stress_at_work_analysis/machine_learning/pipeline.py

328 lines
11 KiB
Python
Raw Normal View History

import datetime
2021-08-23 16:36:26 +02:00
import warnings
from collections.abc import Collection
2021-08-23 16:36:26 +02:00
from pathlib import Path
2021-08-21 19:04:09 +02:00
import numpy as np
import pandas as pd
2021-08-21 19:04:09 +02:00
import yaml
2021-08-23 16:36:26 +02:00
from pyprojroot import here
2021-08-21 19:04:09 +02:00
from sklearn import linear_model
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
import participants.query_db
from features import communication, esm, helper, proximity
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
2021-08-23 16:36:26 +02:00
WARNING_PARTICIPANTS_LABEL = (
"Before calculating features, please set participants label using self.set_participants_label() "
"to be used as a filename prefix when exporting data. "
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
)
class SensorFeatures:
def __init__(
self,
2021-08-23 16:36:26 +02:00
grouping_variable: str,
features: dict,
participants_usernames: Collection = None,
):
2021-08-23 16:36:26 +02:00
self.grouping_variable_name = grouping_variable
self.grouping_variable = [grouping_variable]
self.data_types = features.keys()
2021-08-23 16:36:26 +02:00
self.participants_label: str = ""
if participants_usernames is None:
participants_usernames = participants.query_db.get_usernames(
collection_start=datetime.date.fromisoformat("2020-08-01")
)
2021-08-23 16:36:26 +02:00
self.participants_label = "all"
self.participants_usernames = participants_usernames
self.df_features_all = pd.DataFrame()
self.df_proximity = pd.DataFrame()
self.df_proximity_counts = pd.DataFrame()
self.df_calls = pd.DataFrame()
self.df_sms = pd.DataFrame()
self.df_calls_sms = pd.DataFrame()
2021-08-23 16:36:26 +02:00
self.folder = None
self.filename_prefix = ""
self.construct_export_path()
2021-08-21 19:03:44 +02:00
print("SensorFeatures initialized.")
def set_sensor_data(self):
print("Querying database ...")
if "proximity" in self.data_types:
self.df_proximity = proximity.get_proximity_data(
self.participants_usernames
)
print("Got proximity data from the DB.")
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
self.df_proximity = proximity.recode_proximity(self.df_proximity)
if "communication" in self.data_types:
self.df_calls = communication.get_call_data(self.participants_usernames)
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
print("Got calls data from the DB.")
self.df_sms = communication.get_sms_data(self.participants_usernames)
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
print("Got sms data from the DB.")
def get_sensor_data(self, data_type) -> pd.DataFrame:
2021-08-19 11:47:59 +02:00
if data_type == "proximity":
return self.df_proximity
2021-08-19 17:36:26 +02:00
elif data_type == "communication":
return self.df_calls_sms
2021-08-19 11:47:59 +02:00
else:
raise KeyError("This data type has not been implemented.")
def calculate_features(self):
print("Calculating features ...")
2021-08-23 16:36:26 +02:00
if not self.participants_label:
raise ValueError(WARNING_PARTICIPANTS_LABEL)
if "proximity" in self.data_types:
self.df_proximity_counts = proximity.count_proximity(
self.df_proximity, self.grouping_variable
)
self.df_features_all = safe_outer_merge_on_index(
self.df_features_all, self.df_proximity_counts
)
print("Calculated proximity features.")
2021-08-23 16:36:26 +02:00
to_csv_with_settings(
self.df_proximity, self.folder, self.filename_prefix, data_type="prox"
)
if "communication" in self.data_types:
self.df_calls_sms = communication.calls_sms_features(
df_calls=self.df_calls,
df_sms=self.df_sms,
group_by=self.grouping_variable,
)
self.df_features_all = safe_outer_merge_on_index(
self.df_features_all, self.df_calls_sms
)
print("Calculated communication features.")
2021-08-23 16:36:26 +02:00
to_csv_with_settings(
self.df_calls_sms, self.folder, self.filename_prefix, data_type="comm"
)
2021-08-21 19:48:57 +02:00
self.df_features_all.fillna(
2021-08-23 16:36:26 +02:00
value=proximity.FILL_NA_PROXIMITY, inplace=True, downcast="infer",
2021-08-21 19:48:57 +02:00
)
self.df_features_all.fillna(
2021-08-23 16:36:26 +02:00
value=communication.FILL_NA_CALLS_SMS_ALL, inplace=True, downcast="infer",
2021-08-21 19:48:57 +02:00
)
2021-08-19 11:47:59 +02:00
def get_features(self, data_type, feature_names) -> pd.DataFrame:
if data_type == "proximity":
if feature_names == "all":
feature_names = proximity.FEATURES_PROXIMITY
return self.df_proximity_counts[feature_names]
elif data_type == "communication":
if feature_names == "all":
feature_names = communication.FEATURES_CALLS_SMS_ALL
return self.df_calls_sms[feature_names]
elif data_type == "all":
return self.df_features_all
2021-08-19 11:47:59 +02:00
else:
raise KeyError("This data type has not been implemented.")
2021-08-23 16:36:26 +02:00
def construct_export_path(self):
if not self.participants_label:
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
self.folder = here("machine_learning/intermediate_results/features", warn=True)
self.filename_prefix = (
self.participants_label + "_" + self.grouping_variable_name
)
def set_participants_label(self, label: str):
self.participants_label = label
self.construct_export_path()
2021-08-19 17:44:04 +02:00
class Labels:
def __init__(
self,
grouping_variable: list,
2021-08-19 17:44:04 +02:00
labels: dict,
participants_usernames: Collection = None,
):
self.grouping_variable = grouping_variable
self.questionnaires = labels.keys()
if participants_usernames is None:
participants_usernames = participants.query_db.get_usernames(
collection_start=datetime.date.fromisoformat("2020-08-01")
)
self.participants_usernames = participants_usernames
self.df_esm = pd.DataFrame()
self.df_esm_preprocessed = pd.DataFrame()
self.df_esm_interest = pd.DataFrame()
self.df_esm_clean = pd.DataFrame()
self.df_esm_means = pd.DataFrame()
2021-08-21 19:03:44 +02:00
print("Labels initialized.")
2021-08-19 17:44:04 +02:00
def set_labels(self):
print("Querying database ...")
2021-08-19 17:44:04 +02:00
self.df_esm = esm.get_esm_data(self.participants_usernames)
print("Got ESM data from the DB.")
2021-08-19 17:44:04 +02:00
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
print("ESM data preprocessed.")
2021-08-19 17:44:04 +02:00
if "PANAS" in self.questionnaires:
self.df_esm_interest = self.df_esm_preprocessed[
(
self.df_esm_preprocessed["questionnaire_id"]
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
)
| (
self.df_esm_preprocessed["questionnaire_id"]
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
)
]
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
print("ESM data cleaned.")
2021-08-19 17:44:04 +02:00
def get_labels(self, questionnaire):
if questionnaire == "PANAS":
return self.df_esm_clean
else:
raise KeyError("This questionnaire has not been implemented as a label.")
def aggregate_labels(self):
print("Aggregating labels ...")
self.df_esm_means = (
self.df_esm_clean.groupby(
["participant_id", "questionnaire_id"] + self.grouping_variable
)
.esm_user_answer_numeric.agg("mean")
.reset_index()
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
)
self.df_esm_means = (
self.df_esm_means.pivot(
index=["participant_id"] + self.grouping_variable,
columns="questionnaire_id",
values="esm_numeric_mean",
)
.reset_index(col_level=1)
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
.set_index(["participant_id"] + self.grouping_variable)
)
print("Labels aggregated.")
def get_aggregated_labels(self):
return self.df_esm_means
2021-08-19 17:44:04 +02:00
class ModelValidation:
def __init__(self, X, y, group_variable=None, cv_name="loso"):
self.model = None
self.cv = None
idx_common = X.index.intersection(y.index)
self.y = y.loc[idx_common, "NA"]
# TODO Handle the case of multiple labels.
self.X = X.loc[idx_common]
self.groups = self.y.index.get_level_values(group_variable)
self.cv_name = cv_name
print("ModelValidation initialized.")
def set_cv_method(self):
if self.cv_name == "loso":
self.cv = LeaveOneGroupOut()
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
print("Validation method set.")
def cross_validate(self):
print("Running cross validation ...")
if self.model is None:
raise TypeError(
2021-08-21 19:03:44 +02:00
"Please, specify a machine learning model first, by setting the .model attribute. "
"E.g. self.model = sklearn.linear_model.LinearRegression()"
)
if self.cv is None:
2021-08-21 19:03:44 +02:00
raise TypeError(
"Please, specify a cross validation method first, by using set_cv_method() first."
)
if self.X.isna().any().any() or self.y.isna().any().any():
raise ValueError(
"NaNs were found in either X or y. Please, check your data before continuing."
)
return cross_val_score(
estimator=self.model,
X=self.X,
y=self.y,
groups=self.groups,
cv=self.cv,
n_jobs=-1,
scoring="r2",
)
def safe_outer_merge_on_index(left, right):
if left.empty:
return right
elif right.empty:
return left
else:
return pd.merge(
left,
right,
how="outer",
left_index=True,
right_index=True,
validate="one_to_one",
)
2021-08-23 16:36:26 +02:00
def to_csv_with_settings(
df: pd.DataFrame, folder: Path, filename_prefix: str, data_type: str
) -> None:
export_filename = filename_prefix + "_" + data_type + ".csv"
full_path = folder / export_filename
df.to_csv(
path_or_buf=full_path,
sep=",",
na_rep="NA",
header=True,
index=False,
encoding="utf-8",
)
print("Exported the dataframe to " + str(full_path))
2021-08-21 19:04:09 +02:00
if __name__ == "__main__":
with open("./config/prox_comm_PANAS_features.yaml", "r") as file:
sensor_features_params = yaml.safe_load(file)
sensor_features = SensorFeatures(**sensor_features_params)
sensor_features.set_sensor_data()
sensor_features.calculate_features()
with open("./config/prox_comm_PANAS_labels.yaml", "r") as file:
labels_params = yaml.safe_load(file)
labels = Labels(**labels_params)
labels.set_labels()
labels.aggregate_labels()
model_validation = ModelValidation(
sensor_features.get_features("all", "all"),
labels.get_aggregated_labels(),
group_variable="participant_id",
cv_name="loso",
)
model_validation.model = linear_model.LinearRegression()
model_validation.set_cv_method()
model_loso_r2 = model_validation.cross_validate()
print(model_loso_r2)
print(np.mean(model_loso_r2))