Add a similar class for labels.

rapids
junos 2021-08-19 17:44:04 +02:00
parent 97c693d252
commit 6592612db7
3 changed files with 80 additions and 15 deletions

View File

@ -180,3 +180,17 @@ sensor_features.calculate_features()
sensor_features.get_features("proximity", "all")
# %%
with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
labels_params = yaml.safe_load(file)
# %%
labels = pipeline.Labels(**labels_params)
labels.questionnaires
# %%
labels.set_labels()
# %%
labels.get_labels("PANAS")
# %%

View File

@ -0,0 +1,6 @@
grouping_variable: date_lj
labels:
PANAS:
- PA
- NA
participants_usernames: [nokia_0000003]

View File

@ -79,6 +79,51 @@ class SensorFeatures:
raise KeyError("This data type has not been implemented.")
class Labels:
def __init__(
self,
grouping_variable: str,
labels: dict,
participants_usernames: Collection = None,
):
self.grouping_variable = grouping_variable
self.questionnaires = labels.keys()
if participants_usernames is None:
participants_usernames = participants.query_db.get_usernames(
collection_start=datetime.date.fromisoformat("2020-08-01")
)
self.participants_usernames = participants_usernames
self.df_esm = pd.DataFrame()
self.df_esm_preprocessed = pd.DataFrame()
self.df_esm_interest = pd.DataFrame()
self.df_esm_clean = pd.DataFrame()
def set_labels(self):
self.df_esm = esm.get_esm_data(self.participants_usernames)
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
if "PANAS" in self.questionnaires:
self.df_esm_interest = self.df_esm_preprocessed[
(
self.df_esm_preprocessed["questionnaire_id"]
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
)
| (
self.df_esm_preprocessed["questionnaire_id"]
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
)
]
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
def get_labels(self, questionnaire):
if questionnaire == "PANAS":
return self.df_esm_clean
else:
raise KeyError("This questionnaire has not been implemented as a label.")
class MachineLearningPipeline:
def __init__(
self,
@ -117,21 +162,21 @@ class MachineLearningPipeline:
self.df_esm_daily_means = pd.DataFrame()
self.df_proximity_daily_counts = pd.DataFrame()
def get_labels(self):
self.df_esm = esm.get_esm_data(self.participants_usernames)
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
if self.labels_questionnaire == "PANAS":
self.df_esm_interest = self.df_esm_preprocessed[
(
self.df_esm_preprocessed["questionnaire_id"]
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
)
| (
self.df_esm_preprocessed["questionnaire_id"]
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
)
]
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
# def get_labels(self):
# self.df_esm = esm.get_esm_data(self.participants_usernames)
# self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
# if self.labels_questionnaire == "PANAS":
# self.df_esm_interest = self.df_esm_preprocessed[
# (
# self.df_esm_preprocessed["questionnaire_id"]
# == QUESTIONNAIRE_IDS.get("PANAS").get("PA")
# )
# | (
# self.df_esm_preprocessed["questionnaire_id"]
# == QUESTIONNAIRE_IDS.get("PANAS").get("NA")
# )
# ]
# self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
# def aggregate_daily(self):
# self.df_esm_daily_means = (