Add an example for linear regression.
parent
c8bb481508
commit
577a874288
|
@ -16,6 +16,7 @@ dependencies:
|
||||||
- python-dotenv
|
- python-dotenv
|
||||||
- pytz
|
- pytz
|
||||||
- seaborn
|
- seaborn
|
||||||
|
- scikit-learn
|
||||||
- sqlalchemy
|
- sqlalchemy
|
||||||
- statsmodels
|
- statsmodels
|
||||||
- tabulate
|
- tabulate
|
|
@ -20,6 +20,8 @@ import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
from sklearn import linear_model
|
||||||
|
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
||||||
|
|
||||||
nb_dir = os.path.split(os.getcwd())[0]
|
nb_dir = os.path.split(os.getcwd())[0]
|
||||||
if nb_dir not in sys.path:
|
if nb_dir not in sys.path:
|
||||||
|
@ -75,6 +77,19 @@ df_esm_PANAS_daily_means = (
|
||||||
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
df_esm_PANAS_daily_means = (
|
||||||
|
df_esm_PANAS_daily_means.pivot(
|
||||||
|
index=["participant_id", "date_lj"],
|
||||||
|
columns="questionnaire_id",
|
||||||
|
values="esm_numeric_mean",
|
||||||
|
)
|
||||||
|
.reset_index(col_level=1)
|
||||||
|
.rename(columns={8.0: "PA", 9.0: "NA"})
|
||||||
|
.set_index(["participant_id", "date_lj"])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
df_proximity_daily_counts = proximity.count_proximity(
|
df_proximity_daily_counts = proximity.count_proximity(
|
||||||
df_proximity, ["participant_id", "date_lj"]
|
df_proximity, ["participant_id", "date_lj"]
|
||||||
|
@ -86,7 +101,50 @@ df_proximity_daily_counts
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # 3. Join features (and export to csv?)
|
# # 3. Join features (and export to csv?)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
df_full_data_daily_means = df_esm_PANAS_daily_means.join(
|
||||||
|
df_proximity_daily_counts
|
||||||
|
).reset_index()
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # 4. Machine learning model and parameters
|
# # 4. Machine learning model and parameters
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
lin_reg_proximity = linear_model.LinearRegression()
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## 4.1 Validation method
|
||||||
|
|
||||||
|
# %%
|
||||||
|
logo = LeaveOneGroupOut()
|
||||||
|
logo.get_n_splits(
|
||||||
|
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
|
||||||
|
df_full_data_daily_means["PA"],
|
||||||
|
groups=df_full_data_daily_means["participant_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## 4.2 Fit results (export?)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
cross_val_score(
|
||||||
|
lin_reg_proximity,
|
||||||
|
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
|
||||||
|
df_full_data_daily_means["PA"],
|
||||||
|
groups=df_full_data_daily_means["participant_id"],
|
||||||
|
cv=logo,
|
||||||
|
n_jobs=-1,
|
||||||
|
scoring="r2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
lin_reg_proximity.fit(
|
||||||
|
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
|
||||||
|
df_full_data_daily_means["PA"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
lin_reg_proximity.score(
|
||||||
|
df_full_data_daily_means[["freq_prox_near", "prop_prox_near"]],
|
||||||
|
df_full_data_daily_means["PA"],
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue