Add export capabilities to labels.py.

rapids
junos 2021-09-15 15:36:36 +02:00
parent 20748890a8
commit ed062d25ee
3 changed files with 40 additions and 12 deletions

View File

@ -218,14 +218,9 @@ with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
# %% # %%
labels = machine_learning.labels.Labels(**labels_params) labels = machine_learning.labels.Labels(**labels_params)
labels.participants_usernames = ptcp_2 labels.participants_usernames = ptcp_2
labels.set_participants_label("nokia_0000003")
labels.questionnaires labels.questionnaires
# %%
all_features = sensor_features.get_features("all", "all")
# %%
all_features.isna().any().any()
# %% # %%
labels.set_labels() labels.set_labels()

View File

@ -125,7 +125,7 @@ class SensorFeatures:
self.df_sms = helper.get_date_from_timestamp(self.df_sms) self.df_sms = helper.get_date_from_timestamp(self.df_sms)
print("Got sms data from the DB.") print("Got sms data from the DB.")
def get_sensor_data(self, data_type) -> pd.DataFrame: def get_sensor_data(self, data_type: str) -> pd.DataFrame:
if data_type == "proximity": if data_type == "proximity":
return self.df_proximity return self.df_proximity
elif data_type == "communication": elif data_type == "communication":

View File

@ -1,11 +1,21 @@
import datetime import datetime
import warnings
from pathlib import Path
from typing import Collection from typing import Collection
import pandas as pd import pandas as pd
from pyprojroot import here
import participants.query_db import participants.query_db
from features import esm from features import esm
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
from machine_learning.helper import to_csv_with_settings
WARNING_PARTICIPANTS_LABEL = (
"Before aggregating labels, please set participants label using self.set_participants_label() "
"to be used as a filename prefix when exporting data. "
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
)
class Labels: class Labels:
@ -14,12 +24,13 @@ class Labels:
grouping_variable: str, grouping_variable: str,
labels: dict, labels: dict,
participants_usernames: Collection = None, participants_usernames: Collection = None,
): ) -> None:
self.grouping_variable_name = grouping_variable self.grouping_variable_name = grouping_variable
self.grouping_variable = [grouping_variable] self.grouping_variable = [grouping_variable]
self.questionnaires = labels.keys() self.questionnaires = labels.keys()
self.participants_label: str = ""
if participants_usernames is None: if participants_usernames is None:
participants_usernames = participants.query_db.get_usernames( participants_usernames = participants.query_db.get_usernames(
collection_start=datetime.date.fromisoformat("2020-08-01") collection_start=datetime.date.fromisoformat("2020-08-01")
@ -32,9 +43,13 @@ class Labels:
self.df_esm_clean = pd.DataFrame() self.df_esm_clean = pd.DataFrame()
self.df_esm_means = pd.DataFrame() self.df_esm_means = pd.DataFrame()
self.folder: Path = Path()
self.filename_prefix = ""
self.construct_export_path()
print("Labels initialized.") print("Labels initialized.")
def set_labels(self): def set_labels(self) -> None:
print("Querying database ...") print("Querying database ...")
self.df_esm = esm.get_esm_data(self.participants_usernames) self.df_esm = esm.get_esm_data(self.participants_usernames)
print("Got ESM data from the DB.") print("Got ESM data from the DB.")
@ -54,13 +69,13 @@ class Labels:
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest) self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
print("ESM data cleaned.") print("ESM data cleaned.")
def get_labels(self, questionnaire): def get_labels(self, questionnaire: str) -> pd.DataFrame:
if questionnaire == "PANAS": if questionnaire == "PANAS":
return self.df_esm_clean return self.df_esm_clean
else: else:
raise KeyError("This questionnaire has not been implemented as a label.") raise KeyError("This questionnaire has not been implemented as a label.")
def aggregate_labels(self): def aggregate_labels(self) -> None:
print("Aggregating labels ...") print("Aggregating labels ...")
self.df_esm_means = ( self.df_esm_means = (
self.df_esm_clean.groupby( self.df_esm_clean.groupby(
@ -81,6 +96,24 @@ class Labels:
.set_index(["participant_id"] + self.grouping_variable) .set_index(["participant_id"] + self.grouping_variable)
) )
print("Labels aggregated.") print("Labels aggregated.")
to_csv_with_settings(
self.df_esm_means,
self.folder,
self.filename_prefix,
data_type="_".join(self.questionnaires),
)
def get_aggregated_labels(self): def get_aggregated_labels(self) -> pd.DataFrame:
return self.df_esm_means return self.df_esm_means
def construct_export_path(self) -> None:
if not self.participants_label:
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
self.folder = here("machine_learning/intermediate_results/labels", warn=True)
self.filename_prefix = (
self.participants_label + "_" + self.grouping_variable_name
)
def set_participants_label(self, label: str) -> None:
self.participants_label = label
self.construct_export_path()