From 5be3e82797a188dd77596334e615475401f365cb Mon Sep 17 00:00:00 2001 From: junos Date: Thu, 19 Aug 2021 17:23:23 +0200 Subject: [PATCH] Accept nested feature configuration. To do this, pass a dict as parameters to SensorFeatures class, rather than actually reading the object from yaml file. --- exploration/ex_ml_pipeline.py | 8 +++++++- machine_learning/config/minimal_features.yaml | 6 +++--- machine_learning/pipeline.py | 17 ++++++----------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/exploration/ex_ml_pipeline.py b/exploration/ex_ml_pipeline.py index d6a9831..c91b8f7 100644 --- a/exploration/ex_ml_pipeline.py +++ b/exploration/ex_ml_pipeline.py @@ -158,7 +158,11 @@ from machine_learning import pipeline # %% with open("../machine_learning/config/minimal_features.yaml", "r") as file: - sensor_features = yaml.full_load(file) + sensor_features_params = yaml.safe_load(file) + +# %% +sensor_features = pipeline.SensorFeatures(**sensor_features_params) +sensor_features.data_types # %% sensor_features.get_sensor_data("proximity") @@ -174,3 +178,5 @@ sensor_features.calculate_features() # %% sensor_features.get_features("proximity", "all") + +# %% diff --git a/machine_learning/config/minimal_features.yaml b/machine_learning/config/minimal_features.yaml index c54f47e..a015607 100644 --- a/machine_learning/config/minimal_features.yaml +++ b/machine_learning/config/minimal_features.yaml @@ -1,5 +1,5 @@ ---- !SensorFeatures grouping_variable: date_lj -data_types: [proximity] -feature_names: all +features: + proximity: + all participants_usernames: [nokia_0000003] diff --git a/machine_learning/pipeline.py b/machine_learning/pipeline.py index f325e91..8f91f28 100644 --- a/machine_learning/pipeline.py +++ b/machine_learning/pipeline.py @@ -1,7 +1,7 @@ import datetime +from collections.abc import Collection import pandas as pd -import yaml from sklearn.model_selection import cross_val_score import participants.query_db @@ -9,21 +9,16 @@ from features import communication, esm, helper, proximity from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME -class SensorFeatures(yaml.YAMLObject): - yaml_tag = u"!SensorFeatures" - +class SensorFeatures: def __init__( self, - grouping_variable, - data_types, - feature_names=None, - participants_usernames=None, + grouping_variable: str, + features: dict, + participants_usernames: Collection = None, ): - self.data_types = data_types self.grouping_variable = grouping_variable - if feature_names is None: - self.feature_names = [] + self.data_types = features.keys() if participants_usernames is None: participants_usernames = participants.query_db.get_usernames(