321 lines
11 KiB
Python
321 lines
11 KiB
Python
import datetime
|
|
from collections.abc import Collection
|
|
|
|
import pandas as pd
|
|
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
|
|
|
import participants.query_db
|
|
from features import communication, esm, helper, proximity
|
|
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
|
|
|
|
|
class SensorFeatures:
|
|
def __init__(
|
|
self,
|
|
grouping_variable: list,
|
|
features: dict,
|
|
participants_usernames: Collection = None,
|
|
):
|
|
self.grouping_variable = grouping_variable
|
|
|
|
self.data_types = features.keys()
|
|
|
|
if participants_usernames is None:
|
|
participants_usernames = participants.query_db.get_usernames(
|
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
|
)
|
|
self.participants_usernames = participants_usernames
|
|
|
|
self.df_features_all = pd.DataFrame()
|
|
|
|
self.df_proximity = pd.DataFrame()
|
|
self.df_proximity_counts = pd.DataFrame()
|
|
|
|
self.df_calls = pd.DataFrame()
|
|
self.df_sms = pd.DataFrame()
|
|
self.df_calls_sms = pd.DataFrame()
|
|
|
|
def set_sensor_data(self):
|
|
if "proximity" in self.data_types:
|
|
self.df_proximity = proximity.get_proximity_data(
|
|
self.participants_usernames
|
|
)
|
|
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
|
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
|
if "communication" in self.data_types:
|
|
self.df_calls = communication.get_call_data(self.participants_usernames)
|
|
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
|
|
|
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
|
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
|
|
|
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
|
if data_type == "proximity":
|
|
return self.df_proximity
|
|
elif data_type == "communication":
|
|
return self.df_calls_sms
|
|
else:
|
|
raise KeyError("This data type has not been implemented.")
|
|
|
|
def calculate_features(self):
|
|
if "proximity" in self.data_types:
|
|
self.df_proximity_counts = proximity.count_proximity(
|
|
self.df_proximity, self.grouping_variable
|
|
)
|
|
self.df_features_all = safe_outer_merge_on_index(
|
|
self.df_features_all, self.df_proximity_counts
|
|
)
|
|
|
|
if "communication" in self.data_types:
|
|
self.df_calls_sms = communication.calls_sms_features(
|
|
df_calls=self.df_calls,
|
|
df_sms=self.df_sms,
|
|
group_by=self.grouping_variable,
|
|
)
|
|
self.df_features_all = safe_outer_merge_on_index(
|
|
self.df_features_all, self.df_calls_sms
|
|
)
|
|
|
|
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
|
if data_type == "proximity":
|
|
if feature_names == "all":
|
|
feature_names = proximity.FEATURES_PROXIMITY
|
|
return self.df_proximity_counts[feature_names]
|
|
elif data_type == "communication":
|
|
if feature_names == "all":
|
|
feature_names = communication.FEATURES_CALLS_SMS_ALL
|
|
return self.df_calls_sms[feature_names]
|
|
elif data_type == "all":
|
|
return self.df_features_all
|
|
else:
|
|
raise KeyError("This data type has not been implemented.")
|
|
|
|
|
|
class Labels:
|
|
def __init__(
|
|
self,
|
|
grouping_variable: list,
|
|
labels: dict,
|
|
participants_usernames: Collection = None,
|
|
):
|
|
self.grouping_variable = grouping_variable
|
|
|
|
self.questionnaires = labels.keys()
|
|
|
|
if participants_usernames is None:
|
|
participants_usernames = participants.query_db.get_usernames(
|
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
|
)
|
|
self.participants_usernames = participants_usernames
|
|
|
|
self.df_esm = pd.DataFrame()
|
|
self.df_esm_preprocessed = pd.DataFrame()
|
|
self.df_esm_interest = pd.DataFrame()
|
|
self.df_esm_clean = pd.DataFrame()
|
|
|
|
self.df_esm_means = pd.DataFrame()
|
|
|
|
def set_labels(self):
|
|
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
|
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
|
if "PANAS" in self.questionnaires:
|
|
self.df_esm_interest = self.df_esm_preprocessed[
|
|
(
|
|
self.df_esm_preprocessed["questionnaire_id"]
|
|
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
|
)
|
|
| (
|
|
self.df_esm_preprocessed["questionnaire_id"]
|
|
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
|
)
|
|
]
|
|
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
|
|
|
def get_labels(self, questionnaire):
|
|
if questionnaire == "PANAS":
|
|
return self.df_esm_clean
|
|
else:
|
|
raise KeyError("This questionnaire has not been implemented as a label.")
|
|
|
|
def aggregate_labels(self):
|
|
self.df_esm_means = (
|
|
self.df_esm_clean.groupby(
|
|
["participant_id", "questionnaire_id"] + self.grouping_variable
|
|
)
|
|
.esm_user_answer_numeric.agg("mean")
|
|
.reset_index()
|
|
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
|
)
|
|
self.df_esm_means = (
|
|
self.df_esm_means.pivot(
|
|
index=["participant_id"] + self.grouping_variable,
|
|
columns="questionnaire_id",
|
|
values="esm_numeric_mean",
|
|
)
|
|
.reset_index(col_level=1)
|
|
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
|
.set_index(["participant_id"] + self.grouping_variable)
|
|
)
|
|
|
|
def get_aggregated_labels(self):
|
|
return self.df_esm_means
|
|
|
|
|
|
class ModelValidation:
|
|
def __init__(self, X, y, group_variable=None, cv_name="loso"):
|
|
self.model = None
|
|
self.cv = None
|
|
|
|
self.y = y["NA"]
|
|
# TODO Handle the case of multiple labels.
|
|
self.X = X.loc[self.y.index]
|
|
self.groups = self.y.index.get_level_values(group_variable)
|
|
|
|
self.cv_name = cv_name
|
|
|
|
def set_cv_method(self):
|
|
if self.cv_name == "loso":
|
|
self.cv = LeaveOneGroupOut()
|
|
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
|
|
|
|
def cross_validate(self):
|
|
if self.model is None:
|
|
raise ValueError(
|
|
"Please set self.model first, e.g. self.model = sklearn.linear_model.LinearRegression()"
|
|
)
|
|
# TODO Is ValueError appropriate here?
|
|
if self.cv is None:
|
|
raise ValueError("Please use set_cv_method() first.")
|
|
return cross_val_score(
|
|
estimator=self.model,
|
|
X=self.X,
|
|
y=self.y,
|
|
groups=self.groups,
|
|
cv=self.cv,
|
|
n_jobs=-1,
|
|
scoring="r2",
|
|
)
|
|
|
|
|
|
def safe_outer_merge_on_index(left, right):
|
|
if left.empty:
|
|
return right
|
|
elif right.empty:
|
|
return left
|
|
else:
|
|
return pd.merge(
|
|
left,
|
|
right,
|
|
how="outer",
|
|
left_index=True,
|
|
right_index=True,
|
|
validate="one_to_one",
|
|
)
|
|
|
|
|
|
class MachineLearningPipeline:
|
|
def __init__(
|
|
self,
|
|
labels_questionnaire,
|
|
labels_scale,
|
|
data_types,
|
|
participants_usernames=None,
|
|
feature_names=None,
|
|
grouping_variable=None,
|
|
):
|
|
if participants_usernames is None:
|
|
participants_usernames = participants.query_db.get_usernames(
|
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
|
)
|
|
self.participants_usernames = participants_usernames
|
|
self.labels_questionnaire = labels_questionnaire
|
|
self.data_types = data_types
|
|
|
|
if feature_names is None:
|
|
self.feature_names = []
|
|
self.df_features = pd.DataFrame()
|
|
self.labels_scale = labels_scale
|
|
self.df_labels = pd.DataFrame()
|
|
self.grouping_variable = grouping_variable
|
|
self.df_groups = pd.DataFrame()
|
|
|
|
self.model = None
|
|
self.validation_method = None
|
|
|
|
self.df_esm = pd.DataFrame()
|
|
self.df_esm_preprocessed = pd.DataFrame()
|
|
self.df_esm_interest = pd.DataFrame()
|
|
self.df_esm_clean = pd.DataFrame()
|
|
|
|
self.df_full_data_daily_means = pd.DataFrame()
|
|
self.df_esm_daily_means = pd.DataFrame()
|
|
self.df_proximity_daily_counts = pd.DataFrame()
|
|
|
|
# def get_labels(self):
|
|
# self.df_esm = esm.get_esm_data(self.participants_usernames)
|
|
# self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
|
# if self.labels_questionnaire == "PANAS":
|
|
# self.df_esm_interest = self.df_esm_preprocessed[
|
|
# (
|
|
# self.df_esm_preprocessed["questionnaire_id"]
|
|
# == QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
|
# )
|
|
# | (
|
|
# self.df_esm_preprocessed["questionnaire_id"]
|
|
# == QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
|
# )
|
|
# ]
|
|
# self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
|
|
|
# def aggregate_daily(self):
|
|
# self.df_esm_daily_means = (
|
|
# self.df_esm_clean.groupby(["participant_id", "date_lj", "questionnaire_id"])
|
|
# .esm_user_answer_numeric.agg("mean")
|
|
# .reset_index()
|
|
# .rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
|
# )
|
|
# self.df_esm_daily_means = (
|
|
# self.df_esm_daily_means.pivot(
|
|
# index=["participant_id", "date_lj"],
|
|
# columns="questionnaire_id",
|
|
# values="esm_numeric_mean",
|
|
# )
|
|
# .reset_index(col_level=1)
|
|
# .rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
|
# .set_index(["participant_id", "date_lj"])
|
|
# )
|
|
# self.df_full_data_daily_means = self.df_esm_daily_means.copy()
|
|
# if "proximity" in self.data_types:
|
|
# self.df_proximity_daily_counts = proximity.count_proximity(
|
|
# self.df_proximity, ["participant_id", "date_lj"]
|
|
# )
|
|
# self.df_full_data_daily_means = self.df_full_data_daily_means.join(
|
|
# self.df_proximity_daily_counts
|
|
# )
|
|
|
|
def assign_columns(self):
|
|
self.df_features = self.df_full_data_daily_means[self.feature_names]
|
|
self.df_labels = self.df_full_data_daily_means[self.labels_scale]
|
|
if self.grouping_variable:
|
|
self.df_groups = self.df_full_data_daily_means[self.grouping_variable]
|
|
else:
|
|
self.df_groups = None
|
|
|
|
def validate_model(self):
|
|
if self.model is None:
|
|
raise AttributeError(
|
|
"Please, specify a machine learning model first, by setting the .model attribute."
|
|
)
|
|
if self.validation_method is None:
|
|
raise AttributeError(
|
|
"Please, specify a cross validation method first, by setting the .validation_method attribute."
|
|
)
|
|
cross_val_score(
|
|
estimator=self.model,
|
|
X=self.df_features,
|
|
y=self.df_labels,
|
|
groups=self.df_groups,
|
|
cv=self.validation_method,
|
|
n_jobs=-1,
|
|
)
|