From ca08738c7b146238319aa229b786c650dfcce645 Mon Sep 17 00:00:00 2001 From: "@E181658" <celine.hourcade@etu.univ-nantes.fr> Date: Wed, 21 Sep 2022 10:56:01 +0200 Subject: [PATCH] Creation of a single "discrim" function, csv file required for all modes --- README.md | 11 +- data_process.py | 13 +- output_demo/prediction_network_level.csv | 12 +- output_demo/prediction_station_level.csv | 118 +++++++-------- prediction.py | 178 ++++++++--------------- run.py | 34 ++--- 6 files changed, 148 insertions(+), 218 deletions(-) diff --git a/README.md b/README.md index f43d6c9..addcee4 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,9 @@ To apply the algorithm, we need a folder architecture: Each mseed file corresponds to the raw 3 component recordings of 60 sec. +A csv file is also required to apply the algorithm. +It is composed of a column with folders/events to discriminate. If the -valid option is specified, the file must have a second column with the label associated with the event. + ## Trained model Located in directory: model/model_2021354T1554.h5 @@ -47,11 +50,12 @@ Located in directory: model/model_2021354T1554.h5 ## Prediction +As input, you need a one-column csv file with the folders/events to be discriminated. The algorithm supports the mseed data format. - mseed format ``` - python run.py --mode pred --data_dir ./mseed_demo --output_dir ./output_demo + python run.py --data_dir ./mseed_demo --spectro_dir ./spectro_demo --output_dir ./output_demo --csv_dir demo_file.csv ``` Output files are automatically saved in "output_demo". @@ -84,12 +88,13 @@ event and then average them. This gives us an event-based classification. ## Validation -This mode can be used if the label is known. As input you need a csv file with the associated label for each event. +This mode can be used if the label is known. The -valid argument is then specified. +As input you need a csv file of two columns with the associated label for each event. The algorithm supports the mseed data format. - mseed format ``` - python run.py --mode valid --csv_dir ./demo_file.csv --data_dir ./mseed_demo --output_dir ./output_demo + python run.py --data_dir ./mseed_demo --spectro_dir ./spectro_demo --output_dir ./output_demo --csv_dir demo_file.csv --valid ``` Output files are automatically saved in "output_demo". diff --git a/data_process.py b/data_process.py index d1a5e2a..1365045 100644 --- a/data_process.py +++ b/data_process.py @@ -71,7 +71,7 @@ def get_fft(trace: op.core.trace.Trace, WINDOW_LENGTH: int, def spectro_extract(data_dir: str, spectro_dir: str, - events_list: list = None) -> None: + events_list: list) -> None: """ Compute the spectrograms that will be used for the validation. The matrices are saved as NumPy objects. @@ -87,10 +87,8 @@ def spectro_extract(data_dir: str, spectro_dir: str, WINDOW_LENGTH = 1 OVERLAP = (1 - 0.75) - if events_list is None: - events = glob.glob(f'{data_dir}/*') - else: - events = events_list + + events = events_list print(f'Number of events: {len(events)}') nb_evt = 0 @@ -99,10 +97,7 @@ def spectro_extract(data_dir: str, spectro_dir: str, print('*****************') print(f'EVENT {nb_evt} / {len(events)}') - if events_list is None: - time = events[a].split('/')[-1] - else: - time = events[a][0] + time = events[a][0] if not os.path.exists(f'{spectro_dir}/{time}'): os.makedirs(f'{spectro_dir}/{time}') diff --git a/output_demo/prediction_network_level.csv b/output_demo/prediction_network_level.csv index 776a646..05f7158 100644 --- a/output_demo/prediction_network_level.csv +++ b/output_demo/prediction_network_level.csv @@ -1,10 +1,10 @@ event,prob_nat,prob_ant,pred,nature -2022004T111040,0.006,0.994,1,Anthropogenic -2022001T213524,0.961,0.039,0,Natural -2022003T080315,0.043,0.957,1,Anthropogenic -2022003T084110,0.019,0.981,1,Anthropogenic 2022004T134407,0.77,0.23,0,Natural +2022003T041502,0.988,0.012,0,Natural +2022001T213524,0.961,0.039,0,Natural +2022004T111745,0.005,0.995,1,Anthropogenic +2022004T111040,0.006,0.994,1,Anthropogenic 2022004T105235,0.011,0.989,1,Anthropogenic 2022004T101445,0.015,0.985,1,Anthropogenic -2022004T111745,0.005,0.995,1,Anthropogenic -2022003T041502,0.988,0.012,0,Natural +2022003T084110,0.019,0.981,1,Anthropogenic +2022003T080315,0.043,0.957,1,Anthropogenic diff --git a/output_demo/prediction_station_level.csv b/output_demo/prediction_station_level.csv index 4a8b717..c99fc26 100644 --- a/output_demo/prediction_station_level.csv +++ b/output_demo/prediction_station_level.csv @@ -1,4 +1,39 @@ file_name,station,prob_nat,prob_ant,pred,nature +FR_ABJF_2022004T134407,ABJF,0.044,0.956,1,Anthropogenic +FR_CHLF_2022004T134407,CHLF,0.962,0.038,0,Natural +FR_LBL_2022004T134407,LBL,0.216,0.784,1,Anthropogenic +FR_GARF_2022004T134407,GARF,0.982,0.018,0,Natural +FR_GNEF_2022004T134407,GNEF,0.914,0.086,0,Natural +FR_VERF_2022004T134407,VERF,0.985,0.015,0,Natural +FR_GZNF_2022004T134407,GZNF,0.985,0.015,0,Natural +FR_COLF_2022004T134407,COLF,0.808,0.192,0,Natural +FR_HRSF_2022004T134407,HRSF,0.989,0.011,0,Natural +FR_BRGF_2022004T134407,BRGF,0.637,0.363,0,Natural +FR_FRNF_2022004T134407,FRNF,0.952,0.048,0,Natural +FR_PLDF_2022003T041502,PLDF,0.989,0.011,0,Natural +FR_CHLF_2022003T041502,CHLF,0.985,0.015,0,Natural +FR_AGO_2022003T041502,AGO,0.999,0.001,0,Natural +FR_BRGF_2022003T041502,BRGF,0.988,0.012,0,Natural +FR_FRNF_2022003T041502,FRNF,0.985,0.015,0,Natural +FR_COLF_2022003T041502,COLF,0.986,0.014,0,Natural +FR_HRSF_2022003T041502,HRSF,0.987,0.013,0,Natural +FR_GZNF_2022003T041502,GZNF,0.997,0.003,0,Natural +FR_VERF_2022003T041502,VERF,0.985,0.015,0,Natural +FR_LBL_2022003T041502,LBL,0.977,0.023,0,Natural +FR_MTNF_2022001T213524,MTNF,0.924,0.076,0,Natural +FR_SROF_2022001T213524,SROF,0.961,0.039,0,Natural +FR_PLOF_2022001T213524,PLOF,0.949,0.051,0,Natural +FR_DAUF_2022001T213524,DAUF,0.996,0.004,0,Natural +FR_PLEF_2022001T213524,PLEF,0.989,0.011,0,Natural +FR_RIAF_2022001T213524,RIAF,0.999,0.001,0,Natural +FR_LOCF_2022001T213524,LOCF,0.866,0.134,0,Natural +FR_SOMF_2022001T213524,SOMF,0.944,0.056,0,Natural +FR_LOUF_2022001T213524,LOUF,0.992,0.008,0,Natural +FR_BESN_2022001T213524,BESN,0.991,0.009,0,Natural +FR_PLEF_2022004T111745,PLEF,0.011,0.989,1,Anthropogenic +FR_RIAF_2022004T111745,RIAF,0.006,0.994,1,Anthropogenic +FR_BEGF_2022004T111745,BEGF,0.001,0.999,1,Anthropogenic +FR_MTNF_2022004T111745,MTNF,0.001,0.999,1,Anthropogenic FR_DAUF_2022004T111040,DAUF,0.001,0.999,1,Anthropogenic FR_LEUC_2022004T111040,LEUC,0.002,0.998,1,Anthropogenic FR_CRNF_2022004T111040,CRNF,0.014,0.986,1,Anthropogenic @@ -14,51 +49,6 @@ FR_SLVF_2022004T111040,SLVF,0.001,0.999,1,Anthropogenic FR_VERF_2022004T111040 2,VERF,0.014,0.986,1,Anthropogenic FR_BRGF_2022004T111040,BRGF,0.005,0.995,1,Anthropogenic FR_HRSF_2022004T111040,HRSF,0.003,0.997,1,Anthropogenic -FR_MTNF_2022001T213524,MTNF,0.924,0.076,0,Natural -FR_SROF_2022001T213524,SROF,0.961,0.039,0,Natural -FR_PLOF_2022001T213524,PLOF,0.949,0.051,0,Natural -FR_DAUF_2022001T213524,DAUF,0.996,0.004,0,Natural -FR_PLEF_2022001T213524,PLEF,0.989,0.011,0,Natural -FR_RIAF_2022001T213524,RIAF,0.999,0.001,0,Natural -FR_LOCF_2022001T213524,LOCF,0.866,0.134,0,Natural -FR_SOMF_2022001T213524,SOMF,0.944,0.056,0,Natural -FR_LOUF_2022001T213524,LOUF,0.992,0.008,0,Natural -FR_BESN_2022001T213524,BESN,0.991,0.009,0,Natural -FR_CRNF_2022003T080315,CRNF,0.031,0.969,1,Anthropogenic -FR_CHIF_2022003T080315 2,CHIF,0.047,0.953,1,Anthropogenic -FR_GNEF_2022003T080315 2,GNEF,0.014,0.986,1,Anthropogenic -FR_GARF_2022003T080315 2,GARF,0.008,0.992,1,Anthropogenic -FR_LGIF_2022003T080315 2,LGIF,0.014,0.986,1,Anthropogenic -FR_CHIF_2022003T080315,CHIF,0.047,0.953,1,Anthropogenic -FR_BRGF_2022003T080315 2,BRGF,0.162,0.838,1,Anthropogenic -FR_ABJF_2022003T080315 2,ABJF,0.04,0.96,1,Anthropogenic -FR_CRNF_2022003T080315 2,CRNF,0.031,0.969,1,Anthropogenic -FR_ABJF_2022003T080315,ABJF,0.04,0.96,1,Anthropogenic -FR_BSCF_2022003T080315 2,BSCF,0.03,0.97,1,Anthropogenic -FR_GNEF_2022003T080315,GNEF,0.014,0.986,1,Anthropogenic -FR_GARF_2022003T080315,GARF,0.008,0.992,1,Anthropogenic -FR_LGIF_2022003T080315,LGIF,0.014,0.986,1,Anthropogenic -FR_BSCF_2022003T080315,BSCF,0.03,0.97,1,Anthropogenic -FR_BRGF_2022003T080315,BRGF,0.162,0.838,1,Anthropogenic -FR_CHIF_2022003T084110,CHIF,0.006,0.994,1,Anthropogenic -FR_RIAF_2022003T084110,RIAF,0.039,0.961,1,Anthropogenic -FR_CRNF_2022003T084110,CRNF,0.014,0.986,1,Anthropogenic -FR_BOUF_2022003T084110,BOUF,0.054,0.946,1,Anthropogenic -FR_DAUF_2022003T084110,DAUF,0.011,0.989,1,Anthropogenic -FR_GNEF_2022003T084110,GNEF,0.011,0.989,1,Anthropogenic -FR_LGIF_2022003T084110,LGIF,0.012,0.988,1,Anthropogenic -FR_BSCF_2022003T084110,BSCF,0.004,0.996,1,Anthropogenic -FR_ABJF_2022004T134407,ABJF,0.044,0.956,1,Anthropogenic -FR_CHLF_2022004T134407,CHLF,0.962,0.038,0,Natural -FR_LBL_2022004T134407,LBL,0.216,0.784,1,Anthropogenic -FR_GARF_2022004T134407,GARF,0.982,0.018,0,Natural -FR_GNEF_2022004T134407,GNEF,0.914,0.086,0,Natural -FR_VERF_2022004T134407,VERF,0.985,0.015,0,Natural -FR_GZNF_2022004T134407,GZNF,0.985,0.015,0,Natural -FR_COLF_2022004T134407,COLF,0.808,0.192,0,Natural -FR_HRSF_2022004T134407,HRSF,0.989,0.011,0,Natural -FR_BRGF_2022004T134407,BRGF,0.637,0.363,0,Natural -FR_FRNF_2022004T134407,FRNF,0.952,0.048,0,Natural FR_DAUF_2022004T105235,DAUF,0.003,0.997,1,Anthropogenic FR_BEGF_2022004T105235,BEGF,0.009,0.991,1,Anthropogenic FR_BOUF_2022004T105235,BOUF,0.022,0.978,1,Anthropogenic @@ -78,17 +68,27 @@ FR_SROF_2022004T101445,SROF,0.061,0.939,1,Anthropogenic FR_MTNF_2022004T101445,MTNF,0.006,0.994,1,Anthropogenic FR_GUEF_2022004T101445,GUEF,0.002,0.998,1,Anthropogenic FR_OLIV_2022004T101445,OLIV,0.004,0.996,1,Anthropogenic -FR_PLEF_2022004T111745,PLEF,0.011,0.989,1,Anthropogenic -FR_RIAF_2022004T111745,RIAF,0.006,0.994,1,Anthropogenic -FR_BEGF_2022004T111745,BEGF,0.001,0.999,1,Anthropogenic -FR_MTNF_2022004T111745,MTNF,0.001,0.999,1,Anthropogenic -FR_PLDF_2022003T041502,PLDF,0.989,0.011,0,Natural -FR_CHLF_2022003T041502,CHLF,0.985,0.015,0,Natural -FR_AGO_2022003T041502,AGO,0.999,0.001,0,Natural -FR_BRGF_2022003T041502,BRGF,0.988,0.012,0,Natural -FR_FRNF_2022003T041502,FRNF,0.985,0.015,0,Natural -FR_COLF_2022003T041502,COLF,0.986,0.014,0,Natural -FR_HRSF_2022003T041502,HRSF,0.987,0.013,0,Natural -FR_GZNF_2022003T041502,GZNF,0.997,0.003,0,Natural -FR_VERF_2022003T041502,VERF,0.985,0.015,0,Natural -FR_LBL_2022003T041502,LBL,0.977,0.023,0,Natural +FR_CHIF_2022003T084110,CHIF,0.006,0.994,1,Anthropogenic +FR_RIAF_2022003T084110,RIAF,0.039,0.961,1,Anthropogenic +FR_CRNF_2022003T084110,CRNF,0.014,0.986,1,Anthropogenic +FR_BOUF_2022003T084110,BOUF,0.054,0.946,1,Anthropogenic +FR_DAUF_2022003T084110,DAUF,0.011,0.989,1,Anthropogenic +FR_GNEF_2022003T084110,GNEF,0.011,0.989,1,Anthropogenic +FR_LGIF_2022003T084110,LGIF,0.012,0.988,1,Anthropogenic +FR_BSCF_2022003T084110,BSCF,0.004,0.996,1,Anthropogenic +FR_CRNF_2022003T080315,CRNF,0.031,0.969,1,Anthropogenic +FR_CHIF_2022003T080315 2,CHIF,0.047,0.953,1,Anthropogenic +FR_GNEF_2022003T080315 2,GNEF,0.014,0.986,1,Anthropogenic +FR_GARF_2022003T080315 2,GARF,0.008,0.992,1,Anthropogenic +FR_LGIF_2022003T080315 2,LGIF,0.014,0.986,1,Anthropogenic +FR_CHIF_2022003T080315,CHIF,0.047,0.953,1,Anthropogenic +FR_BRGF_2022003T080315 2,BRGF,0.162,0.838,1,Anthropogenic +FR_ABJF_2022003T080315 2,ABJF,0.04,0.96,1,Anthropogenic +FR_CRNF_2022003T080315 2,CRNF,0.031,0.969,1,Anthropogenic +FR_ABJF_2022003T080315,ABJF,0.04,0.96,1,Anthropogenic +FR_BSCF_2022003T080315 2,BSCF,0.03,0.97,1,Anthropogenic +FR_GNEF_2022003T080315,GNEF,0.014,0.986,1,Anthropogenic +FR_GARF_2022003T080315,GARF,0.008,0.992,1,Anthropogenic +FR_LGIF_2022003T080315,LGIF,0.014,0.986,1,Anthropogenic +FR_BSCF_2022003T080315,BSCF,0.03,0.97,1,Anthropogenic +FR_BRGF_2022003T080315,BRGF,0.162,0.838,1,Anthropogenic diff --git a/prediction.py b/prediction.py index eb0dfbc..ae17472 100644 --- a/prediction.py +++ b/prediction.py @@ -10,7 +10,7 @@ from numpy import moveaxis import tensorflow as tf -def pred(model_dir, spectro_dir, output_dir): +def discrim(spectro_dir, output_dir, event_label, valid): """ Event class prediction. @@ -20,140 +20,63 @@ def pred(model_dir, spectro_dir, output_dir): :param spectro_dir: Absolute path to the input spectrograms. :type output_dir: str :param output_dir: Absolute path where to save to output files. + :type event_label: list + :param event_label: The class of event to validate. """ - - csvPr_sta = open(os.path.join( - output_dir, 'prediction_station_level.csv'), 'w') - predict_sta = csv.writer(csvPr_sta, delimiter=',', - quotechar='"', quoting=csv.QUOTE_MINIMAL) - predict_sta.writerow(['file_name', + + if valid : + filename_csvsta = 'validation_station_level.csv' + csvsta_row = ['file_name', 'station', + 'label_cat', 'prob_nat', 'prob_ant', 'pred', - 'nature', - ]) + 'nature'] - csvPr_net = open(os.path.join( - output_dir, 'prediction_network_level.csv'), 'w') - predict_net = csv.writer(csvPr_net, delimiter=',', - quotechar='"', quoting=csv.QUOTE_MINIMAL) - predict_net.writerow(['event', + filename_csvnet = 'validation_network_level.csv' + csvnet_row = ['event', + 'label_cat', 'prob_nat', 'prob_ant', 'pred', - 'nature', - ]) - - model = tf.keras.models.load_model(model_dir) - - events = glob.glob(f'{spectro_dir}/*') - - print(f'Number of events: {len(events)}') - - nb_evt = 0 - for evt in events: - nb_evt += 1 - print('*****************') - print(f'EVENT {nb_evt} / {len(events)}') - - time = evt.split('/')[-1] - pred_nat = 0 - pred_ant = 0 - - list_spect = glob.glob(f'{spectro_dir}/{time}/*') - print(f'Number of station: {len(list_spect)}') - nb_st = 0 - for spect in list_spect: - nb_st += 1 - print(f'Station {nb_st} / {len(list_spect)}', end = "\r") - file_name = (spect.split('/')[-1]).split('.npy')[0] - - station = file_name.split('_')[1] - spect_file = np.load(f'{spect}', allow_pickle=True) - spect_file = [np.array(spect_file)] - - x = moveaxis(spect_file, 1, 3) - - model_output = model.predict(x).round(3) - pred = np.argmax(model_output, axis=1) - - if pred == 0: - pred_final = 'Natural' - if pred == 1: - pred_final = 'Anthropogenic' - - predict_sta.writerow([file_name, - station, - model_output[0][0], - model_output[0][1], - pred[0], - pred_final, - ]) + 'nature'] - pred_nat += model_output[0][0] - pred_ant += model_output[0][1] - - pred_total = [pred_nat, pred_ant] - pred_total = [(float(i)/sum(pred_total)).round(3) for i in pred_total] - pred_event = np.argmax(pred_total) - if pred_event == 0: - pred_final = 'Natural' - if pred_event == 1: - pred_final = 'Anthropogenic' - - predict_net.writerow([time, - pred_total[0], - pred_total[1], - pred_event, - pred_final, - ]) - - -def valid(model_dir, spectro_dir, output_dir, event_label): - """ - Event class validation. - - :type model_dir: str - :param model_dir: Absolute path the the trained model. - :type spectro_dir: str - :param spectro_dir: Absolute path to the input spectrograms. - :type output_dir: str - :param output_dir: Absolute path where to save the output files. - :type event_label: list - :param event_label: The class of event to validate. - """ - - csvPr_sta = open(os.path.join( - output_dir, 'validation_station_level.csv'), 'w') - predict_sta = csv.writer(csvPr_sta, delimiter=',', - quotechar='"', quoting=csv.QUOTE_MINIMAL) - predict_sta.writerow(['file_name', + else : + filename_csvsta = 'prediction_station_level.csv' + csvsta_row = ['file_name', 'station', - 'label_cat', 'prob_nat', 'prob_ant', 'pred', - 'nature', - ]) + 'nature'] - csvPr_net = open(os.path.join( - output_dir, 'validation_network_level.csv'), 'w') - predict_net = csv.writer(csvPr_net, delimiter=',', - quotechar='"', quoting=csv.QUOTE_MINIMAL) - predict_net.writerow(['event', - 'label_cat', + filename_csvnet = 'prediction_network_level.csv' + csvnet_row = ['event', 'prob_nat', 'prob_ant', 'pred', - 'nature']) + 'nature' + ] + + csvPr_sta = open(os.path.join( + output_dir, filename_csvsta), 'w') + predict_sta = csv.writer(csvPr_sta, delimiter=',', + quotechar='"', quoting=csv.QUOTE_MINIMAL) + predict_sta.writerow(csvsta_row) - model = tf.keras.models.load_model(model_dir) + csvPr_net = open(os.path.join( + output_dir, filename_csvnet), 'w') + predict_net = csv.writer(csvPr_net, delimiter=',', + quotechar='"', quoting=csv.QUOTE_MINIMAL) + predict_net.writerow(csvnet_row) + + model = tf.keras.models.load_model("./model/model_2021354T1554.h5") events = glob.glob(f'{spectro_dir}/*') print(f'Number of events: {len(events)}') - print(len(event_label)) + nb_evt = 0 for a in range(len(event_label)): nb_evt += 1 @@ -161,8 +84,8 @@ def valid(model_dir, spectro_dir, output_dir, event_label): print(f'EVENT {nb_evt} / {len(event_label)}') time = event_label[a][0] - class_ = event_label[a][1] - #time = evt.split('/')[2] + if valid: + class_ = event_label[a][1] pred_nat = 0 pred_ant = 0 @@ -187,8 +110,9 @@ def valid(model_dir, spectro_dir, output_dir, event_label): pred_final = 'Natural' if pred == 1: pred_final = 'Anthropogenic' - - predict_sta.writerow([file_name, + + if valid : + predict_sta.writerow([file_name, station, class_, model_output[0][0], @@ -196,6 +120,14 @@ def valid(model_dir, spectro_dir, output_dir, event_label): pred[0], pred_final, ]) + else : + predict_sta.writerow([file_name, + station, + model_output[0][0], + model_output[0][1], + pred[0], + pred_final, + ]) pred_nat += model_output[0][0] pred_ant += model_output[0][1] @@ -207,11 +139,21 @@ def valid(model_dir, spectro_dir, output_dir, event_label): pred_final = 'Natural' if pred_event == 1: pred_final = 'Anthropogenic' - - predict_net.writerow([time, + + if valid : + predict_net.writerow([time, class_, pred_total[0], pred_total[1], pred_event, pred_final, ]) + else : + predict_net.writerow([time, + pred_total[0], + pred_total[1], + pred_event, + pred_final, + ]) + + diff --git a/run.py b/run.py index 41090c5..cd4a2b5 100644 --- a/run.py +++ b/run.py @@ -5,7 +5,7 @@ import argparse import numpy as np -from prediction import pred, valid +from prediction import discrim from data_process import spectro_extract @@ -13,14 +13,6 @@ def read_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument('--mode', - type=str, default='pred', - help=' "valid" for the validation mode or "pred" for the prediction mode.') - - parser.add_argument('--model_dir', - type=str, default='./model/model_2021354T1554.h5', - help="Model file directory.") - parser.add_argument('--data_dir', type=str, default='./mseed_demo', help="Input mseed file directory.") @@ -30,36 +22,32 @@ def read_args() -> argparse.Namespace: help='Output spectrogram file directory.') parser.add_argument('--csv_dir', - default=None, + required=True, help="Input csv file directory") parser.add_argument('--output_dir', type=str, default='./output_demo', help='Output directory') + parser.add_argument('--valid', + action="store_true", + help=' if the option "valid" is specified the validation mode will be applied. Csv input must have two columns (time, label_cat)') + args = parser.parse_args() return args def main(args: argparse.Namespace): - if args.mode == 'pred': - spectro_extract(data_dir=args.data_dir, - spectro_dir=args.spectro_dir) - pred(model_dir=args.model_dir, spectro_dir=args.spectro_dir, - output_dir=args.output_dir) - elif args.mode == 'valid': - events = np.genfromtxt( + events = np.genfromtxt( f'{args.csv_dir}', delimiter=',', skip_header=1, dtype=str) - spectro_extract(data_dir=args.data_dir, + spectro_extract(data_dir=args.data_dir, spectro_dir=args.spectro_dir, events_list=events) - valid(model_dir=args.model_dir, spectro_dir=args.spectro_dir, - output_dir=args.output_dir, event_label=events) - - else: - print("Mode should be: valid, or pred") + discrim(spectro_dir=args.spectro_dir, + output_dir=args.output_dir, event_label=events, valid = args.valid) + return -- GitLab