Move to presentation.

ml_pipeline
junos 2022-12-07 21:24:20 +01:00
parent 525496418f
commit 95ab66fd81
20 changed files with 133 additions and 479 deletions

1
.gitignore vendored
View File

@ -5,6 +5,7 @@ __pycache__/
/exploration/*.ipynb
/config/*.ipynb
/statistical_analysis/*.ipynb
/presentation/*.ipynb
/machine_learning/intermediate_results/
/data/features/
/data/baseline/

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,131 @@
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.0
# kernelspec:
# display_name: straw2analysis
# language: python
# name: straw2analysis
# ---
# %%
# %matplotlib inline
import yaml
from sklearn import linear_model
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
import os
import importlib
import matplotlib.pyplot as plt
import sys
import numpy as np
import seaborn as sns
import pandas as pd
nb_dir = os.path.split(os.getcwd())[0]
if nb_dir not in sys.path:
sys.path.append(nb_dir)
# %%
from machine_learning import pipeline, features_sensor, labels, model
# %%
importlib.reload(labels)
# %%
with open("./config/prox_comm_PANAS_features.yaml", "r") as file:
sensor_features_params = yaml.safe_load(file)
sensor_features = features_sensor.SensorFeatures(**sensor_features_params)
#sensor_features.set_sensor_data()
sensor_features.calculate_features(cached=True)
# %%
all_features = sensor_features.get_features("all","all")
# %%
with open("./config/prox_comm_PANAS_labels.yaml", "r") as file:
labels_params = yaml.safe_load(file)
labels_current = labels.Labels(**labels_params)
#labels_current.set_labels()
labels_current.aggregate_labels(cached=True)
# %%
model_validation = model.ModelValidation(
sensor_features.get_features("all", "all"),
labels_current.get_aggregated_labels(),
group_variable="participant_id",
cv_name="loso",
)
model_validation.model = linear_model.LinearRegression()
model_validation.set_cv_method()
# %%
model_loso_r2 = model_validation.cross_validate()
# %%
print(model_loso_r2)
print(np.mean(model_loso_r2))
# %%
model_loso_r2[model_loso_r2 > 0]
# %%
logo = LeaveOneGroupOut()
# %%
try_X = model_validation.X.reset_index().drop(["participant_id","date_lj"], axis=1)
try_y = model_validation.y.reset_index().drop(["participant_id","date_lj"], axis=1)
# %%
model_loso_mean_absolute_error = -1 * cross_val_score(
estimator=model_validation.model,
X=try_X,
y=try_y,
groups=model_validation.groups,
cv=logo.split(X=try_X, y=try_y, groups=model_validation.groups),
scoring='neg_mean_absolute_error'
)
# %%
model_loso_mean_absolute_error
# %%
np.median(model_loso_mean_absolute_error)
# %%
model_validation.model.fit(try_X, try_y)
# %%
Y_predicted = model_validation.model.predict(try_X)
# %%
try_y.rename(columns={"NA": "NA_true"}, inplace=True)
try_y["NA_predicted"] = Y_predicted
NA_long = pd.wide_to_long(
try_y.reset_index(),
i="index",
j="value",
stubnames="NA",
sep="_",
suffix=".+",
)
# %%
g1 = sns.displot(NA_long, x="NA", hue="value", binwidth=0.1, height=5, aspect=1.5)
sns.move_legend(g1, "upper left", bbox_to_anchor=(.55, .45))
g1.set_axis_labels("Daily mean", "Day count")
display(g1)
g1.savefig("prox_comm_PANAS_predictions.pdf")
# %%
from sklearn.metrics import mean_absolute_error
# %%
mean_absolute_error(try_y["NA_true"], try_y["NA_predicted"])
# %%
model_loso_mean_absolute_error

2
rapids

@ -1 +1 @@
Subproject commit f78aa3e7b3567423b44045766b230cd60d557cb0
Subproject commit 8a6b52a97c95dcd8b70b980b4f46421b1a847905