From 0b98d59aadc19f1cd168ea6bed36c5307a9a3c46 Mon Sep 17 00:00:00 2001 From: junos Date: Fri, 20 Aug 2021 19:17:22 +0200 Subject: [PATCH] Aggregate labels using grouping_variable. --- exploration/ex_ml_pipeline.py | 10 +++++++++ machine_learning/config/minimal_labels.yaml | 2 +- machine_learning/pipeline.py | 25 ++++++++++++++++++++- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/exploration/ex_ml_pipeline.py b/exploration/ex_ml_pipeline.py index 9471ebc..596d67e 100644 --- a/exploration/ex_ml_pipeline.py +++ b/exploration/ex_ml_pipeline.py @@ -16,6 +16,7 @@ # %% # %matplotlib inline import datetime +import importlib import os import sys @@ -156,6 +157,9 @@ lin_reg_proximity.score( # %% from machine_learning import pipeline +# %% +importlib.reload(pipeline) + # %% with open("../machine_learning/config/minimal_features.yaml", "r") as file: sensor_features_params = yaml.safe_load(file) @@ -204,3 +208,9 @@ labels.set_labels() labels.get_labels("PANAS") # %% +labels.aggregate_labels() + +# %% +labels.get_aggregated_labels() + +# %% diff --git a/machine_learning/config/minimal_labels.yaml b/machine_learning/config/minimal_labels.yaml index 9e719ed..25d3e8f 100644 --- a/machine_learning/config/minimal_labels.yaml +++ b/machine_learning/config/minimal_labels.yaml @@ -1,4 +1,4 @@ -grouping_variable: date_lj +grouping_variable: [date_lj] labels: PANAS: - PA diff --git a/machine_learning/pipeline.py b/machine_learning/pipeline.py index 2370fde..00fc00b 100644 --- a/machine_learning/pipeline.py +++ b/machine_learning/pipeline.py @@ -94,7 +94,7 @@ class SensorFeatures: class Labels: def __init__( self, - grouping_variable: str, + grouping_variable: list, labels: dict, participants_usernames: Collection = None, ): @@ -113,6 +113,8 @@ class Labels: self.df_esm_interest = pd.DataFrame() self.df_esm_clean = pd.DataFrame() + self.df_esm_means = pd.DataFrame() + def set_labels(self): self.df_esm = esm.get_esm_data(self.participants_usernames) self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm) @@ -135,6 +137,27 @@ class Labels: else: raise KeyError("This questionnaire has not been implemented as a label.") + def aggregate_labels(self): + self.df_esm_means = ( + self.df_esm_clean.groupby(["participant_id", "questionnaire_id"] + self.grouping_variable) + .esm_user_answer_numeric.agg("mean") + .reset_index() + .rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"}) + ) + self.df_esm_means = ( + self.df_esm_means.pivot( + index=["participant_id"] + self.grouping_variable, + columns="questionnaire_id", + values="esm_numeric_mean", + ) + .reset_index(col_level=1) + .rename(columns=QUESTIONNAIRE_IDS_RENAME) + .set_index(["participant_id"] + self.grouping_variable) + ) + + def get_aggregated_labels(self): + return self.df_esm_means + def safe_outer_merge_on_index(left, right): if left.empty: