Add an option to read cached labels from a file.

rapids
junos 2021-09-15 15:45:49 +02:00
parent ed062d25ee
commit b8c7606664
2 changed files with 52 additions and 27 deletions

View File

@ -228,10 +228,20 @@ labels.set_labels()
labels.get_labels("PANAS")
# %%
labels.aggregate_labels()
labels.aggregate_labels(cached=False)
labels_calculated = labels.get_aggregated_labels()
# %%
labels.get_aggregated_labels()
labels.aggregate_labels(cached=True)
labels_read = labels.get_aggregated_labels()
labels_read = labels_read.reset_index()
labels_read["date_lj"] = labels_read["date_lj"].dt.date
labels_read.set_index(["participant_id", "date_lj"], inplace=True)
# date_lj column is parsed as a date and represented as Timestamp, when read from csv.
# When calculated, it is represented as date.
# %%
np.isclose(labels_read, labels_calculated).all()
# %%
model_validation = machine_learning.model.ModelValidation(

View File

@ -9,7 +9,7 @@ from pyprojroot import here
import participants.query_db
from features import esm
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
from machine_learning.helper import to_csv_with_settings
from machine_learning.helper import to_csv_with_settings, read_csv_with_settings
WARNING_PARTICIPANTS_LABEL = (
"Before aggregating labels, please set participants label using self.set_participants_label() "
@ -75,33 +75,48 @@ class Labels:
else:
raise KeyError("This questionnaire has not been implemented as a label.")
def aggregate_labels(self) -> None:
def aggregate_labels(self, cached=True) -> None:
print("Aggregating labels ...")
self.df_esm_means = (
self.df_esm_clean.groupby(
["participant_id", "questionnaire_id"] + self.grouping_variable
if not self.participants_label:
raise ValueError(WARNING_PARTICIPANTS_LABEL)
try:
if not cached: # Do not use the file, even if it exists.
raise FileNotFoundError
self.df_esm_means = read_csv_with_settings(
self.folder,
self.filename_prefix,
data_type="_".join(self.questionnaires),
grouping_variable=self.grouping_variable
)
.esm_user_answer_numeric.agg("mean")
.reset_index()
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
)
self.df_esm_means = (
self.df_esm_means.pivot(
index=["participant_id"] + self.grouping_variable,
columns="questionnaire_id",
values="esm_numeric_mean",
print("Read labels from the file.")
except FileNotFoundError:
# We need to recalculate the features in this case.
self.df_esm_means = (
self.df_esm_clean.groupby(
["participant_id", "questionnaire_id"] + self.grouping_variable
)
.esm_user_answer_numeric.agg("mean")
.reset_index()
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
)
self.df_esm_means = (
self.df_esm_means.pivot(
index=["participant_id"] + self.grouping_variable,
columns="questionnaire_id",
values="esm_numeric_mean",
)
.reset_index(col_level=1)
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
.set_index(["participant_id"] + self.grouping_variable)
)
print("Labels aggregated.")
to_csv_with_settings(
self.df_esm_means,
self.folder,
self.filename_prefix,
data_type="_".join(self.questionnaires),
)
.reset_index(col_level=1)
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
.set_index(["participant_id"] + self.grouping_variable)
)
print("Labels aggregated.")
to_csv_with_settings(
self.df_esm_means,
self.folder,
self.filename_prefix,
data_type="_".join(self.questionnaires),
)
def get_aggregated_labels(self) -> pd.DataFrame:
return self.df_esm_means