From ca08738c7b146238319aa229b786c650dfcce645 Mon Sep 17 00:00:00 2001
From: "@E181658" <celine.hourcade@etu.univ-nantes.fr>
Date: Wed, 21 Sep 2022 10:56:01 +0200
Subject: [PATCH] Creation of a single "discrim" function, csv file required
 for all modes

---
 README.md                                |  11 +-
 data_process.py                          |  13 +-
 output_demo/prediction_network_level.csv |  12 +-
 output_demo/prediction_station_level.csv | 118 +++++++--------
 prediction.py                            | 178 ++++++++---------------
 run.py                                   |  34 ++---
 6 files changed, 148 insertions(+), 218 deletions(-)

diff --git a/README.md b/README.md
index f43d6c9..addcee4 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,9 @@ To apply the algorithm, we need a folder architecture:
 
 Each mseed file corresponds to the raw 3 component recordings of 60 sec.
 
+A csv file is also required to apply the algorithm. 
+It is composed of a column with folders/events to discriminate. If the -valid option is specified, the file must have a second column with the label associated with the event. 
+
 ## Trained model
 
 Located in directory: model/model_2021354T1554.h5
@@ -47,11 +50,12 @@ Located in directory: model/model_2021354T1554.h5
 
 ## Prediction
 
+As input, you need a one-column csv file with the folders/events to be discriminated.
 The algorithm supports the mseed data format.
 
 -  mseed format
 ```
- python run.py --mode pred --data_dir ./mseed_demo --output_dir ./output_demo
+ python run.py --data_dir ./mseed_demo --spectro_dir ./spectro_demo --output_dir ./output_demo --csv_dir demo_file.csv
 ```
 
 Output files are automatically saved in "output_demo".
@@ -84,12 +88,13 @@ event and then average them. This gives us an event-based classification.
 
 ## Validation
 
-This mode can be used if the label is known. As input you need a csv file with the associated label for each event.
+This mode can be used if the label is known. The -valid argument is then specified.
+As input you need a csv file of two columns with the associated label for each event.
 The algorithm supports the mseed data format.
 
 -  mseed format
 ```
- python run.py --mode valid --csv_dir ./demo_file.csv --data_dir ./mseed_demo --output_dir ./output_demo 
+ python run.py --data_dir ./mseed_demo --spectro_dir ./spectro_demo --output_dir ./output_demo --csv_dir demo_file.csv --valid
 ```
 
 Output files are automatically saved in "output_demo".
