[WIP] Add a class for model validation.
parent
0b98d59aad
commit
065cd4347e
|
@ -171,6 +171,7 @@ sensor_features.data_types
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
sensor_features.data_types = ["proximity", "communication"]
|
sensor_features.data_types = ["proximity", "communication"]
|
||||||
|
sensor_features.participants_usernames = ptcp_2
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
sensor_features.get_sensor_data("proximity")
|
sensor_features.get_sensor_data("proximity")
|
||||||
|
@ -199,6 +200,7 @@ with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
labels = pipeline.Labels(**labels_params)
|
labels = pipeline.Labels(**labels_params)
|
||||||
|
labels.participants_usernames = ptcp_2
|
||||||
labels.questionnaires
|
labels.questionnaires
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -214,3 +216,19 @@ labels.aggregate_labels()
|
||||||
labels.get_aggregated_labels()
|
labels.get_aggregated_labels()
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
model_validation = pipeline.ModelValidation(
|
||||||
|
sensor_features.get_features("all", "all"),
|
||||||
|
labels.get_aggregated_labels(),
|
||||||
|
group_variable="participant_id",
|
||||||
|
cv_name="loso",
|
||||||
|
)
|
||||||
|
model_validation.model = linear_model.LinearRegression()
|
||||||
|
model_validation.set_cv_method()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model_validation.cross_validate()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model_validation.groups
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
|
@ -2,7 +2,7 @@ import datetime
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.model_selection import cross_val_score
|
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
||||||
|
|
||||||
import participants.query_db
|
import participants.query_db
|
||||||
from features import communication, esm, helper, proximity
|
from features import communication, esm, helper, proximity
|
||||||
|
@ -139,7 +139,9 @@ class Labels:
|
||||||
|
|
||||||
def aggregate_labels(self):
|
def aggregate_labels(self):
|
||||||
self.df_esm_means = (
|
self.df_esm_means = (
|
||||||
self.df_esm_clean.groupby(["participant_id", "questionnaire_id"] + self.grouping_variable)
|
self.df_esm_clean.groupby(
|
||||||
|
["participant_id", "questionnaire_id"] + self.grouping_variable
|
||||||
|
)
|
||||||
.esm_user_answer_numeric.agg("mean")
|
.esm_user_answer_numeric.agg("mean")
|
||||||
.reset_index()
|
.reset_index()
|
||||||
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||||
|
@ -159,6 +161,42 @@ class Labels:
|
||||||
return self.df_esm_means
|
return self.df_esm_means
|
||||||
|
|
||||||
|
|
||||||
|
class ModelValidation:
|
||||||
|
def __init__(self, X, y, group_variable=None, cv_name="loso"):
|
||||||
|
self.model = None
|
||||||
|
self.cv = None
|
||||||
|
|
||||||
|
self.y = y["NA"]
|
||||||
|
# TODO Handle the case of multiple labels.
|
||||||
|
self.X = X.loc[self.y.index]
|
||||||
|
self.groups = self.y.index.get_level_values(group_variable)
|
||||||
|
|
||||||
|
self.cv_name = cv_name
|
||||||
|
|
||||||
|
def set_cv_method(self):
|
||||||
|
if self.cv_name == "loso":
|
||||||
|
self.cv = LeaveOneGroupOut()
|
||||||
|
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
|
||||||
|
|
||||||
|
def cross_validate(self):
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set self.model first, e.g. self.model = sklearn.linear_model.LinearRegression()"
|
||||||
|
)
|
||||||
|
# TODO Is ValueError appropriate here?
|
||||||
|
if self.cv is None:
|
||||||
|
raise ValueError("Please use set_cv_method() first.")
|
||||||
|
return cross_val_score(
|
||||||
|
estimator=self.model,
|
||||||
|
X=self.X,
|
||||||
|
y=self.y,
|
||||||
|
groups=self.groups,
|
||||||
|
cv=self.cv,
|
||||||
|
n_jobs=-1,
|
||||||
|
scoring="r2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def safe_outer_merge_on_index(left, right):
|
def safe_outer_merge_on_index(left, right):
|
||||||
if left.empty:
|
if left.empty:
|
||||||
return right
|
return right
|
||||||
|
|
Loading…
Reference in New Issue