Add export capabilities to labels.py.
parent
20748890a8
commit
ed062d25ee
|
@ -218,14 +218,9 @@ with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
||||||
# %%
|
# %%
|
||||||
labels = machine_learning.labels.Labels(**labels_params)
|
labels = machine_learning.labels.Labels(**labels_params)
|
||||||
labels.participants_usernames = ptcp_2
|
labels.participants_usernames = ptcp_2
|
||||||
|
labels.set_participants_label("nokia_0000003")
|
||||||
labels.questionnaires
|
labels.questionnaires
|
||||||
|
|
||||||
# %%
|
|
||||||
all_features = sensor_features.get_features("all", "all")
|
|
||||||
|
|
||||||
# %%
|
|
||||||
all_features.isna().any().any()
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
labels.set_labels()
|
labels.set_labels()
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ class SensorFeatures:
|
||||||
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
||||||
print("Got sms data from the DB.")
|
print("Got sms data from the DB.")
|
||||||
|
|
||||||
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
def get_sensor_data(self, data_type: str) -> pd.DataFrame:
|
||||||
if data_type == "proximity":
|
if data_type == "proximity":
|
||||||
return self.df_proximity
|
return self.df_proximity
|
||||||
elif data_type == "communication":
|
elif data_type == "communication":
|
||||||
|
|
|
@ -1,11 +1,21 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
from typing import Collection
|
from typing import Collection
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pyprojroot import here
|
||||||
|
|
||||||
import participants.query_db
|
import participants.query_db
|
||||||
from features import esm
|
from features import esm
|
||||||
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||||
|
from machine_learning.helper import to_csv_with_settings
|
||||||
|
|
||||||
|
WARNING_PARTICIPANTS_LABEL = (
|
||||||
|
"Before aggregating labels, please set participants label using self.set_participants_label() "
|
||||||
|
"to be used as a filename prefix when exporting data. "
|
||||||
|
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Labels:
|
class Labels:
|
||||||
|
@ -14,12 +24,13 @@ class Labels:
|
||||||
grouping_variable: str,
|
grouping_variable: str,
|
||||||
labels: dict,
|
labels: dict,
|
||||||
participants_usernames: Collection = None,
|
participants_usernames: Collection = None,
|
||||||
):
|
) -> None:
|
||||||
self.grouping_variable_name = grouping_variable
|
self.grouping_variable_name = grouping_variable
|
||||||
self.grouping_variable = [grouping_variable]
|
self.grouping_variable = [grouping_variable]
|
||||||
|
|
||||||
self.questionnaires = labels.keys()
|
self.questionnaires = labels.keys()
|
||||||
|
|
||||||
|
self.participants_label: str = ""
|
||||||
if participants_usernames is None:
|
if participants_usernames is None:
|
||||||
participants_usernames = participants.query_db.get_usernames(
|
participants_usernames = participants.query_db.get_usernames(
|
||||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||||
|
@ -32,9 +43,13 @@ class Labels:
|
||||||
self.df_esm_clean = pd.DataFrame()
|
self.df_esm_clean = pd.DataFrame()
|
||||||
|
|
||||||
self.df_esm_means = pd.DataFrame()
|
self.df_esm_means = pd.DataFrame()
|
||||||
|
|
||||||
|
self.folder: Path = Path()
|
||||||
|
self.filename_prefix = ""
|
||||||
|
self.construct_export_path()
|
||||||
print("Labels initialized.")
|
print("Labels initialized.")
|
||||||
|
|
||||||
def set_labels(self):
|
def set_labels(self) -> None:
|
||||||
print("Querying database ...")
|
print("Querying database ...")
|
||||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||||
print("Got ESM data from the DB.")
|
print("Got ESM data from the DB.")
|
||||||
|
@ -54,13 +69,13 @@ class Labels:
|
||||||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||||
print("ESM data cleaned.")
|
print("ESM data cleaned.")
|
||||||
|
|
||||||
def get_labels(self, questionnaire):
|
def get_labels(self, questionnaire: str) -> pd.DataFrame:
|
||||||
if questionnaire == "PANAS":
|
if questionnaire == "PANAS":
|
||||||
return self.df_esm_clean
|
return self.df_esm_clean
|
||||||
else:
|
else:
|
||||||
raise KeyError("This questionnaire has not been implemented as a label.")
|
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||||
|
|
||||||
def aggregate_labels(self):
|
def aggregate_labels(self) -> None:
|
||||||
print("Aggregating labels ...")
|
print("Aggregating labels ...")
|
||||||
self.df_esm_means = (
|
self.df_esm_means = (
|
||||||
self.df_esm_clean.groupby(
|
self.df_esm_clean.groupby(
|
||||||
|
@ -81,6 +96,24 @@ class Labels:
|
||||||
.set_index(["participant_id"] + self.grouping_variable)
|
.set_index(["participant_id"] + self.grouping_variable)
|
||||||
)
|
)
|
||||||
print("Labels aggregated.")
|
print("Labels aggregated.")
|
||||||
|
to_csv_with_settings(
|
||||||
|
self.df_esm_means,
|
||||||
|
self.folder,
|
||||||
|
self.filename_prefix,
|
||||||
|
data_type="_".join(self.questionnaires),
|
||||||
|
)
|
||||||
|
|
||||||
def get_aggregated_labels(self):
|
def get_aggregated_labels(self) -> pd.DataFrame:
|
||||||
return self.df_esm_means
|
return self.df_esm_means
|
||||||
|
|
||||||
|
def construct_export_path(self) -> None:
|
||||||
|
if not self.participants_label:
|
||||||
|
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
|
||||||
|
self.folder = here("machine_learning/intermediate_results/labels", warn=True)
|
||||||
|
self.filename_prefix = (
|
||||||
|
self.participants_label + "_" + self.grouping_variable_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_participants_label(self, label: str) -> None:
|
||||||
|
self.participants_label = label
|
||||||
|
self.construct_export_path()
|
||||||
|
|
Loading…
Reference in New Issue