Add a similar class for labels.
parent
97c693d252
commit
6592612db7
|
@ -180,3 +180,17 @@ sensor_features.calculate_features()
|
||||||
sensor_features.get_features("proximity", "all")
|
sensor_features.get_features("proximity", "all")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
||||||
|
labels_params = yaml.safe_load(file)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
labels = pipeline.Labels(**labels_params)
|
||||||
|
labels.questionnaires
|
||||||
|
|
||||||
|
# %%
|
||||||
|
labels.set_labels()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
labels.get_labels("PANAS")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
grouping_variable: date_lj
|
||||||
|
labels:
|
||||||
|
PANAS:
|
||||||
|
- PA
|
||||||
|
- NA
|
||||||
|
participants_usernames: [nokia_0000003]
|
|
@ -79,6 +79,51 @@ class SensorFeatures:
|
||||||
raise KeyError("This data type has not been implemented.")
|
raise KeyError("This data type has not been implemented.")
|
||||||
|
|
||||||
|
|
||||||
|
class Labels:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grouping_variable: str,
|
||||||
|
labels: dict,
|
||||||
|
participants_usernames: Collection = None,
|
||||||
|
):
|
||||||
|
self.grouping_variable = grouping_variable
|
||||||
|
|
||||||
|
self.questionnaires = labels.keys()
|
||||||
|
|
||||||
|
if participants_usernames is None:
|
||||||
|
participants_usernames = participants.query_db.get_usernames(
|
||||||
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||||
|
)
|
||||||
|
self.participants_usernames = participants_usernames
|
||||||
|
|
||||||
|
self.df_esm = pd.DataFrame()
|
||||||
|
self.df_esm_preprocessed = pd.DataFrame()
|
||||||
|
self.df_esm_interest = pd.DataFrame()
|
||||||
|
self.df_esm_clean = pd.DataFrame()
|
||||||
|
|
||||||
|
def set_labels(self):
|
||||||
|
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||||
|
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
||||||
|
if "PANAS" in self.questionnaires:
|
||||||
|
self.df_esm_interest = self.df_esm_preprocessed[
|
||||||
|
(
|
||||||
|
self.df_esm_preprocessed["questionnaire_id"]
|
||||||
|
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
self.df_esm_preprocessed["questionnaire_id"]
|
||||||
|
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||||
|
|
||||||
|
def get_labels(self, questionnaire):
|
||||||
|
if questionnaire == "PANAS":
|
||||||
|
return self.df_esm_clean
|
||||||
|
else:
|
||||||
|
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||||
|
|
||||||
|
|
||||||
class MachineLearningPipeline:
|
class MachineLearningPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -117,21 +162,21 @@ class MachineLearningPipeline:
|
||||||
self.df_esm_daily_means = pd.DataFrame()
|
self.df_esm_daily_means = pd.DataFrame()
|
||||||
self.df_proximity_daily_counts = pd.DataFrame()
|
self.df_proximity_daily_counts = pd.DataFrame()
|
||||||
|
|
||||||
def get_labels(self):
|
# def get_labels(self):
|
||||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
# self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||||
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
# self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
||||||
if self.labels_questionnaire == "PANAS":
|
# if self.labels_questionnaire == "PANAS":
|
||||||
self.df_esm_interest = self.df_esm_preprocessed[
|
# self.df_esm_interest = self.df_esm_preprocessed[
|
||||||
(
|
# (
|
||||||
self.df_esm_preprocessed["questionnaire_id"]
|
# self.df_esm_preprocessed["questionnaire_id"]
|
||||||
== QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
# == QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
||||||
)
|
# )
|
||||||
| (
|
# | (
|
||||||
self.df_esm_preprocessed["questionnaire_id"]
|
# self.df_esm_preprocessed["questionnaire_id"]
|
||||||
== QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
# == QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
||||||
)
|
# )
|
||||||
]
|
# ]
|
||||||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
# self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||||
|
|
||||||
# def aggregate_daily(self):
|
# def aggregate_daily(self):
|
||||||
# self.df_esm_daily_means = (
|
# self.df_esm_daily_means = (
|
||||||
|
|
Loading…
Reference in New Issue