diff --git a/exploration/ex_ml_pipeline.py b/exploration/ex_ml_pipeline.py index 88b09a8..fec5717 100644 --- a/exploration/ex_ml_pipeline.py +++ b/exploration/ex_ml_pipeline.py @@ -228,10 +228,20 @@ labels.set_labels() labels.get_labels("PANAS") # %% -labels.aggregate_labels() +labels.aggregate_labels(cached=False) +labels_calculated = labels.get_aggregated_labels() # %% -labels.get_aggregated_labels() +labels.aggregate_labels(cached=True) +labels_read = labels.get_aggregated_labels() +labels_read = labels_read.reset_index() +labels_read["date_lj"] = labels_read["date_lj"].dt.date +labels_read.set_index(["participant_id", "date_lj"], inplace=True) +# date_lj column is parsed as a date and represented as Timestamp, when read from csv. +# When calculated, it is represented as date. + +# %% +np.isclose(labels_read, labels_calculated).all() # %% model_validation = machine_learning.model.ModelValidation( diff --git a/machine_learning/labels.py b/machine_learning/labels.py index f685e6b..6a59c65 100644 --- a/machine_learning/labels.py +++ b/machine_learning/labels.py @@ -9,7 +9,7 @@ from pyprojroot import here import participants.query_db from features import esm from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME -from machine_learning.helper import to_csv_with_settings +from machine_learning.helper import to_csv_with_settings, read_csv_with_settings WARNING_PARTICIPANTS_LABEL = ( "Before aggregating labels, please set participants label using self.set_participants_label() " @@ -75,33 +75,48 @@ class Labels: else: raise KeyError("This questionnaire has not been implemented as a label.") - def aggregate_labels(self) -> None: + def aggregate_labels(self, cached=True) -> None: print("Aggregating labels ...") - self.df_esm_means = ( - self.df_esm_clean.groupby( - ["participant_id", "questionnaire_id"] + self.grouping_variable + if not self.participants_label: + raise ValueError(WARNING_PARTICIPANTS_LABEL) + + try: + if not cached: # Do not use the file, even if it exists. + raise FileNotFoundError + self.df_esm_means = read_csv_with_settings( + self.folder, + self.filename_prefix, + data_type="_".join(self.questionnaires), + grouping_variable=self.grouping_variable ) - .esm_user_answer_numeric.agg("mean") - .reset_index() - .rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"}) - ) - self.df_esm_means = ( - self.df_esm_means.pivot( - index=["participant_id"] + self.grouping_variable, - columns="questionnaire_id", - values="esm_numeric_mean", + print("Read labels from the file.") + except FileNotFoundError: + # We need to recalculate the features in this case. + self.df_esm_means = ( + self.df_esm_clean.groupby( + ["participant_id", "questionnaire_id"] + self.grouping_variable + ) + .esm_user_answer_numeric.agg("mean") + .reset_index() + .rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"}) + ) + self.df_esm_means = ( + self.df_esm_means.pivot( + index=["participant_id"] + self.grouping_variable, + columns="questionnaire_id", + values="esm_numeric_mean", + ) + .reset_index(col_level=1) + .rename(columns=QUESTIONNAIRE_IDS_RENAME) + .set_index(["participant_id"] + self.grouping_variable) + ) + print("Labels aggregated.") + to_csv_with_settings( + self.df_esm_means, + self.folder, + self.filename_prefix, + data_type="_".join(self.questionnaires), ) - .reset_index(col_level=1) - .rename(columns=QUESTIONNAIRE_IDS_RENAME) - .set_index(["participant_id"] + self.grouping_variable) - ) - print("Labels aggregated.") - to_csv_with_settings( - self.df_esm_means, - self.folder, - self.filename_prefix, - data_type="_".join(self.questionnaires), - ) def get_aggregated_labels(self) -> pd.DataFrame: return self.df_esm_means