diff --git a/exploration/ex_ml_pipeline.py b/exploration/ex_ml_pipeline.py index 681cd78..88b09a8 100644 --- a/exploration/ex_ml_pipeline.py +++ b/exploration/ex_ml_pipeline.py @@ -218,14 +218,9 @@ with open("../machine_learning/config/minimal_labels.yaml", "r") as file: # %% labels = machine_learning.labels.Labels(**labels_params) labels.participants_usernames = ptcp_2 +labels.set_participants_label("nokia_0000003") labels.questionnaires -# %% -all_features = sensor_features.get_features("all", "all") - -# %% -all_features.isna().any().any() - # %% labels.set_labels() diff --git a/machine_learning/features_sensor.py b/machine_learning/features_sensor.py index f25e37e..50850f4 100644 --- a/machine_learning/features_sensor.py +++ b/machine_learning/features_sensor.py @@ -125,7 +125,7 @@ class SensorFeatures: self.df_sms = helper.get_date_from_timestamp(self.df_sms) print("Got sms data from the DB.") - def get_sensor_data(self, data_type) -> pd.DataFrame: + def get_sensor_data(self, data_type: str) -> pd.DataFrame: if data_type == "proximity": return self.df_proximity elif data_type == "communication": diff --git a/machine_learning/labels.py b/machine_learning/labels.py index 9c4b968..f685e6b 100644 --- a/machine_learning/labels.py +++ b/machine_learning/labels.py @@ -1,11 +1,21 @@ import datetime +import warnings +from pathlib import Path from typing import Collection import pandas as pd +from pyprojroot import here import participants.query_db from features import esm from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME +from machine_learning.helper import to_csv_with_settings + +WARNING_PARTICIPANTS_LABEL = ( + "Before aggregating labels, please set participants label using self.set_participants_label() " + "to be used as a filename prefix when exporting data. " + "The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv" +) class Labels: @@ -14,12 +24,13 @@ class Labels: grouping_variable: str, labels: dict, participants_usernames: Collection = None, - ): + ) -> None: self.grouping_variable_name = grouping_variable self.grouping_variable = [grouping_variable] self.questionnaires = labels.keys() + self.participants_label: str = "" if participants_usernames is None: participants_usernames = participants.query_db.get_usernames( collection_start=datetime.date.fromisoformat("2020-08-01") @@ -32,9 +43,13 @@ class Labels: self.df_esm_clean = pd.DataFrame() self.df_esm_means = pd.DataFrame() + + self.folder: Path = Path() + self.filename_prefix = "" + self.construct_export_path() print("Labels initialized.") - def set_labels(self): + def set_labels(self) -> None: print("Querying database ...") self.df_esm = esm.get_esm_data(self.participants_usernames) print("Got ESM data from the DB.") @@ -54,13 +69,13 @@ class Labels: self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest) print("ESM data cleaned.") - def get_labels(self, questionnaire): + def get_labels(self, questionnaire: str) -> pd.DataFrame: if questionnaire == "PANAS": return self.df_esm_clean else: raise KeyError("This questionnaire has not been implemented as a label.") - def aggregate_labels(self): + def aggregate_labels(self) -> None: print("Aggregating labels ...") self.df_esm_means = ( self.df_esm_clean.groupby( @@ -81,6 +96,24 @@ class Labels: .set_index(["participant_id"] + self.grouping_variable) ) print("Labels aggregated.") + to_csv_with_settings( + self.df_esm_means, + self.folder, + self.filename_prefix, + data_type="_".join(self.questionnaires), + ) - def get_aggregated_labels(self): + def get_aggregated_labels(self) -> pd.DataFrame: return self.df_esm_means + + def construct_export_path(self) -> None: + if not self.participants_label: + warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning) + self.folder = here("machine_learning/intermediate_results/labels", warn=True) + self.filename_prefix = ( + self.participants_label + "_" + self.grouping_variable_name + ) + + def set_participants_label(self, label: str) -> None: + self.participants_label = label + self.construct_export_path()