Move to presentation.
parent
525496418f
commit
95ab66fd81
|
@ -5,6 +5,7 @@ __pycache__/
|
|||
/exploration/*.ipynb
|
||||
/config/*.ipynb
|
||||
/statistical_analysis/*.ipynb
|
||||
/presentation/*.ipynb
|
||||
/machine_learning/intermediate_results/
|
||||
/data/features/
|
||||
/data/baseline/
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,131 @@
|
|||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.13.0
|
||||
# kernelspec:
|
||||
# display_name: straw2analysis
|
||||
# language: python
|
||||
# name: straw2analysis
|
||||
# ---
|
||||
|
||||
# %%
|
||||
# %matplotlib inline
|
||||
import yaml
|
||||
from sklearn import linear_model
|
||||
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
||||
import os
|
||||
import importlib
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
|
||||
nb_dir = os.path.split(os.getcwd())[0]
|
||||
if nb_dir not in sys.path:
|
||||
sys.path.append(nb_dir)
|
||||
|
||||
# %%
|
||||
from machine_learning import pipeline, features_sensor, labels, model
|
||||
|
||||
# %%
|
||||
importlib.reload(labels)
|
||||
|
||||
# %%
|
||||
with open("./config/prox_comm_PANAS_features.yaml", "r") as file:
|
||||
sensor_features_params = yaml.safe_load(file)
|
||||
sensor_features = features_sensor.SensorFeatures(**sensor_features_params)
|
||||
#sensor_features.set_sensor_data()
|
||||
sensor_features.calculate_features(cached=True)
|
||||
|
||||
# %%
|
||||
all_features = sensor_features.get_features("all","all")
|
||||
|
||||
# %%
|
||||
with open("./config/prox_comm_PANAS_labels.yaml", "r") as file:
|
||||
labels_params = yaml.safe_load(file)
|
||||
labels_current = labels.Labels(**labels_params)
|
||||
#labels_current.set_labels()
|
||||
labels_current.aggregate_labels(cached=True)
|
||||
|
||||
# %%
|
||||
model_validation = model.ModelValidation(
|
||||
sensor_features.get_features("all", "all"),
|
||||
labels_current.get_aggregated_labels(),
|
||||
group_variable="participant_id",
|
||||
cv_name="loso",
|
||||
)
|
||||
model_validation.model = linear_model.LinearRegression()
|
||||
model_validation.set_cv_method()
|
||||
|
||||
# %%
|
||||
model_loso_r2 = model_validation.cross_validate()
|
||||
|
||||
# %%
|
||||
print(model_loso_r2)
|
||||
print(np.mean(model_loso_r2))
|
||||
|
||||
# %%
|
||||
model_loso_r2[model_loso_r2 > 0]
|
||||
|
||||
# %%
|
||||
logo = LeaveOneGroupOut()
|
||||
|
||||
# %%
|
||||
try_X = model_validation.X.reset_index().drop(["participant_id","date_lj"], axis=1)
|
||||
try_y = model_validation.y.reset_index().drop(["participant_id","date_lj"], axis=1)
|
||||
|
||||
# %%
|
||||
model_loso_mean_absolute_error = -1 * cross_val_score(
|
||||
estimator=model_validation.model,
|
||||
X=try_X,
|
||||
y=try_y,
|
||||
groups=model_validation.groups,
|
||||
cv=logo.split(X=try_X, y=try_y, groups=model_validation.groups),
|
||||
scoring='neg_mean_absolute_error'
|
||||
)
|
||||
|
||||
# %%
|
||||
model_loso_mean_absolute_error
|
||||
|
||||
# %%
|
||||
np.median(model_loso_mean_absolute_error)
|
||||
|
||||
# %%
|
||||
model_validation.model.fit(try_X, try_y)
|
||||
|
||||
# %%
|
||||
Y_predicted = model_validation.model.predict(try_X)
|
||||
|
||||
# %%
|
||||
try_y.rename(columns={"NA": "NA_true"}, inplace=True)
|
||||
try_y["NA_predicted"] = Y_predicted
|
||||
NA_long = pd.wide_to_long(
|
||||
try_y.reset_index(),
|
||||
i="index",
|
||||
j="value",
|
||||
stubnames="NA",
|
||||
sep="_",
|
||||
suffix=".+",
|
||||
)
|
||||
|
||||
# %%
|
||||
g1 = sns.displot(NA_long, x="NA", hue="value", binwidth=0.1, height=5, aspect=1.5)
|
||||
sns.move_legend(g1, "upper left", bbox_to_anchor=(.55, .45))
|
||||
g1.set_axis_labels("Daily mean", "Day count")
|
||||
|
||||
display(g1)
|
||||
g1.savefig("prox_comm_PANAS_predictions.pdf")
|
||||
|
||||
# %%
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
|
||||
# %%
|
||||
mean_absolute_error(try_y["NA_true"], try_y["NA_predicted"])
|
||||
|
||||
# %%
|
||||
model_loso_mean_absolute_error
|
2
rapids
2
rapids
|
@ -1 +1 @@
|
|||
Subproject commit f78aa3e7b3567423b44045766b230cd60d557cb0
|
||||
Subproject commit 8a6b52a97c95dcd8b70b980b4f46421b1a847905
|
Loading…
Reference in New Issue