From 577a874288e84f5b57c4ae79b89712c67b1285ef Mon Sep 17 00:00:00 2001 From: junos Date: Thu, 12 Aug 2021 16:54:00 +0200 Subject: [PATCH] Add an example for linear regression. --- config/environment.yml | 1 + exploration/ex_ml_pipeline.py | 58 +++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/config/environment.yml b/config/environment.yml index f5a8128..e1ccedf 100644 --- a/config/environment.yml +++ b/config/environment.yml @@ -16,6 +16,7 @@ dependencies: - python-dotenv - pytz - seaborn + - scikit-learn - sqlalchemy - statsmodels - tabulate \ No newline at end of file diff --git a/exploration/ex_ml_pipeline.py b/exploration/ex_ml_pipeline.py index cd3c293..4e7e542 100644 --- a/exploration/ex_ml_pipeline.py +++ b/exploration/ex_ml_pipeline.py @@ -20,6 +20,8 @@ import os import sys import seaborn as sns +from sklearn import linear_model +from sklearn.model_selection import LeaveOneGroupOut, cross_val_score nb_dir = os.path.split(os.getcwd())[0] if nb_dir not in sys.path: @@ -75,6 +77,19 @@ df_esm_PANAS_daily_means = ( .rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"}) ) +# %% +df_esm_PANAS_daily_means = ( + df_esm_PANAS_daily_means.pivot( + index=["participant_id", "date_lj"], + columns="questionnaire_id", + values="esm_numeric_mean", + ) + .reset_index(col_level=1) + .rename(columns={8.0: "PA", 9.0: "NA"}) + .set_index(["participant_id", "date_lj"]) +) + + # %% df_proximity_daily_counts = proximity.count_proximity( df_proximity, ["participant_id", "date_lj"] @@ -86,7 +101,50 @@ df_proximity_daily_counts # %% [markdown] # # 3. Join features (and export to csv?) +# %% +df_full_data_daily_means = df_esm_PANAS_daily_means.join( + df_proximity_daily_counts +).reset_index() + # %% [markdown] # # 4. Machine learning model and parameters # %% +lin_reg_proximity = linear_model.LinearRegression() + +# %% [markdown] +# ## 4.1 Validation method + +# %% +logo = LeaveOneGroupOut() +logo.get_n_splits( + df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]], + df_full_data_daily_means["PA"], + groups=df_full_data_daily_means["participant_id"], +) + +# %% [markdown] +# ## 4.2 Fit results (export?) + +# %% +cross_val_score( + lin_reg_proximity, + df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]], + df_full_data_daily_means["PA"], + groups=df_full_data_daily_means["participant_id"], + cv=logo, + n_jobs=-1, + scoring="r2", +) + +# %% +lin_reg_proximity.fit( + df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]], + df_full_data_daily_means["PA"], +) + +# %% +lin_reg_proximity.score( + df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]], + df_full_data_daily_means["PA"], +)