diff --git a/data_process.py b/data_process.py
index d1a5e2a..1365045 100644
--- a/data_process.py
+++ b/data_process.py
@@ -71,7 +71,7 @@ def get_fft(trace: op.core.trace.Trace, WINDOW_LENGTH: int,
 
 
 def spectro_extract(data_dir: str, spectro_dir: str,
-                    events_list: list = None) -> None:
+                    events_list: list) -> None:
     """
         Compute the spectrograms that will be used for the validation.
         The matrices are saved as NumPy objects.
@@ -87,10 +87,8 @@ def spectro_extract(data_dir: str, spectro_dir: str,
     WINDOW_LENGTH = 1
     OVERLAP = (1 - 0.75)
 
-    if events_list is None:
-        events = glob.glob(f'{data_dir}/*')
-    else:
-        events = events_list
+    
+    events = events_list
 
     print(f'Number of events: {len(events)}')
     nb_evt = 0
@@ -99,10 +97,7 @@ def spectro_extract(data_dir: str, spectro_dir: str,
         print('*****************')
         print(f'EVENT {nb_evt} / {len(events)}')
 
-        if events_list is None:
-            time = events[a].split('/')[-1]
-        else:
-            time = events[a][0]
+        time = events[a][0]
 
         if not os.path.exists(f'{spectro_dir}/{time}'):
             os.makedirs(f'{spectro_dir}/{time}')
diff --git a/output_demo/prediction_network_level.csv b/output_demo/prediction_network_level.csv
index 776a646..05f7158 100644
--- a/output_demo/prediction_network_level.csv
+++ b/output_demo/prediction_network_level.csv
@@ -1,10 +1,10 @@
 event,prob_nat,prob_ant,pred,nature
-2022004T111040,0.006,0.994,1,Anthropogenic
-2022001T213524,0.961,0.039,0,Natural
-2022003T080315,0.043,0.957,1,Anthropogenic
-2022003T084110,0.019,0.981,1,Anthropogenic
 2022004T134407,0.77,0.23,0,Natural
+2022003T041502,0.988,0.012,0,Natural
+2022001T213524,0.961,0.039,0,Natural
+2022004T111745,0.005,0.995,1,Anthropogenic
+2022004T111040,0.006,0.994,1,Anthropogenic
 2022004T105235,0.011,0.989,1,Anthropogenic
 2022004T101445,0.015,0.985,1,Anthropogenic
-2022004T111745,0.005,0.995,1,Anthropogenic
-2022003T041502,0.988,0.012,0,Natural
+2022003T084110,0.019,0.981,1,Anthropogenic
+2022003T080315,0.043,0.957,1,Anthropogenic
diff --git a/output_demo/prediction_station_level.csv b/output_demo/prediction_station_level.csv
index 4a8b717..c99fc26 100644
--- a/output_demo/prediction_station_level.csv
+++ b/output_demo/prediction_station_level.csv
@@ -1,4 +1,39 @@
 file_name,station,prob_nat,prob_ant,pred,nature
+FR_ABJF_2022004T134407,ABJF,0.044,0.956,1,Anthropogenic
+FR_CHLF_2022004T134407,CHLF,0.962,0.038,0,Natural
+FR_LBL_2022004T134407,LBL,0.216,0.784,1,Anthropogenic
+FR_GARF_2022004T134407,GARF,0.982,0.018,0,Natural
+FR_GNEF_2022004T134407,GNEF,0.914,0.086,0,Natural
+FR_VERF_2022004T134407,VERF,0.985,0.015,0,Natural
+FR_GZNF_2022004T134407,GZNF,0.985,0.015,0,Natural
+FR_COLF_2022004T134407,COLF,0.808,0.192,0,Natural
+FR_HRSF_2022004T134407,HRSF,0.989,0.011,0,Natural
+FR_BRGF_2022004T134407,BRGF,0.637,0.363,0,Natural
+FR_FRNF_2022004T134407,FRNF,0.952,0.048,0,Natural
+FR_PLDF_2022003T041502,PLDF,0.989,0.011,0,Natural
+FR_CHLF_2022003T041502,CHLF,0.985,0.015,0,Natural
+FR_AGO_2022003T041502,AGO,0.999,0.001,0,Natural
+FR_BRGF_2022003T041502,BRGF,0.988,0.012,0,Natural
+FR_FRNF_2022003T041502,FRNF,0.985,0.015,0,Natural
+FR_COLF_2022003T041502,COLF,0.986,0.014,0,Natural
+FR_HRSF_2022003T041502,HRSF,0.987,0.013,0,Natural
+FR_GZNF_2022003T041502,GZNF,0.997,0.003,0,Natural
+FR_VERF_2022003T041502,VERF,0.985,0.015,0,Natural
+FR_LBL_2022003T041502,LBL,0.977,0.023,0,Natural
+FR_MTNF_2022001T213524,MTNF,0.924,0.076,0,Natural
+FR_SROF_2022001T213524,SROF,0.961,0.039,0,Natural
+FR_PLOF_2022001T213524,PLOF,0.949,0.051,0,Natural
+FR_DAUF_2022001T213524,DAUF,0.996,0.004,0,Natural
+FR_PLEF_2022001T213524,PLEF,0.989,0.011,0,Natural
+FR_RIAF_2022001T213524,RIAF,0.999,0.001,0,Natural
+FR_LOCF_2022001T213524,LOCF,0.866,0.134,0,Natural
+FR_SOMF_2022001T213524,SOMF,0.944,0.056,0,Natural
+FR_LOUF_2022001T213524,LOUF,0.992,0.008,0,Natural
+FR_BESN_2022001T213524,BESN,0.991,0.009,0,Natural
+FR_PLEF_2022004T111745,PLEF,0.011,0.989,1,Anthropogenic
+FR_RIAF_2022004T111745,RIAF,0.006,0.994,1,Anthropogenic
+FR_BEGF_2022004T111745,BEGF,0.001,0.999,1,Anthropogenic
+FR_MTNF_2022004T111745,MTNF,0.001,0.999,1,Anthropogenic
 FR_DAUF_2022004T111040,DAUF,0.001,0.999,1,Anthropogenic
 FR_LEUC_2022004T111040,LEUC,0.002,0.998,1,Anthropogenic
 FR_CRNF_2022004T111040,CRNF,0.014,0.986,1,Anthropogenic
@@ -14,51 +49,6 @@ FR_SLVF_2022004T111040,SLVF,0.001,0.999,1,Anthropogenic
 FR_VERF_2022004T111040 2,VERF,0.014,0.986,1,Anthropogenic
 FR_BRGF_2022004T111040,BRGF,0.005,0.995,1,Anthropogenic
 FR_HRSF_2022004T111040,HRSF,0.003,0.997,1,Anthropogenic
-FR_MTNF_2022001T213524,MTNF,0.924,0.076,0,Natural
-FR_SROF_2022001T213524,SROF,0.961,0.039,0,Natural
-FR_PLOF_2022001T213524,PLOF,0.949,0.051,0,Natural
-FR_DAUF_2022001T213524,DAUF,0.996,0.004,0,Natural
-FR_PLEF_2022001T213524,PLEF,0.989,0.011,0,Natural
-FR_RIAF_2022001T213524,RIAF,0.999,0.001,0,Natural
-FR_LOCF_2022001T213524,LOCF,0.866,0.134,0,Natural
-FR_SOMF_2022001T213524,SOMF,0.944,0.056,0,Natural
-FR_LOUF_2022001T213524,LOUF,0.992,0.008,0,Natural
-FR_BESN_2022001T213524,BESN,0.991,0.009,0,Natural
-FR_CRNF_2022003T080315,CRNF,0.031,0.969,1,Anthropogenic
-FR_CHIF_2022003T080315 2,CHIF,0.047,0.953,1,Anthropogenic
-FR_GNEF_2022003T080315 2,GNEF,0.014,0.986,1,Anthropogenic
-FR_GARF_2022003T080315 2,GARF,0.008,0.992,1,Anthropogenic
-FR_LGIF_2022003T080315 2,LGIF,0.014,0.986,1,Anthropogenic
-FR_CHIF_2022003T080315,CHIF,0.047,0.953,1,Anthropogenic
-FR_BRGF_2022003T080315 2,BRGF,0.162,0.838,1,Anthropogenic
-FR_ABJF_2022003T080315 2,ABJF,0.04,0.96,1,Anthropogenic
-FR_CRNF_2022003T080315 2,CRNF,0.031,0.969,1,Anthropogenic
-FR_ABJF_2022003T080315,ABJF,0.04,0.96,1,Anthropogenic
-FR_BSCF_2022003T080315 2,BSCF,0.03,0.97,1,Anthropogenic
-FR_GNEF_2022003T080315,GNEF,0.014,0.986,1,Anthropogenic
-FR_GARF_2022003T080315,GARF,0.008,0.992,1,Anthropogenic
-FR_LGIF_2022003T080315,LGIF,0.014,0.986,1,Anthropogenic
-FR_BSCF_2022003T080315,BSCF,0.03,0.97,1,Anthropogenic
-FR_BRGF_2022003T080315,BRGF,0.162,0.838,1,Anthropogenic
-FR_CHIF_2022003T084110,CHIF,0.006,0.994,1,Anthropogenic
-FR_RIAF_2022003T084110,RIAF,0.039,0.961,1,Anthropogenic
-FR_CRNF_2022003T084110,CRNF,0.014,0.986,1,Anthropogenic
-FR_BOUF_2022003T084110,BOUF,0.054,0.946,1,Anthropogenic
-FR_DAUF_2022003T084110,DAUF,0.011,0.989,1,Anthropogenic
-FR_GNEF_2022003T084110,GNEF,0.011,0.989,1,Anthropogenic
-FR_LGIF_2022003T084110,LGIF,0.012,0.988,1,Anthropogenic
-FR_BSCF_2022003T084110,BSCF,0.004,0.996,1,Anthropogenic
-FR_ABJF_2022004T134407,ABJF,0.044,0.956,1,Anthropogenic
-FR_CHLF_2022004T134407,CHLF,0.962,0.038,0,Natural
-FR_LBL_2022004T134407,LBL,0.216,0.784,1,Anthropogenic
-FR_GARF_2022004T134407,GARF,0.982,0.018,0,Natural
-FR_GNEF_2022004T134407,GNEF,0.914,0.086,0,Natural
-FR_VERF_2022004T134407,VERF,0.985,0.015,0,Natural
-FR_GZNF_2022004T134407,GZNF,0.985,0.015,0,Natural
-FR_COLF_2022004T134407,COLF,0.808,0.192,0,Natural
-FR_HRSF_2022004T134407,HRSF,0.989,0.011,0,Natural
-FR_BRGF_2022004T134407,BRGF,0.637,0.363,0,Natural
-FR_FRNF_2022004T134407,FRNF,0.952,0.048,0,Natural
 FR_DAUF_2022004T105235,DAUF,0.003,0.997,1,Anthropogenic
 FR_BEGF_2022004T105235,BEGF,0.009,0.991,1,Anthropogenic
 FR_BOUF_2022004T105235,BOUF,0.022,0.978,1,Anthropogenic
@@ -78,17 +68,27 @@ FR_SROF_2022004T101445,SROF,0.061,0.939,1,Anthropogenic
 FR_MTNF_2022004T101445,MTNF,0.006,0.994,1,Anthropogenic
 FR_GUEF_2022004T101445,GUEF,0.002,0.998,1,Anthropogenic
 FR_OLIV_2022004T101445,OLIV,0.004,0.996,1,Anthropogenic
-FR_PLEF_2022004T111745,PLEF,0.011,0.989,1,Anthropogenic
-FR_RIAF_2022004T111745,RIAF,0.006,0.994,1,Anthropogenic
-FR_BEGF_2022004T111745,BEGF,0.001,0.999,1,Anthropogenic
-FR_MTNF_2022004T111745,MTNF,0.001,0.999,1,Anthropogenic
-FR_PLDF_2022003T041502,PLDF,0.989,0.011,0,Natural
-FR_CHLF_2022003T041502,CHLF,0.985,0.015,0,Natural
-FR_AGO_2022003T041502,AGO,0.999,0.001,0,Natural
-FR_BRGF_2022003T041502,BRGF,0.988,0.012,0,Natural
-FR_FRNF_2022003T041502,FRNF,0.985,0.015,0,Natural
-FR_COLF_2022003T041502,COLF,0.986,0.014,0,Natural
-FR_HRSF_2022003T041502,HRSF,0.987,0.013,0,Natural
-FR_GZNF_2022003T041502,GZNF,0.997,0.003,0,Natural
-FR_VERF_2022003T041502,VERF,0.985,0.015,0,Natural
-FR_LBL_2022003T041502,LBL,0.977,0.023,0,Natural
+FR_CHIF_2022003T084110,CHIF,0.006,0.994,1,Anthropogenic
+FR_RIAF_2022003T084110,RIAF,0.039,0.961,1,Anthropogenic
+FR_CRNF_2022003T084110,CRNF,0.014,0.986,1,Anthropogenic
+FR_BOUF_2022003T084110,BOUF,0.054,0.946,1,Anthropogenic
+FR_DAUF_2022003T084110,DAUF,0.011,0.989,1,Anthropogenic
+FR_GNEF_2022003T084110,GNEF,0.011,0.989,1,Anthropogenic
+FR_LGIF_2022003T084110,LGIF,0.012,0.988,1,Anthropogenic
+FR_BSCF_2022003T084110,BSCF,0.004,0.996,1,Anthropogenic
+FR_CRNF_2022003T080315,CRNF,0.031,0.969,1,Anthropogenic
+FR_CHIF_2022003T080315 2,CHIF,0.047,0.953,1,Anthropogenic
+FR_GNEF_2022003T080315 2,GNEF,0.014,0.986,1,Anthropogenic
+FR_GARF_2022003T080315 2,GARF,0.008,0.992,1,Anthropogenic
+FR_LGIF_2022003T080315 2,LGIF,0.014,0.986,1,Anthropogenic
+FR_CHIF_2022003T080315,CHIF,0.047,0.953,1,Anthropogenic
+FR_BRGF_2022003T080315 2,BRGF,0.162,0.838,1,Anthropogenic
+FR_ABJF_2022003T080315 2,ABJF,0.04,0.96,1,Anthropogenic
+FR_CRNF_2022003T080315 2,CRNF,0.031,0.969,1,Anthropogenic
+FR_ABJF_2022003T080315,ABJF,0.04,0.96,1,Anthropogenic
+FR_BSCF_2022003T080315 2,BSCF,0.03,0.97,1,Anthropogenic
+FR_GNEF_2022003T080315,GNEF,0.014,0.986,1,Anthropogenic
+FR_GARF_2022003T080315,GARF,0.008,0.992,1,Anthropogenic
+FR_LGIF_2022003T080315,LGIF,0.014,0.986,1,Anthropogenic
+FR_BSCF_2022003T080315,BSCF,0.03,0.97,1,Anthropogenic
+FR_BRGF_2022003T080315,BRGF,0.162,0.838,1,Anthropogenic
diff --git a/prediction.py b/prediction.py
index eb0dfbc..ae17472 100644
--- a/prediction.py
+++ b/prediction.py
@@ -10,7 +10,7 @@ from numpy import moveaxis
 import tensorflow as tf
 
 
-def pred(model_dir, spectro_dir, output_dir):
+def discrim(spectro_dir, output_dir, event_label, valid):
     """
     Event class prediction.
 
