Browse Source

Add an example for linear regression.

communication
junos 11 months ago
parent
commit
577a874288
  1. 1
      config/environment.yml
  2. 58
      exploration/ex_ml_pipeline.py

1
config/environment.yml

@ -16,6 +16,7 @@ dependencies:
- python-dotenv
- pytz
- seaborn
- scikit-learn
- sqlalchemy
- statsmodels
- tabulate

58
exploration/ex_ml_pipeline.py

@ -20,6 +20,8 @@ import os
import sys
import seaborn as sns
from sklearn import linear_model
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
nb_dir = os.path.split(os.getcwd())[0]
if nb_dir not in sys.path:
@ -75,6 +77,19 @@ df_esm_PANAS_daily_means = (
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
)
# %%
df_esm_PANAS_daily_means = (
df_esm_PANAS_daily_means.pivot(
index=["participant_id", "date_lj"],
columns="questionnaire_id",
values="esm_numeric_mean",
)
.reset_index(col_level=1)
.rename(columns={8.0: "PA", 9.0: "NA"})
.set_index(["participant_id", "date_lj"])
)
# %%
df_proximity_daily_counts = proximity.count_proximity(
df_proximity, ["participant_id", "date_lj"]
@ -86,7 +101,50 @@ df_proximity_daily_counts
# %% [markdown]
# # 3. Join features (and export to csv?)
# %%
df_full_data_daily_means = df_esm_PANAS_daily_means.join(
df_proximity_daily_counts
).reset_index()
# %% [markdown]
# # 4. Machine learning model and parameters
# %%
lin_reg_proximity = linear_model.LinearRegression()
# %% [markdown]
# ## 4.1 Validation method
# %%
logo = LeaveOneGroupOut()
logo.get_n_splits(
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
df_full_data_daily_means["PA"],
groups=df_full_data_daily_means["participant_id"],
)
# %% [markdown]
# ## 4.2 Fit results (export?)
# %%
cross_val_score(
lin_reg_proximity,
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
df_full_data_daily_means["PA"],
groups=df_full_data_daily_means["participant_id"],
cv=logo,
n_jobs=-1,
scoring="r2",
)
# %%
lin_reg_proximity.fit(
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
df_full_data_daily_means["PA"],
)
# %%
lin_reg_proximity.score(
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
df_full_data_daily_means["PA"],
)
Loading…
Cancel
Save