2021-09-13 11:41:57 +02:00
|
|
|
import datetime
|
2021-09-15 15:36:36 +02:00
|
|
|
import warnings
|
|
|
|
from pathlib import Path
|
2021-09-13 11:41:57 +02:00
|
|
|
from typing import Collection
|
|
|
|
|
|
|
|
import pandas as pd
|
2021-09-15 15:36:36 +02:00
|
|
|
from pyprojroot import here
|
2021-09-13 11:41:57 +02:00
|
|
|
|
|
|
|
import participants.query_db
|
|
|
|
from features import esm
|
|
|
|
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
2021-10-29 12:07:12 +02:00
|
|
|
from machine_learning.helper import read_csv_with_settings, to_csv_with_settings
|
2021-09-15 15:36:36 +02:00
|
|
|
|
|
|
|
WARNING_PARTICIPANTS_LABEL = (
|
|
|
|
"Before aggregating labels, please set participants label using self.set_participants_label() "
|
|
|
|
"to be used as a filename prefix when exporting data. "
|
|
|
|
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
|
|
|
|
)
|
2021-09-13 11:41:57 +02:00
|
|
|
|
|
|
|
|
|
|
|
class Labels:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
grouping_variable: str,
|
|
|
|
labels: dict,
|
|
|
|
participants_usernames: Collection = None,
|
2021-09-15 15:36:36 +02:00
|
|
|
) -> None:
|
2021-09-13 11:41:57 +02:00
|
|
|
self.grouping_variable_name = grouping_variable
|
|
|
|
self.grouping_variable = [grouping_variable]
|
|
|
|
|
|
|
|
self.questionnaires = labels.keys()
|
|
|
|
|
2021-09-15 15:36:36 +02:00
|
|
|
self.participants_label: str = ""
|
2021-09-13 11:41:57 +02:00
|
|
|
if participants_usernames is None:
|
|
|
|
participants_usernames = participants.query_db.get_usernames(
|
|
|
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
|
|
|
)
|
2021-10-13 13:39:58 +02:00
|
|
|
self.participants_label = "all"
|
2021-09-13 11:41:57 +02:00
|
|
|
self.participants_usernames = participants_usernames
|
|
|
|
|
|
|
|
self.df_esm = pd.DataFrame()
|
|
|
|
self.df_esm_preprocessed = pd.DataFrame()
|
|
|
|
self.df_esm_interest = pd.DataFrame()
|
|
|
|
self.df_esm_clean = pd.DataFrame()
|
|
|
|
|
|
|
|
self.df_esm_means = pd.DataFrame()
|
2021-09-15 15:36:36 +02:00
|
|
|
|
|
|
|
self.folder: Path = Path()
|
|
|
|
self.filename_prefix = ""
|
|
|
|
self.construct_export_path()
|
2021-09-13 11:41:57 +02:00
|
|
|
print("Labels initialized.")
|
|
|
|
|
2021-09-15 15:36:36 +02:00
|
|
|
def set_labels(self) -> None:
|
2021-09-13 11:41:57 +02:00
|
|
|
print("Querying database ...")
|
|
|
|
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
|
|
|
print("Got ESM data from the DB.")
|
|
|
|
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
|
|
|
print("ESM data preprocessed.")
|
|
|
|
if "PANAS" in self.questionnaires:
|
|
|
|
self.df_esm_interest = self.df_esm_preprocessed[
|
|
|
|
(
|
|
|
|
self.df_esm_preprocessed["questionnaire_id"]
|
|
|
|
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
|
|
|
)
|
|
|
|
| (
|
|
|
|
self.df_esm_preprocessed["questionnaire_id"]
|
|
|
|
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
|
|
|
)
|
|
|
|
]
|
|
|
|
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
|
|
|
print("ESM data cleaned.")
|
|
|
|
|
2021-09-15 15:36:36 +02:00
|
|
|
def get_labels(self, questionnaire: str) -> pd.DataFrame:
|
2021-09-13 11:41:57 +02:00
|
|
|
if questionnaire == "PANAS":
|
|
|
|
return self.df_esm_clean
|
|
|
|
else:
|
|
|
|
raise KeyError("This questionnaire has not been implemented as a label.")
|
|
|
|
|
2021-09-15 15:45:49 +02:00
|
|
|
def aggregate_labels(self, cached=True) -> None:
|
2021-09-13 11:41:57 +02:00
|
|
|
print("Aggregating labels ...")
|
2021-09-15 15:45:49 +02:00
|
|
|
if not self.participants_label:
|
|
|
|
raise ValueError(WARNING_PARTICIPANTS_LABEL)
|
|
|
|
|
|
|
|
try:
|
|
|
|
if not cached: # Do not use the file, even if it exists.
|
|
|
|
raise FileNotFoundError
|
|
|
|
self.df_esm_means = read_csv_with_settings(
|
|
|
|
self.folder,
|
|
|
|
self.filename_prefix,
|
|
|
|
data_type="_".join(self.questionnaires),
|
2021-10-29 12:07:12 +02:00
|
|
|
grouping_variable=self.grouping_variable,
|
2021-09-13 11:41:57 +02:00
|
|
|
)
|
2021-09-15 15:45:49 +02:00
|
|
|
print("Read labels from the file.")
|
|
|
|
except FileNotFoundError:
|
|
|
|
# We need to recalculate the features in this case.
|
|
|
|
self.df_esm_means = (
|
|
|
|
self.df_esm_clean.groupby(
|
|
|
|
["participant_id", "questionnaire_id"] + self.grouping_variable
|
|
|
|
)
|
|
|
|
.esm_user_answer_numeric.agg("mean")
|
|
|
|
.reset_index()
|
|
|
|
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
|
|
|
)
|
|
|
|
self.df_esm_means = (
|
|
|
|
self.df_esm_means.pivot(
|
|
|
|
index=["participant_id"] + self.grouping_variable,
|
|
|
|
columns="questionnaire_id",
|
|
|
|
values="esm_numeric_mean",
|
|
|
|
)
|
|
|
|
.reset_index(col_level=1)
|
|
|
|
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
|
|
|
.set_index(["participant_id"] + self.grouping_variable)
|
|
|
|
)
|
|
|
|
print("Labels aggregated.")
|
|
|
|
to_csv_with_settings(
|
|
|
|
self.df_esm_means,
|
|
|
|
self.folder,
|
|
|
|
self.filename_prefix,
|
|
|
|
data_type="_".join(self.questionnaires),
|
2021-09-13 11:41:57 +02:00
|
|
|
)
|
|
|
|
|
2021-09-15 15:36:36 +02:00
|
|
|
def get_aggregated_labels(self) -> pd.DataFrame:
|
2021-09-13 17:43:47 +02:00
|
|
|
return self.df_esm_means
|
2021-09-15 15:36:36 +02:00
|
|
|
|
|
|
|
def construct_export_path(self) -> None:
|
|
|
|
if not self.participants_label:
|
|
|
|
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
|
|
|
|
self.folder = here("machine_learning/intermediate_results/labels", warn=True)
|
|
|
|
self.filename_prefix = (
|
|
|
|
self.participants_label + "_" + self.grouping_variable_name
|
|
|
|
)
|
|
|
|
|
|
|
|
def set_participants_label(self, label: str) -> None:
|
|
|
|
self.participants_label = label
|
|
|
|
self.construct_export_path()
|