Add an option to read cached labels from a file.

rapids
junos 2021-09-15 15:45:49 +02:00
parent ed062d25ee
commit b8c7606664
2 changed files with 52 additions and 27 deletions

View File

@ -228,10 +228,20 @@ labels.set_labels()
labels.get_labels("PANAS") labels.get_labels("PANAS")
# %% # %%
labels.aggregate_labels() labels.aggregate_labels(cached=False)
labels_calculated = labels.get_aggregated_labels()
# %% # %%
labels.get_aggregated_labels() labels.aggregate_labels(cached=True)
labels_read = labels.get_aggregated_labels()
labels_read = labels_read.reset_index()
labels_read["date_lj"] = labels_read["date_lj"].dt.date
labels_read.set_index(["participant_id", "date_lj"], inplace=True)
# date_lj column is parsed as a date and represented as Timestamp, when read from csv.
# When calculated, it is represented as date.
# %%
np.isclose(labels_read, labels_calculated).all()
# %% # %%
model_validation = machine_learning.model.ModelValidation( model_validation = machine_learning.model.ModelValidation(

View File

@ -9,7 +9,7 @@ from pyprojroot import here
import participants.query_db import participants.query_db
from features import esm from features import esm
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
from machine_learning.helper import to_csv_with_settings from machine_learning.helper import to_csv_with_settings, read_csv_with_settings
WARNING_PARTICIPANTS_LABEL = ( WARNING_PARTICIPANTS_LABEL = (
"Before aggregating labels, please set participants label using self.set_participants_label() " "Before aggregating labels, please set participants label using self.set_participants_label() "
@ -75,33 +75,48 @@ class Labels:
else: else:
raise KeyError("This questionnaire has not been implemented as a label.") raise KeyError("This questionnaire has not been implemented as a label.")
def aggregate_labels(self) -> None: def aggregate_labels(self, cached=True) -> None:
print("Aggregating labels ...") print("Aggregating labels ...")
self.df_esm_means = ( if not self.participants_label:
self.df_esm_clean.groupby( raise ValueError(WARNING_PARTICIPANTS_LABEL)
["participant_id", "questionnaire_id"] + self.grouping_variable
try:
if not cached: # Do not use the file, even if it exists.
raise FileNotFoundError
self.df_esm_means = read_csv_with_settings(
self.folder,
self.filename_prefix,
data_type="_".join(self.questionnaires),
grouping_variable=self.grouping_variable
) )
.esm_user_answer_numeric.agg("mean") print("Read labels from the file.")
.reset_index() except FileNotFoundError:
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"}) # We need to recalculate the features in this case.
) self.df_esm_means = (
self.df_esm_means = ( self.df_esm_clean.groupby(
self.df_esm_means.pivot( ["participant_id", "questionnaire_id"] + self.grouping_variable
index=["participant_id"] + self.grouping_variable, )
columns="questionnaire_id", .esm_user_answer_numeric.agg("mean")
values="esm_numeric_mean", .reset_index()
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
)
self.df_esm_means = (
self.df_esm_means.pivot(
index=["participant_id"] + self.grouping_variable,
columns="questionnaire_id",
values="esm_numeric_mean",
)
.reset_index(col_level=1)
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
.set_index(["participant_id"] + self.grouping_variable)
)
print("Labels aggregated.")
to_csv_with_settings(
self.df_esm_means,
self.folder,
self.filename_prefix,
data_type="_".join(self.questionnaires),
) )
.reset_index(col_level=1)
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
.set_index(["participant_id"] + self.grouping_variable)
)
print("Labels aggregated.")
to_csv_with_settings(
self.df_esm_means,
self.folder,
self.filename_prefix,
data_type="_".join(self.questionnaires),
)
def get_aggregated_labels(self) -> pd.DataFrame: def get_aggregated_labels(self) -> pd.DataFrame:
return self.df_esm_means return self.df_esm_means