Add export capabilities to labels.py.
parent
20748890a8
commit
ed062d25ee
|
@ -218,14 +218,9 @@ with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
|||
# %%
|
||||
labels = machine_learning.labels.Labels(**labels_params)
|
||||
labels.participants_usernames = ptcp_2
|
||||
labels.set_participants_label("nokia_0000003")
|
||||
labels.questionnaires
|
||||
|
||||
# %%
|
||||
all_features = sensor_features.get_features("all", "all")
|
||||
|
||||
# %%
|
||||
all_features.isna().any().any()
|
||||
|
||||
# %%
|
||||
labels.set_labels()
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ class SensorFeatures:
|
|||
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
||||
print("Got sms data from the DB.")
|
||||
|
||||
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
||||
def get_sensor_data(self, data_type: str) -> pd.DataFrame:
|
||||
if data_type == "proximity":
|
||||
return self.df_proximity
|
||||
elif data_type == "communication":
|
||||
|
|
|
@ -1,11 +1,21 @@
|
|||
import datetime
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Collection
|
||||
|
||||
import pandas as pd
|
||||
from pyprojroot import here
|
||||
|
||||
import participants.query_db
|
||||
from features import esm
|
||||
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||
from machine_learning.helper import to_csv_with_settings
|
||||
|
||||
WARNING_PARTICIPANTS_LABEL = (
|
||||
"Before aggregating labels, please set participants label using self.set_participants_label() "
|
||||
"to be used as a filename prefix when exporting data. "
|
||||
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
|
||||
)
|
||||
|
||||
|
||||
class Labels:
|
||||
|
@ -14,12 +24,13 @@ class Labels:
|
|||
grouping_variable: str,
|
||||
labels: dict,
|
||||
participants_usernames: Collection = None,
|
||||
):
|
||||
) -> None:
|
||||
self.grouping_variable_name = grouping_variable
|
||||
self.grouping_variable = [grouping_variable]
|
||||
|
||||
self.questionnaires = labels.keys()
|
||||
|
||||
self.participants_label: str = ""
|
||||
if participants_usernames is None:
|
||||
participants_usernames = participants.query_db.get_usernames(
|
||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||
|
@ -32,9 +43,13 @@ class Labels:
|
|||
self.df_esm_clean = pd.DataFrame()
|
||||
|
||||
self.df_esm_means = pd.DataFrame()
|
||||
|
||||
self.folder: Path = Path()
|
||||
self.filename_prefix = ""
|
||||
self.construct_export_path()
|
||||
print("Labels initialized.")
|
||||
|
||||
def set_labels(self):
|
||||
def set_labels(self) -> None:
|
||||
print("Querying database ...")
|
||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||
print("Got ESM data from the DB.")
|
||||
|
@ -54,13 +69,13 @@ class Labels:
|
|||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||
print("ESM data cleaned.")
|
||||
|
||||
def get_labels(self, questionnaire):
|
||||
def get_labels(self, questionnaire: str) -> pd.DataFrame:
|
||||
if questionnaire == "PANAS":
|
||||
return self.df_esm_clean
|
||||
else:
|
||||
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||
|
||||
def aggregate_labels(self):
|
||||
def aggregate_labels(self) -> None:
|
||||
print("Aggregating labels ...")
|
||||
self.df_esm_means = (
|
||||
self.df_esm_clean.groupby(
|
||||
|
@ -81,6 +96,24 @@ class Labels:
|
|||
.set_index(["participant_id"] + self.grouping_variable)
|
||||
)
|
||||
print("Labels aggregated.")
|
||||
to_csv_with_settings(
|
||||
self.df_esm_means,
|
||||
self.folder,
|
||||
self.filename_prefix,
|
||||
data_type="_".join(self.questionnaires),
|
||||
)
|
||||
|
||||
def get_aggregated_labels(self):
|
||||
def get_aggregated_labels(self) -> pd.DataFrame:
|
||||
return self.df_esm_means
|
||||
|
||||
def construct_export_path(self) -> None:
|
||||
if not self.participants_label:
|
||||
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
|
||||
self.folder = here("machine_learning/intermediate_results/labels", warn=True)
|
||||
self.filename_prefix = (
|
||||
self.participants_label + "_" + self.grouping_variable_name
|
||||
)
|
||||
|
||||
def set_participants_label(self, label: str) -> None:
|
||||
self.participants_label = label
|
||||
self.construct_export_path()
|
||||
|
|
Loading…
Reference in New Issue