@@ -20,140 +20,63 @@ def pred(model_dir, spectro_dir, output_dir):
     :param spectro_dir: Absolute path to the input spectrograms.
     :type output_dir: str
     :param output_dir: Absolute path where to save to output files.
+    :type event_label: list
+    :param event_label: The class of event to validate.
     """
-
-    csvPr_sta = open(os.path.join(
-        output_dir, 'prediction_station_level.csv'), 'w')
-    predict_sta = csv.writer(csvPr_sta, delimiter=',',
-                             quotechar='"', quoting=csv.QUOTE_MINIMAL)
-    predict_sta.writerow(['file_name',
+    
+    if valid : 
+        filename_csvsta = 'validation_station_level.csv'
+        csvsta_row = ['file_name',
                           'station',
+                          'label_cat',
                           'prob_nat',
                           'prob_ant',
                           'pred',
-                          'nature',
-                          ])
+                          'nature']
 
-    csvPr_net = open(os.path.join(
-        output_dir, 'prediction_network_level.csv'), 'w')
-    predict_net = csv.writer(csvPr_net, delimiter=',',
-                             quotechar='"', quoting=csv.QUOTE_MINIMAL)
-    predict_net.writerow(['event',
+        filename_csvnet = 'validation_network_level.csv'
+        csvnet_row = ['event',
+                          'label_cat',
                           'prob_nat',
                           'prob_ant',
                           'pred',
-                          'nature',
-                          ])
-
-    model = tf.keras.models.load_model(model_dir)
-
-    events = glob.glob(f'{spectro_dir}/*')
-
-    print(f'Number of events: {len(events)}')
-
-    nb_evt = 0
-    for evt in events:
-        nb_evt += 1
-        print('*****************')
-        print(f'EVENT {nb_evt} / {len(events)}')
-
-        time = evt.split('/')[-1]
-        pred_nat = 0
-        pred_ant = 0
-
-        list_spect = glob.glob(f'{spectro_dir}/{time}/*')
-        print(f'Number of station: {len(list_spect)}')
-        nb_st = 0
-        for spect in list_spect:
-            nb_st += 1
-            print(f'Station {nb_st} / {len(list_spect)}', end = "\r")
-            file_name = (spect.split('/')[-1]).split('.npy')[0]
-
-            station = file_name.split('_')[1]
-            spect_file = np.load(f'{spect}', allow_pickle=True)
-            spect_file = [np.array(spect_file)]
-
-            x = moveaxis(spect_file, 1, 3)
-
-            model_output = model.predict(x).round(3)
-            pred = np.argmax(model_output, axis=1)
-
-            if pred == 0:
-                pred_final = 'Natural'
-            if pred == 1:
-                pred_final = 'Anthropogenic'
-
-            predict_sta.writerow([file_name,
-                                  station,
-                                  model_output[0][0],
-                                  model_output[0][1],
-                                  pred[0],
-                                  pred_final,
-                                  ])
+                          'nature']
 
-            pred_nat += model_output[0][0]
-            pred_ant += model_output[0][1]
-
-        pred_total = [pred_nat, pred_ant]
-        pred_total = [(float(i)/sum(pred_total)).round(3) for i in pred_total]
-        pred_event = np.argmax(pred_total)
-        if pred_event == 0:
-            pred_final = 'Natural'
-        if pred_event == 1:
-            pred_final = 'Anthropogenic'
-
-        predict_net.writerow([time,
-                              pred_total[0],
-                              pred_total[1],
-                              pred_event,
-                              pred_final,
-                              ])
-
-
-def valid(model_dir, spectro_dir, output_dir, event_label):
-    """
-    Event class validation.
-
-    :type model_dir: str
-    :param model_dir: Absolute path the the trained model.
-    :type spectro_dir: str
-    :param spectro_dir: Absolute path to the input spectrograms.
-    :type output_dir: str
-    :param output_dir: Absolute path where to save the output files.
-    :type event_label: list
-    :param event_label: The class of event to validate.
-    """
-
-    csvPr_sta = open(os.path.join(
-        output_dir, 'validation_station_level.csv'), 'w')
-    predict_sta = csv.writer(csvPr_sta, delimiter=',',
-                             quotechar='"', quoting=csv.QUOTE_MINIMAL)
-    predict_sta.writerow(['file_name',
+    else : 
+        filename_csvsta = 'prediction_station_level.csv'
+        csvsta_row = ['file_name',
                           'station',
-                          'label_cat',
                           'prob_nat',
                           'prob_ant',
                           'pred',
-                          'nature',
-                          ])
+                          'nature']
 
-    csvPr_net = open(os.path.join(
-        output_dir, 'validation_network_level.csv'), 'w')
-    predict_net = csv.writer(csvPr_net, delimiter=',',
-                             quotechar='"', quoting=csv.QUOTE_MINIMAL)
-    predict_net.writerow(['event',
-                          'label_cat',
+        filename_csvnet = 'prediction_network_level.csv'
+        csvnet_row = ['event',
                           'prob_nat',
                           'prob_ant',
                           'pred',
-                          'nature'])
+                          'nature'
+                          ]
+    
+    csvPr_sta = open(os.path.join(
+        output_dir, filename_csvsta), 'w')
+    predict_sta = csv.writer(csvPr_sta, delimiter=',',
+                             quotechar='"', quoting=csv.QUOTE_MINIMAL)
+    predict_sta.writerow(csvsta_row)
 
-    model = tf.keras.models.load_model(model_dir)
+    csvPr_net = open(os.path.join(
+        output_dir, filename_csvnet), 'w')
+    predict_net = csv.writer(csvPr_net, delimiter=',',
+                             quotechar='"', quoting=csv.QUOTE_MINIMAL)
+    predict_net.writerow(csvnet_row)
+
+    model = tf.keras.models.load_model("./model/model_2021354T1554.h5")
 
     events = glob.glob(f'{spectro_dir}/*')
 
     print(f'Number of events: {len(events)}')
-    print(len(event_label))
+
     nb_evt = 0
     for a in range(len(event_label)):
         nb_evt += 1
@@ -161,8 +84,8 @@ def valid(model_dir, spectro_dir, output_dir, event_label):
         print(f'EVENT {nb_evt} / {len(event_label)}')
         time = event_label[a][0]
 
-        class_ = event_label[a][1]
-        #time = evt.split('/')[2]
+        if valid: 
+            class_ = event_label[a][1]
         pred_nat = 0
         pred_ant = 0
 
@@ -187,8 +110,9 @@ def valid(model_dir, spectro_dir, output_dir, event_label):
                 pred_final = 'Natural'
             if pred == 1:
                 pred_final = 'Anthropogenic'
-
-            predict_sta.writerow([file_name,
+            
+            if valid : 
+                predict_sta.writerow([file_name,
                                   station,
                                   class_,
                                   model_output[0][0],
@@ -196,6 +120,14 @@ def valid(model_dir, spectro_dir, output_dir, event_label):
                                   pred[0],
                                   pred_final,
                                   ])
+            else : 
+                predict_sta.writerow([file_name,
+                                  station,
+                                  model_output[0][0],
+                                  model_output[0][1],
+                                  pred[0],
+                                  pred_final,
+                                  ])
 
             pred_nat += model_output[0][0]
             pred_ant += model_output[0][1]
@@ -207,11 +139,21 @@ def valid(model_dir, spectro_dir, output_dir, event_label):
             pred_final = 'Natural'
         if pred_event == 1:
             pred_final = 'Anthropogenic'
-
-        predict_net.writerow([time,
+        
+        if valid : 
+            predict_net.writerow([time,
                               class_,
                               pred_total[0],
                               pred_total[1],
                               pred_event,
                               pred_final,
                               ])
+        else : 
+            predict_net.writerow([time,
+                              pred_total[0],
+                              pred_total[1],
+                              pred_event,
+                              pred_final,
+                              ])
+
+
diff --git a/run.py b/run.py
index 41090c5..cd4a2b5 100644
--- a/run.py
+++ b/run.py
@@ -5,7 +5,7 @@ import argparse
 
 import numpy as np
 
-from prediction import pred, valid
+from prediction import discrim
 from data_process import spectro_extract
 
 
@@ -13,14 +13,6 @@ def read_args() -> argparse.Namespace:
 
     parser = argparse.ArgumentParser()
 
-    parser.add_argument('--mode',
-                        type=str, default='pred',
-                        help=' "valid" for the validation mode or "pred" for the prediction mode.')
-
-    parser.add_argument('--model_dir',
-                        type=str, default='./model/model_2021354T1554.h5',
-                        help="Model file directory.")
-
     parser.add_argument('--data_dir',
                         type=str, default='./mseed_demo',
                         help="Input mseed file directory.")
@@ -30,36 +22,32 @@ def read_args() -> argparse.Namespace:
                         help='Output spectrogram file directory.')
 
     parser.add_argument('--csv_dir',
-                        default=None,
+                        required=True,
                         help="Input csv file directory")
 
     parser.add_argument('--output_dir',
                         type=str, default='./output_demo',
                         help='Output directory')
 
+    parser.add_argument('--valid', 
+                        action="store_true", 
+                        help=' if the option "valid" is specified the validation mode will be applied. Csv input must have two columns (time, label_cat)')
+
     args = parser.parse_args()
     return args
 
 
 def main(args: argparse.Namespace):
 
-    if args.mode == 'pred':
-        spectro_extract(data_dir=args.data_dir,
-                        spectro_dir=args.spectro_dir)
-        pred(model_dir=args.model_dir, spectro_dir=args.spectro_dir,
-             output_dir=args.output_dir)
 
-    elif args.mode == 'valid':
-        events = np.genfromtxt(
+    events = np.genfromtxt(
             f'{args.csv_dir}', delimiter=',', skip_header=1, dtype=str)
-        spectro_extract(data_dir=args.data_dir,
+    spectro_extract(data_dir=args.data_dir,
                         spectro_dir=args.spectro_dir, events_list=events)
-        valid(model_dir=args.model_dir, spectro_dir=args.spectro_dir,
-              output_dir=args.output_dir, event_label=events)
-
-    else:
-        print("Mode should be: valid, or pred")
+    discrim(spectro_dir=args.spectro_dir,
+              output_dir=args.output_dir, event_label=events, valid = args.valid)
 
+    
     return
 
 
-- 
GitLab