Add an option to read cached labels from a file.
parent
ed062d25ee
commit
b8c7606664
|
@ -228,10 +228,20 @@ labels.set_labels()
|
||||||
labels.get_labels("PANAS")
|
labels.get_labels("PANAS")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
labels.aggregate_labels()
|
labels.aggregate_labels(cached=False)
|
||||||
|
labels_calculated = labels.get_aggregated_labels()
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
labels.get_aggregated_labels()
|
labels.aggregate_labels(cached=True)
|
||||||
|
labels_read = labels.get_aggregated_labels()
|
||||||
|
labels_read = labels_read.reset_index()
|
||||||
|
labels_read["date_lj"] = labels_read["date_lj"].dt.date
|
||||||
|
labels_read.set_index(["participant_id", "date_lj"], inplace=True)
|
||||||
|
# date_lj column is parsed as a date and represented as Timestamp, when read from csv.
|
||||||
|
# When calculated, it is represented as date.
|
||||||
|
|
||||||
|
# %%
|
||||||
|
np.isclose(labels_read, labels_calculated).all()
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
model_validation = machine_learning.model.ModelValidation(
|
model_validation = machine_learning.model.ModelValidation(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from pyprojroot import here
|
||||||
import participants.query_db
|
import participants.query_db
|
||||||
from features import esm
|
from features import esm
|
||||||
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||||
from machine_learning.helper import to_csv_with_settings
|
from machine_learning.helper import to_csv_with_settings, read_csv_with_settings
|
||||||
|
|
||||||
WARNING_PARTICIPANTS_LABEL = (
|
WARNING_PARTICIPANTS_LABEL = (
|
||||||
"Before aggregating labels, please set participants label using self.set_participants_label() "
|
"Before aggregating labels, please set participants label using self.set_participants_label() "
|
||||||
|
@ -75,8 +75,23 @@ class Labels:
|
||||||
else:
|
else:
|
||||||
raise KeyError("This questionnaire has not been implemented as a label.")
|
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||||
|
|
||||||
def aggregate_labels(self) -> None:
|
def aggregate_labels(self, cached=True) -> None:
|
||||||
print("Aggregating labels ...")
|
print("Aggregating labels ...")
|
||||||
|
if not self.participants_label:
|
||||||
|
raise ValueError(WARNING_PARTICIPANTS_LABEL)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not cached: # Do not use the file, even if it exists.
|
||||||
|
raise FileNotFoundError
|
||||||
|
self.df_esm_means = read_csv_with_settings(
|
||||||
|
self.folder,
|
||||||
|
self.filename_prefix,
|
||||||
|
data_type="_".join(self.questionnaires),
|
||||||
|
grouping_variable=self.grouping_variable
|
||||||
|
)
|
||||||
|
print("Read labels from the file.")
|
||||||
|
except FileNotFoundError:
|
||||||
|
# We need to recalculate the features in this case.
|
||||||
self.df_esm_means = (
|
self.df_esm_means = (
|
||||||
self.df_esm_clean.groupby(
|
self.df_esm_clean.groupby(
|
||||||
["participant_id", "questionnaire_id"] + self.grouping_variable
|
["participant_id", "questionnaire_id"] + self.grouping_variable
|
||||||
|
|
Loading…
Reference in New Issue