From 6592612db73a698005dbe10b8a1acf842a771755 Mon Sep 17 00:00:00 2001 From: junos Date: Thu, 19 Aug 2021 17:44:04 +0200 Subject: [PATCH] Add a similar class for labels. --- exploration/ex_ml_pipeline.py | 14 ++++ machine_learning/config/minimal_labels.yaml | 6 ++ machine_learning/pipeline.py | 75 ++++++++++++++++----- 3 files changed, 80 insertions(+), 15 deletions(-) create mode 100644 machine_learning/config/minimal_labels.yaml diff --git a/exploration/ex_ml_pipeline.py b/exploration/ex_ml_pipeline.py index c91b8f7..d37b335 100644 --- a/exploration/ex_ml_pipeline.py +++ b/exploration/ex_ml_pipeline.py @@ -180,3 +180,17 @@ sensor_features.calculate_features() sensor_features.get_features("proximity", "all") # %% +with open("../machine_learning/config/minimal_labels.yaml", "r") as file: + labels_params = yaml.safe_load(file) + +# %% +labels = pipeline.Labels(**labels_params) +labels.questionnaires + +# %% +labels.set_labels() + +# %% +labels.get_labels("PANAS") + +# %% diff --git a/machine_learning/config/minimal_labels.yaml b/machine_learning/config/minimal_labels.yaml new file mode 100644 index 0000000..9e719ed --- /dev/null +++ b/machine_learning/config/minimal_labels.yaml @@ -0,0 +1,6 @@ +grouping_variable: date_lj +labels: + PANAS: + - PA + - NA +participants_usernames: [nokia_0000003] diff --git a/machine_learning/pipeline.py b/machine_learning/pipeline.py index f08d2cf..d5ead2f 100644 --- a/machine_learning/pipeline.py +++ b/machine_learning/pipeline.py @@ -79,6 +79,51 @@ class SensorFeatures: raise KeyError("This data type has not been implemented.") +class Labels: + def __init__( + self, + grouping_variable: str, + labels: dict, + participants_usernames: Collection = None, + ): + self.grouping_variable = grouping_variable + + self.questionnaires = labels.keys() + + if participants_usernames is None: + participants_usernames = participants.query_db.get_usernames( + collection_start=datetime.date.fromisoformat("2020-08-01") + ) + self.participants_usernames = participants_usernames + + self.df_esm = pd.DataFrame() + self.df_esm_preprocessed = pd.DataFrame() + self.df_esm_interest = pd.DataFrame() + self.df_esm_clean = pd.DataFrame() + + def set_labels(self): + self.df_esm = esm.get_esm_data(self.participants_usernames) + self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm) + if "PANAS" in self.questionnaires: + self.df_esm_interest = self.df_esm_preprocessed[ + ( + self.df_esm_preprocessed["questionnaire_id"] + == QUESTIONNAIRE_IDS.get("PANAS").get("PA") + ) + | ( + self.df_esm_preprocessed["questionnaire_id"] + == QUESTIONNAIRE_IDS.get("PANAS").get("NA") + ) + ] + self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest) + + def get_labels(self, questionnaire): + if questionnaire == "PANAS": + return self.df_esm_clean + else: + raise KeyError("This questionnaire has not been implemented as a label.") + + class MachineLearningPipeline: def __init__( self, @@ -117,21 +162,21 @@ class MachineLearningPipeline: self.df_esm_daily_means = pd.DataFrame() self.df_proximity_daily_counts = pd.DataFrame() - def get_labels(self): - self.df_esm = esm.get_esm_data(self.participants_usernames) - self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm) - if self.labels_questionnaire == "PANAS": - self.df_esm_interest = self.df_esm_preprocessed[ - ( - self.df_esm_preprocessed["questionnaire_id"] - == QUESTIONNAIRE_IDS.get("PANAS").get("PA") - ) - | ( - self.df_esm_preprocessed["questionnaire_id"] - == QUESTIONNAIRE_IDS.get("PANAS").get("NA") - ) - ] - self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest) + # def get_labels(self): + # self.df_esm = esm.get_esm_data(self.participants_usernames) + # self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm) + # if self.labels_questionnaire == "PANAS": + # self.df_esm_interest = self.df_esm_preprocessed[ + # ( + # self.df_esm_preprocessed["questionnaire_id"] + # == QUESTIONNAIRE_IDS.get("PANAS").get("PA") + # ) + # | ( + # self.df_esm_preprocessed["questionnaire_id"] + # == QUESTIONNAIRE_IDS.get("PANAS").get("NA") + # ) + # ] + # self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest) # def aggregate_daily(self): # self.df_esm_daily_means = (