Document the SensorFeatures class and its __init__ method.

rapids
junos 2021-09-13 17:43:47 +02:00
parent b19eebbb92
commit af9e81fe40
4 changed files with 62 additions and 8 deletions

View File

@ -29,13 +29,14 @@ nb_dir = os.path.split(os.getcwd())[0]
if nb_dir not in sys.path: if nb_dir not in sys.path:
sys.path.append(nb_dir) sys.path.append(nb_dir)
# %%
import participants.query_db
from features import esm, helper, proximity
import machine_learning.features_sensor import machine_learning.features_sensor
import machine_learning.labels import machine_learning.labels
import machine_learning.model import machine_learning.model
# %%
import participants.query_db
from features import esm, helper, proximity
# %% [markdown] # %% [markdown]
# # 1. Get the relevant data # # 1. Get the relevant data
@ -169,7 +170,9 @@ with open("../machine_learning/config/minimal_features.yaml", "r") as file:
print(sensor_features_params) print(sensor_features_params)
# %% # %%
sensor_features = machine_learning.features_sensor.SensorFeatures(**sensor_features_params) sensor_features = machine_learning.features_sensor.SensorFeatures(
**sensor_features_params
)
sensor_features.data_types sensor_features.data_types
# %% # %%

View File

@ -7,7 +7,7 @@ import pandas as pd
from pyprojroot import here from pyprojroot import here
import participants.query_db import participants.query_db
from features import proximity, helper, communication from features import communication, helper, proximity
WARNING_PARTICIPANTS_LABEL = ( WARNING_PARTICIPANTS_LABEL = (
"Before calculating features, please set participants label using self.set_participants_label() " "Before calculating features, please set participants label using self.set_participants_label() "
@ -17,13 +17,64 @@ WARNING_PARTICIPANTS_LABEL = (
class SensorFeatures: class SensorFeatures:
"""
A class to represent all sensor (AWARE) features.
Attributes
----------
grouping_variable: str
The name of the variable by which to group (segment) data, e.g. date_lj.
features: dict
A dictionary of sensors (data types) and features to calculate.
See config/minimal_features.yaml for an example.
participants_usernames: Collection
A list of usernames for which to calculate features.
If None, use all participants.
Methods
-------
set_sensor_data():
Query the database for data types defined by self.features.
get_sensor_data(data_type): pd.DataFrame
Returns the dataframe of sensor data for specified data_type.
calculate_features():
Calls appropriate functions from features/ and joins them in a single dataframe, df_features_all.
get_features(data_type, feature_names): pd.DataFrame
Returns the dataframe of specified features for selected sensor.
construct_export_path():
Construct a path for exporting the features as csv files.
set_participants_label(label):
Sets a label for the usernames subset. This is used to distinguish feature exports.
"""
def __init__( def __init__(
self, self,
grouping_variable: str, grouping_variable: str,
features: dict, features: dict,
participants_usernames: Collection = None, participants_usernames: Collection = None,
): ):
"""
Specifies the grouping variable and usernames for which to calculate features.
Sets other (implicit) attributes used in other methods.
If participants_usernames=None, this queries the usernames which belong to the main part of the study,
i.e. from 2020-08-01 on.
Parameters
----------
grouping_variable: str
The name of the variable by which to group (segment) data, e.g. date_lj.
features: dict
A dictionary of sensors (data types) and features to calculate.
See config/minimal_features.yaml for an example.
participants_usernames: Collection
A list of usernames for which to calculate features.
If None, use all participants.
Returns
-------
None
"""
self.grouping_variable_name = grouping_variable self.grouping_variable_name = grouping_variable
self.grouping_variable = [grouping_variable] self.grouping_variable = [grouping_variable]
@ -170,4 +221,4 @@ def to_csv_with_settings(
index=False, index=False,
encoding="utf-8", encoding="utf-8",
) )
print("Exported the dataframe to " + str(full_path)) print("Exported the dataframe to " + str(full_path))

View File

@ -83,4 +83,4 @@ class Labels:
print("Labels aggregated.") print("Labels aggregated.")
def get_aggregated_labels(self): def get_aggregated_labels(self):
return self.df_esm_means return self.df_esm_means

View File

@ -44,4 +44,4 @@ class ModelValidation:
cv=self.cv, cv=self.cv,
n_jobs=-1, n_jobs=-1,
scoring="r2", scoring="r2",
) )