Add an option to read cached labels from a file.
parent
ed062d25ee
commit
b8c7606664
|
@ -228,10 +228,20 @@ labels.set_labels()
|
|||
labels.get_labels("PANAS")
|
||||
|
||||
# %%
|
||||
labels.aggregate_labels()
|
||||
labels.aggregate_labels(cached=False)
|
||||
labels_calculated = labels.get_aggregated_labels()
|
||||
|
||||
# %%
|
||||
labels.get_aggregated_labels()
|
||||
labels.aggregate_labels(cached=True)
|
||||
labels_read = labels.get_aggregated_labels()
|
||||
labels_read = labels_read.reset_index()
|
||||
labels_read["date_lj"] = labels_read["date_lj"].dt.date
|
||||
labels_read.set_index(["participant_id", "date_lj"], inplace=True)
|
||||
# date_lj column is parsed as a date and represented as Timestamp, when read from csv.
|
||||
# When calculated, it is represented as date.
|
||||
|
||||
# %%
|
||||
np.isclose(labels_read, labels_calculated).all()
|
||||
|
||||
# %%
|
||||
model_validation = machine_learning.model.ModelValidation(
|
||||
|
|
|
@ -9,7 +9,7 @@ from pyprojroot import here
|
|||
import participants.query_db
|
||||
from features import esm
|
||||
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||
from machine_learning.helper import to_csv_with_settings
|
||||
from machine_learning.helper import to_csv_with_settings, read_csv_with_settings
|
||||
|
||||
WARNING_PARTICIPANTS_LABEL = (
|
||||
"Before aggregating labels, please set participants label using self.set_participants_label() "
|
||||
|
@ -75,33 +75,48 @@ class Labels:
|
|||
else:
|
||||
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||
|
||||
def aggregate_labels(self) -> None:
|
||||
def aggregate_labels(self, cached=True) -> None:
|
||||
print("Aggregating labels ...")
|
||||
self.df_esm_means = (
|
||||
self.df_esm_clean.groupby(
|
||||
["participant_id", "questionnaire_id"] + self.grouping_variable
|
||||
if not self.participants_label:
|
||||
raise ValueError(WARNING_PARTICIPANTS_LABEL)
|
||||
|
||||
try:
|
||||
if not cached: # Do not use the file, even if it exists.
|
||||
raise FileNotFoundError
|
||||
self.df_esm_means = read_csv_with_settings(
|
||||
self.folder,
|
||||
self.filename_prefix,
|
||||
data_type="_".join(self.questionnaires),
|
||||
grouping_variable=self.grouping_variable
|
||||
)
|
||||
.esm_user_answer_numeric.agg("mean")
|
||||
.reset_index()
|
||||
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||
)
|
||||
self.df_esm_means = (
|
||||
self.df_esm_means.pivot(
|
||||
index=["participant_id"] + self.grouping_variable,
|
||||
columns="questionnaire_id",
|
||||
values="esm_numeric_mean",
|
||||
print("Read labels from the file.")
|
||||
except FileNotFoundError:
|
||||
# We need to recalculate the features in this case.
|
||||
self.df_esm_means = (
|
||||
self.df_esm_clean.groupby(
|
||||
["participant_id", "questionnaire_id"] + self.grouping_variable
|
||||
)
|
||||
.esm_user_answer_numeric.agg("mean")
|
||||
.reset_index()
|
||||
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||
)
|
||||
self.df_esm_means = (
|
||||
self.df_esm_means.pivot(
|
||||
index=["participant_id"] + self.grouping_variable,
|
||||
columns="questionnaire_id",
|
||||
values="esm_numeric_mean",
|
||||
)
|
||||
.reset_index(col_level=1)
|
||||
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
||||
.set_index(["participant_id"] + self.grouping_variable)
|
||||
)
|
||||
print("Labels aggregated.")
|
||||
to_csv_with_settings(
|
||||
self.df_esm_means,
|
||||
self.folder,
|
||||
self.filename_prefix,
|
||||
data_type="_".join(self.questionnaires),
|
||||
)
|
||||
.reset_index(col_level=1)
|
||||
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
||||
.set_index(["participant_id"] + self.grouping_variable)
|
||||
)
|
||||
print("Labels aggregated.")
|
||||
to_csv_with_settings(
|
||||
self.df_esm_means,
|
||||
self.folder,
|
||||
self.filename_prefix,
|
||||
data_type="_".join(self.questionnaires),
|
||||
)
|
||||
|
||||
def get_aggregated_labels(self) -> pd.DataFrame:
|
||||
return self.df_esm_means
|
||||
|
|
Loading…
Reference in New Issue