From d6f36ec8f8f02523c3dfc04ca2a4e66abffbf16d Mon Sep 17 00:00:00 2001 From: junos Date: Fri, 13 Aug 2021 17:40:31 +0200 Subject: [PATCH] [WIP] Finish the class by assigning columns and validating model. --- machine_learning/pipeline.py | 48 +++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/machine_learning/pipeline.py b/machine_learning/pipeline.py index df2d035..5b57f53 100644 --- a/machine_learning/pipeline.py +++ b/machine_learning/pipeline.py @@ -1,6 +1,7 @@ import datetime import pandas as pd +from sklearn.model_selection import cross_val_score import participants.query_db from features import esm, helper, proximity @@ -8,7 +9,15 @@ from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME class MachineLearningPipeline: - def __init__(self, labels_questionnaire, data_types, participants_usernames=None): + def __init__( + self, + labels_questionnaire, + labels_scale, + data_types, + participants_usernames=None, + feature_names=None, + grouping_variable=None, + ): if participants_usernames is None: participants_usernames = participants.query_db.get_usernames( collection_start=datetime.date.fromisoformat("2020-08-01") @@ -17,6 +26,17 @@ class MachineLearningPipeline: self.labels_questionnaire = labels_questionnaire self.data_types = data_types + if feature_names is None: + self.feature_names = [] + self.df_features = pd.DataFrame() + self.labels_scale = labels_scale + self.df_labels = pd.DataFrame() + self.grouping_variable = grouping_variable + self.df_groups = pd.DataFrame() + + self.model = None + self.validation_method = None + self.df_esm = pd.DataFrame() self.df_esm_preprocessed = pd.DataFrame() self.df_esm_interest = pd.DataFrame() @@ -77,3 +97,29 @@ class MachineLearningPipeline: self.df_full_data_daily_means = self.df_full_data_daily_means.join( self.df_proximity_daily_counts ) + + def assign_columns(self): + self.df_features = self.df_full_data_daily_means[self.feature_names] + self.df_labels = self.df_full_data_daily_means[self.labels_scale] + if self.grouping_variable: + self.df_groups = self.df_full_data_daily_means[self.grouping_variable] + else: + self.df_groups = None + + def validate_model(self): + if self.model is None: + raise AttributeError( + "Please, specify a machine learning model first, by setting the .model attribute." + ) + if self.validation_method is None: + raise AttributeError( + "Please, specify a cross validation method first, by setting the .validation_method attribute." + ) + cross_val_score( + estimator=self.model, + X=self.df_features, + y=self.df_labels, + groups=self.df_groups, + cv=self.validation_method, + n_jobs=-1, + )