Refactor machine_learning/pipeline.py by defining one class by file.
parent
c1bb4ddf0f
commit
b19eebbb92
|
@ -6,7 +6,7 @@
|
|||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.11.4
|
||||
# jupytext_version: 1.12.0
|
||||
# kernelspec:
|
||||
# display_name: straw2analysis
|
||||
# language: python
|
||||
|
@ -32,6 +32,9 @@ if nb_dir not in sys.path:
|
|||
# %%
|
||||
import participants.query_db
|
||||
from features import esm, helper, proximity
|
||||
import machine_learning.features_sensor
|
||||
import machine_learning.labels
|
||||
import machine_learning.model
|
||||
|
||||
# %% [markdown]
|
||||
# # 1. Get the relevant data
|
||||
|
@ -166,7 +169,7 @@ with open("../machine_learning/config/minimal_features.yaml", "r") as file:
|
|||
print(sensor_features_params)
|
||||
|
||||
# %%
|
||||
sensor_features = pipeline.SensorFeatures(**sensor_features_params)
|
||||
sensor_features = machine_learning.features_sensor.SensorFeatures(**sensor_features_params)
|
||||
sensor_features.data_types
|
||||
|
||||
# %%
|
||||
|
@ -188,12 +191,6 @@ sensor_features.get_sensor_data("proximity")
|
|||
# %%
|
||||
sensor_features.calculate_features()
|
||||
|
||||
# %%
|
||||
sensor_features.get_features("proximity", "all")
|
||||
|
||||
# %%
|
||||
sensor_features.get_features("communication", "all")
|
||||
|
||||
# %%
|
||||
sensor_features.get_features("all", "all")
|
||||
|
||||
|
@ -202,10 +199,16 @@ with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
|||
labels_params = yaml.safe_load(file)
|
||||
|
||||
# %%
|
||||
labels = pipeline.Labels(**labels_params)
|
||||
labels = machine_learning.labels.Labels(**labels_params)
|
||||
labels.participants_usernames = ptcp_2
|
||||
labels.questionnaires
|
||||
|
||||
# %%
|
||||
all_features = sensor_features.get_features("all", "all")
|
||||
|
||||
# %%
|
||||
all_features.isna().any().any()
|
||||
|
||||
# %%
|
||||
labels.set_labels()
|
||||
|
||||
|
@ -219,7 +222,7 @@ labels.aggregate_labels()
|
|||
labels.get_aggregated_labels()
|
||||
|
||||
# %%
|
||||
model_validation = pipeline.ModelValidation(
|
||||
model_validation = machine_learning.model.ModelValidation(
|
||||
sensor_features.get_features("all", "all"),
|
||||
labels.get_aggregated_labels(),
|
||||
group_variable="participant_id",
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
grouping_variable: [date_lj]
|
||||
grouping_variable: date_lj
|
||||
labels:
|
||||
PANAS:
|
||||
- PA
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
grouping_variable: [date_lj]
|
||||
grouping_variable: date_lj
|
||||
features:
|
||||
proximity:
|
||||
all
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
grouping_variable: [date_lj]
|
||||
grouping_variable: date_lj
|
||||
labels:
|
||||
PANAS:
|
||||
- PA
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
import datetime
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Collection
|
||||
|
||||
import pandas as pd
|
||||
from pyprojroot import here
|
||||
|
||||
import participants.query_db
|
||||
from features import proximity, helper, communication
|
||||
|
||||
WARNING_PARTICIPANTS_LABEL = (
|
||||
"Before calculating features, please set participants label using self.set_participants_label() "
|
||||
"to be used as a filename prefix when exporting data. "
|
||||
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
|
||||
)
|
||||
|
||||
|
||||
class SensorFeatures:
|
||||
def __init__(
|
||||
self,
|
||||
grouping_variable: str,
|
||||
features: dict,
|
||||
participants_usernames: Collection = None,
|
||||
):
|
||||
|
||||
self.grouping_variable_name = grouping_variable
|
||||
self.grouping_variable = [grouping_variable]
|
||||
|
||||
self.data_types = features.keys()
|
||||
|
||||
self.participants_label: str = ""
|
||||
if participants_usernames is None:
|
||||
participants_usernames = participants.query_db.get_usernames(
|
||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||
)
|
||||
self.participants_label = "all"
|
||||
self.participants_usernames = participants_usernames
|
||||
|
||||
self.df_features_all = pd.DataFrame()
|
||||
|
||||
self.df_proximity = pd.DataFrame()
|
||||
self.df_proximity_counts = pd.DataFrame()
|
||||
|
||||
self.df_calls = pd.DataFrame()
|
||||
self.df_sms = pd.DataFrame()
|
||||
self.df_calls_sms = pd.DataFrame()
|
||||
|
||||
self.folder = None
|
||||
self.filename_prefix = ""
|
||||
self.construct_export_path()
|
||||
print("SensorFeatures initialized.")
|
||||
|
||||
def set_sensor_data(self):
|
||||
print("Querying database ...")
|
||||
if "proximity" in self.data_types:
|
||||
self.df_proximity = proximity.get_proximity_data(
|
||||
self.participants_usernames
|
||||
)
|
||||
print("Got proximity data from the DB.")
|
||||
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
||||
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
||||
if "communication" in self.data_types:
|
||||
self.df_calls = communication.get_call_data(self.participants_usernames)
|
||||
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
||||
print("Got calls data from the DB.")
|
||||
|
||||
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
||||
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
||||
print("Got sms data from the DB.")
|
||||
|
||||
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
||||
if data_type == "proximity":
|
||||
return self.df_proximity
|
||||
elif data_type == "communication":
|
||||
return self.df_calls_sms
|
||||
else:
|
||||
raise KeyError("This data type has not been implemented.")
|
||||
|
||||
def calculate_features(self):
|
||||
print("Calculating features ...")
|
||||
if not self.participants_label:
|
||||
raise ValueError(WARNING_PARTICIPANTS_LABEL)
|
||||
if "proximity" in self.data_types:
|
||||
self.df_proximity_counts = proximity.count_proximity(
|
||||
self.df_proximity, self.grouping_variable
|
||||
)
|
||||
self.df_features_all = safe_outer_merge_on_index(
|
||||
self.df_features_all, self.df_proximity_counts
|
||||
)
|
||||
print("Calculated proximity features.")
|
||||
to_csv_with_settings(
|
||||
self.df_proximity, self.folder, self.filename_prefix, data_type="prox"
|
||||
)
|
||||
|
||||
if "communication" in self.data_types:
|
||||
self.df_calls_sms = communication.calls_sms_features(
|
||||
df_calls=self.df_calls,
|
||||
df_sms=self.df_sms,
|
||||
group_by=self.grouping_variable,
|
||||
)
|
||||
self.df_features_all = safe_outer_merge_on_index(
|
||||
self.df_features_all, self.df_calls_sms
|
||||
)
|
||||
print("Calculated communication features.")
|
||||
to_csv_with_settings(
|
||||
self.df_calls_sms, self.folder, self.filename_prefix, data_type="comm"
|
||||
)
|
||||
|
||||
self.df_features_all.fillna(
|
||||
value=proximity.FILL_NA_PROXIMITY, inplace=True, downcast="infer",
|
||||
)
|
||||
self.df_features_all.fillna(
|
||||
value=communication.FILL_NA_CALLS_SMS_ALL, inplace=True, downcast="infer",
|
||||
)
|
||||
|
||||
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
||||
if data_type == "proximity":
|
||||
if feature_names == "all":
|
||||
feature_names = proximity.FEATURES_PROXIMITY
|
||||
return self.df_proximity_counts[feature_names]
|
||||
elif data_type == "communication":
|
||||
if feature_names == "all":
|
||||
feature_names = communication.FEATURES_CALLS_SMS_ALL
|
||||
return self.df_calls_sms[feature_names]
|
||||
elif data_type == "all":
|
||||
return self.df_features_all
|
||||
else:
|
||||
raise KeyError("This data type has not been implemented.")
|
||||
|
||||
def construct_export_path(self):
|
||||
if not self.participants_label:
|
||||
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
|
||||
self.folder = here("machine_learning/intermediate_results/features", warn=True)
|
||||
self.filename_prefix = (
|
||||
self.participants_label + "_" + self.grouping_variable_name
|
||||
)
|
||||
|
||||
def set_participants_label(self, label: str):
|
||||
self.participants_label = label
|
||||
self.construct_export_path()
|
||||
|
||||
|
||||
def safe_outer_merge_on_index(left, right):
|
||||
if left.empty:
|
||||
return right
|
||||
elif right.empty:
|
||||
return left
|
||||
else:
|
||||
return pd.merge(
|
||||
left,
|
||||
right,
|
||||
how="outer",
|
||||
left_index=True,
|
||||
right_index=True,
|
||||
validate="one_to_one",
|
||||
)
|
||||
|
||||
|
||||
def to_csv_with_settings(
|
||||
df: pd.DataFrame, folder: Path, filename_prefix: str, data_type: str
|
||||
) -> None:
|
||||
export_filename = filename_prefix + "_" + data_type + ".csv"
|
||||
full_path = folder / export_filename
|
||||
df.to_csv(
|
||||
path_or_buf=full_path,
|
||||
sep=",",
|
||||
na_rep="NA",
|
||||
header=True,
|
||||
index=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
print("Exported the dataframe to " + str(full_path))
|
|
@ -0,0 +1,86 @@
|
|||
import datetime
|
||||
from typing import Collection
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import participants.query_db
|
||||
from features import esm
|
||||
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||
|
||||
|
||||
class Labels:
|
||||
def __init__(
|
||||
self,
|
||||
grouping_variable: str,
|
||||
labels: dict,
|
||||
participants_usernames: Collection = None,
|
||||
):
|
||||
self.grouping_variable_name = grouping_variable
|
||||
self.grouping_variable = [grouping_variable]
|
||||
|
||||
self.questionnaires = labels.keys()
|
||||
|
||||
if participants_usernames is None:
|
||||
participants_usernames = participants.query_db.get_usernames(
|
||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||
)
|
||||
self.participants_usernames = participants_usernames
|
||||
|
||||
self.df_esm = pd.DataFrame()
|
||||
self.df_esm_preprocessed = pd.DataFrame()
|
||||
self.df_esm_interest = pd.DataFrame()
|
||||
self.df_esm_clean = pd.DataFrame()
|
||||
|
||||
self.df_esm_means = pd.DataFrame()
|
||||
print("Labels initialized.")
|
||||
|
||||
def set_labels(self):
|
||||
print("Querying database ...")
|
||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||
print("Got ESM data from the DB.")
|
||||
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
||||
print("ESM data preprocessed.")
|
||||
if "PANAS" in self.questionnaires:
|
||||
self.df_esm_interest = self.df_esm_preprocessed[
|
||||
(
|
||||
self.df_esm_preprocessed["questionnaire_id"]
|
||||
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
||||
)
|
||||
| (
|
||||
self.df_esm_preprocessed["questionnaire_id"]
|
||||
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
||||
)
|
||||
]
|
||||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||
print("ESM data cleaned.")
|
||||
|
||||
def get_labels(self, questionnaire):
|
||||
if questionnaire == "PANAS":
|
||||
return self.df_esm_clean
|
||||
else:
|
||||
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||
|
||||
def aggregate_labels(self):
|
||||
print("Aggregating labels ...")
|
||||
self.df_esm_means = (
|
||||
self.df_esm_clean.groupby(
|
||||
["participant_id", "questionnaire_id"] + self.grouping_variable
|
||||
)
|
||||
.esm_user_answer_numeric.agg("mean")
|
||||
.reset_index()
|
||||
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||
)
|
||||
self.df_esm_means = (
|
||||
self.df_esm_means.pivot(
|
||||
index=["participant_id"] + self.grouping_variable,
|
||||
columns="questionnaire_id",
|
||||
values="esm_numeric_mean",
|
||||
)
|
||||
.reset_index(col_level=1)
|
||||
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
||||
.set_index(["participant_id"] + self.grouping_variable)
|
||||
)
|
||||
print("Labels aggregated.")
|
||||
|
||||
def get_aggregated_labels(self):
|
||||
return self.df_esm_means
|
|
@ -0,0 +1,47 @@
|
|||
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
||||
|
||||
|
||||
class ModelValidation:
|
||||
def __init__(self, X, y, group_variable=None, cv_name="loso"):
|
||||
self.model = None
|
||||
self.cv = None
|
||||
|
||||
idx_common = X.index.intersection(y.index)
|
||||
self.y = y.loc[idx_common, "NA"]
|
||||
# TODO Handle the case of multiple labels.
|
||||
self.X = X.loc[idx_common]
|
||||
self.groups = self.y.index.get_level_values(group_variable)
|
||||
|
||||
self.cv_name = cv_name
|
||||
print("ModelValidation initialized.")
|
||||
|
||||
def set_cv_method(self):
|
||||
if self.cv_name == "loso":
|
||||
self.cv = LeaveOneGroupOut()
|
||||
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
|
||||
print("Validation method set.")
|
||||
|
||||
def cross_validate(self):
|
||||
print("Running cross validation ...")
|
||||
if self.model is None:
|
||||
raise TypeError(
|
||||
"Please, specify a machine learning model first, by setting the .model attribute. "
|
||||
"E.g. self.model = sklearn.linear_model.LinearRegression()"
|
||||
)
|
||||
if self.cv is None:
|
||||
raise TypeError(
|
||||
"Please, specify a cross validation method first, by using set_cv_method() first."
|
||||
)
|
||||
if self.X.isna().any().any() or self.y.isna().any().any():
|
||||
raise ValueError(
|
||||
"NaNs were found in either X or y. Please, check your data before continuing."
|
||||
)
|
||||
return cross_val_score(
|
||||
estimator=self.model,
|
||||
X=self.X,
|
||||
y=self.y,
|
||||
groups=self.groups,
|
||||
cv=self.cv,
|
||||
n_jobs=-1,
|
||||
scoring="r2",
|
||||
)
|
|
@ -1,305 +1,10 @@
|
|||
import datetime
|
||||
import warnings
|
||||
from collections.abc import Collection
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from pyprojroot import here
|
||||
from sklearn import linear_model
|
||||
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
||||
|
||||
import participants.query_db
|
||||
from features import communication, esm, helper, proximity
|
||||
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||
|
||||
WARNING_PARTICIPANTS_LABEL = (
|
||||
"Before calculating features, please set participants label using self.set_participants_label() "
|
||||
"to be used as a filename prefix when exporting data. "
|
||||
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
|
||||
)
|
||||
|
||||
|
||||
class SensorFeatures:
|
||||
def __init__(
|
||||
self,
|
||||
grouping_variable: str,
|
||||
features: dict,
|
||||
participants_usernames: Collection = None,
|
||||
):
|
||||
|
||||
self.grouping_variable_name = grouping_variable
|
||||
self.grouping_variable = [grouping_variable]
|
||||
|
||||
self.data_types = features.keys()
|
||||
|
||||
self.participants_label: str = ""
|
||||
if participants_usernames is None:
|
||||
participants_usernames = participants.query_db.get_usernames(
|
||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||
)
|
||||
self.participants_label = "all"
|
||||
self.participants_usernames = participants_usernames
|
||||
|
||||
self.df_features_all = pd.DataFrame()
|
||||
|
||||
self.df_proximity = pd.DataFrame()
|
||||
self.df_proximity_counts = pd.DataFrame()
|
||||
|
||||
self.df_calls = pd.DataFrame()
|
||||
self.df_sms = pd.DataFrame()
|
||||
self.df_calls_sms = pd.DataFrame()
|
||||
|
||||
self.folder = None
|
||||
self.filename_prefix = ""
|
||||
self.construct_export_path()
|
||||
print("SensorFeatures initialized.")
|
||||
|
||||
def set_sensor_data(self):
|
||||
print("Querying database ...")
|
||||
if "proximity" in self.data_types:
|
||||
self.df_proximity = proximity.get_proximity_data(
|
||||
self.participants_usernames
|
||||
)
|
||||
print("Got proximity data from the DB.")
|
||||
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
||||
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
||||
if "communication" in self.data_types:
|
||||
self.df_calls = communication.get_call_data(self.participants_usernames)
|
||||
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
||||
print("Got calls data from the DB.")
|
||||
|
||||
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
||||
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
||||
print("Got sms data from the DB.")
|
||||
|
||||
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
||||
if data_type == "proximity":
|
||||
return self.df_proximity
|
||||
elif data_type == "communication":
|
||||
return self.df_calls_sms
|
||||
else:
|
||||
raise KeyError("This data type has not been implemented.")
|
||||
|
||||
def calculate_features(self):
|
||||
print("Calculating features ...")
|
||||
if not self.participants_label:
|
||||
raise ValueError(WARNING_PARTICIPANTS_LABEL)
|
||||
if "proximity" in self.data_types:
|
||||
self.df_proximity_counts = proximity.count_proximity(
|
||||
self.df_proximity, self.grouping_variable
|
||||
)
|
||||
self.df_features_all = safe_outer_merge_on_index(
|
||||
self.df_features_all, self.df_proximity_counts
|
||||
)
|
||||
print("Calculated proximity features.")
|
||||
to_csv_with_settings(
|
||||
self.df_proximity, self.folder, self.filename_prefix, data_type="prox"
|
||||
)
|
||||
|
||||
if "communication" in self.data_types:
|
||||
self.df_calls_sms = communication.calls_sms_features(
|
||||
df_calls=self.df_calls,
|
||||
df_sms=self.df_sms,
|
||||
group_by=self.grouping_variable,
|
||||
)
|
||||
self.df_features_all = safe_outer_merge_on_index(
|
||||
self.df_features_all, self.df_calls_sms
|
||||
)
|
||||
print("Calculated communication features.")
|
||||
to_csv_with_settings(
|
||||
self.df_calls_sms, self.folder, self.filename_prefix, data_type="comm"
|
||||
)
|
||||
|
||||
self.df_features_all.fillna(
|
||||
value=proximity.FILL_NA_PROXIMITY, inplace=True, downcast="infer",
|
||||
)
|
||||
self.df_features_all.fillna(
|
||||
value=communication.FILL_NA_CALLS_SMS_ALL, inplace=True, downcast="infer",
|
||||
)
|
||||
|
||||
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
||||
if data_type == "proximity":
|
||||
if feature_names == "all":
|
||||
feature_names = proximity.FEATURES_PROXIMITY
|
||||
return self.df_proximity_counts[feature_names]
|
||||
elif data_type == "communication":
|
||||
if feature_names == "all":
|
||||
feature_names = communication.FEATURES_CALLS_SMS_ALL
|
||||
return self.df_calls_sms[feature_names]
|
||||
elif data_type == "all":
|
||||
return self.df_features_all
|
||||
else:
|
||||
raise KeyError("This data type has not been implemented.")
|
||||
|
||||
def construct_export_path(self):
|
||||
if not self.participants_label:
|
||||
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
|
||||
self.folder = here("machine_learning/intermediate_results/features", warn=True)
|
||||
self.filename_prefix = (
|
||||
self.participants_label + "_" + self.grouping_variable_name
|
||||
)
|
||||
|
||||
def set_participants_label(self, label: str):
|
||||
self.participants_label = label
|
||||
self.construct_export_path()
|
||||
|
||||
|
||||
class Labels:
|
||||
def __init__(
|
||||
self,
|
||||
grouping_variable: list,
|
||||
labels: dict,
|
||||
participants_usernames: Collection = None,
|
||||
):
|
||||
self.grouping_variable = grouping_variable
|
||||
|
||||
self.questionnaires = labels.keys()
|
||||
|
||||
if participants_usernames is None:
|
||||
participants_usernames = participants.query_db.get_usernames(
|
||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||
)
|
||||
self.participants_usernames = participants_usernames
|
||||
|
||||
self.df_esm = pd.DataFrame()
|
||||
self.df_esm_preprocessed = pd.DataFrame()
|
||||
self.df_esm_interest = pd.DataFrame()
|
||||
self.df_esm_clean = pd.DataFrame()
|
||||
|
||||
self.df_esm_means = pd.DataFrame()
|
||||
print("Labels initialized.")
|
||||
|
||||
def set_labels(self):
|
||||
print("Querying database ...")
|
||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||
print("Got ESM data from the DB.")
|
||||
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
||||
print("ESM data preprocessed.")
|
||||
if "PANAS" in self.questionnaires:
|
||||
self.df_esm_interest = self.df_esm_preprocessed[
|
||||
(
|
||||
self.df_esm_preprocessed["questionnaire_id"]
|
||||
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
||||
)
|
||||
| (
|
||||
self.df_esm_preprocessed["questionnaire_id"]
|
||||
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
||||
)
|
||||
]
|
||||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||
print("ESM data cleaned.")
|
||||
|
||||
def get_labels(self, questionnaire):
|
||||
if questionnaire == "PANAS":
|
||||
return self.df_esm_clean
|
||||
else:
|
||||
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||
|
||||
def aggregate_labels(self):
|
||||
print("Aggregating labels ...")
|
||||
self.df_esm_means = (
|
||||
self.df_esm_clean.groupby(
|
||||
["participant_id", "questionnaire_id"] + self.grouping_variable
|
||||
)
|
||||
.esm_user_answer_numeric.agg("mean")
|
||||
.reset_index()
|
||||
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||
)
|
||||
self.df_esm_means = (
|
||||
self.df_esm_means.pivot(
|
||||
index=["participant_id"] + self.grouping_variable,
|
||||
columns="questionnaire_id",
|
||||
values="esm_numeric_mean",
|
||||
)
|
||||
.reset_index(col_level=1)
|
||||
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
||||
.set_index(["participant_id"] + self.grouping_variable)
|
||||
)
|
||||
print("Labels aggregated.")
|
||||
|
||||
def get_aggregated_labels(self):
|
||||
return self.df_esm_means
|
||||
|
||||
|
||||
class ModelValidation:
|
||||
def __init__(self, X, y, group_variable=None, cv_name="loso"):
|
||||
self.model = None
|
||||
self.cv = None
|
||||
|
||||
idx_common = X.index.intersection(y.index)
|
||||
self.y = y.loc[idx_common, "NA"]
|
||||
# TODO Handle the case of multiple labels.
|
||||
self.X = X.loc[idx_common]
|
||||
self.groups = self.y.index.get_level_values(group_variable)
|
||||
|
||||
self.cv_name = cv_name
|
||||
print("ModelValidation initialized.")
|
||||
|
||||
def set_cv_method(self):
|
||||
if self.cv_name == "loso":
|
||||
self.cv = LeaveOneGroupOut()
|
||||
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
|
||||
print("Validation method set.")
|
||||
|
||||
def cross_validate(self):
|
||||
print("Running cross validation ...")
|
||||
if self.model is None:
|
||||
raise TypeError(
|
||||
"Please, specify a machine learning model first, by setting the .model attribute. "
|
||||
"E.g. self.model = sklearn.linear_model.LinearRegression()"
|
||||
)
|
||||
if self.cv is None:
|
||||
raise TypeError(
|
||||
"Please, specify a cross validation method first, by using set_cv_method() first."
|
||||
)
|
||||
if self.X.isna().any().any() or self.y.isna().any().any():
|
||||
raise ValueError(
|
||||
"NaNs were found in either X or y. Please, check your data before continuing."
|
||||
)
|
||||
return cross_val_score(
|
||||
estimator=self.model,
|
||||
X=self.X,
|
||||
y=self.y,
|
||||
groups=self.groups,
|
||||
cv=self.cv,
|
||||
n_jobs=-1,
|
||||
scoring="r2",
|
||||
)
|
||||
|
||||
|
||||
def safe_outer_merge_on_index(left, right):
|
||||
if left.empty:
|
||||
return right
|
||||
elif right.empty:
|
||||
return left
|
||||
else:
|
||||
return pd.merge(
|
||||
left,
|
||||
right,
|
||||
how="outer",
|
||||
left_index=True,
|
||||
right_index=True,
|
||||
validate="one_to_one",
|
||||
)
|
||||
|
||||
|
||||
def to_csv_with_settings(
|
||||
df: pd.DataFrame, folder: Path, filename_prefix: str, data_type: str
|
||||
) -> None:
|
||||
export_filename = filename_prefix + "_" + data_type + ".csv"
|
||||
full_path = folder / export_filename
|
||||
df.to_csv(
|
||||
path_or_buf=full_path,
|
||||
sep=",",
|
||||
na_rep="NA",
|
||||
header=True,
|
||||
index=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
print("Exported the dataframe to " + str(full_path))
|
||||
|
||||
from machine_learning.features_sensor import SensorFeatures
|
||||
from machine_learning.labels import Labels
|
||||
from machine_learning.model import ModelValidation
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("./config/prox_comm_PANAS_features.yaml", "r") as file:
|
||||
|
|
Loading…
Reference in New Issue