[WIP] Finish the class by assigning columns and validating model.
parent
b06ec6e1ae
commit
d6f36ec8f8
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from sklearn.model_selection import cross_val_score
|
||||||
|
|
||||||
import participants.query_db
|
import participants.query_db
|
||||||
from features import esm, helper, proximity
|
from features import esm, helper, proximity
|
||||||
|
@ -8,7 +9,15 @@ from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||||
|
|
||||||
|
|
||||||
class MachineLearningPipeline:
|
class MachineLearningPipeline:
|
||||||
def __init__(self, labels_questionnaire, data_types, participants_usernames=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
labels_questionnaire,
|
||||||
|
labels_scale,
|
||||||
|
data_types,
|
||||||
|
participants_usernames=None,
|
||||||
|
feature_names=None,
|
||||||
|
grouping_variable=None,
|
||||||
|
):
|
||||||
if participants_usernames is None:
|
if participants_usernames is None:
|
||||||
participants_usernames = participants.query_db.get_usernames(
|
participants_usernames = participants.query_db.get_usernames(
|
||||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||||
|
@ -17,6 +26,17 @@ class MachineLearningPipeline:
|
||||||
self.labels_questionnaire = labels_questionnaire
|
self.labels_questionnaire = labels_questionnaire
|
||||||
self.data_types = data_types
|
self.data_types = data_types
|
||||||
|
|
||||||
|
if feature_names is None:
|
||||||
|
self.feature_names = []
|
||||||
|
self.df_features = pd.DataFrame()
|
||||||
|
self.labels_scale = labels_scale
|
||||||
|
self.df_labels = pd.DataFrame()
|
||||||
|
self.grouping_variable = grouping_variable
|
||||||
|
self.df_groups = pd.DataFrame()
|
||||||
|
|
||||||
|
self.model = None
|
||||||
|
self.validation_method = None
|
||||||
|
|
||||||
self.df_esm = pd.DataFrame()
|
self.df_esm = pd.DataFrame()
|
||||||
self.df_esm_preprocessed = pd.DataFrame()
|
self.df_esm_preprocessed = pd.DataFrame()
|
||||||
self.df_esm_interest = pd.DataFrame()
|
self.df_esm_interest = pd.DataFrame()
|
||||||
|
@ -77,3 +97,29 @@ class MachineLearningPipeline:
|
||||||
self.df_full_data_daily_means = self.df_full_data_daily_means.join(
|
self.df_full_data_daily_means = self.df_full_data_daily_means.join(
|
||||||
self.df_proximity_daily_counts
|
self.df_proximity_daily_counts
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def assign_columns(self):
|
||||||
|
self.df_features = self.df_full_data_daily_means[self.feature_names]
|
||||||
|
self.df_labels = self.df_full_data_daily_means[self.labels_scale]
|
||||||
|
if self.grouping_variable:
|
||||||
|
self.df_groups = self.df_full_data_daily_means[self.grouping_variable]
|
||||||
|
else:
|
||||||
|
self.df_groups = None
|
||||||
|
|
||||||
|
def validate_model(self):
|
||||||
|
if self.model is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"Please, specify a machine learning model first, by setting the .model attribute."
|
||||||
|
)
|
||||||
|
if self.validation_method is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"Please, specify a cross validation method first, by setting the .validation_method attribute."
|
||||||
|
)
|
||||||
|
cross_val_score(
|
||||||
|
estimator=self.model,
|
||||||
|
X=self.df_features,
|
||||||
|
y=self.df_labels,
|
||||||
|
groups=self.df_groups,
|
||||||
|
cv=self.validation_method,
|
||||||
|
n_jobs=-1,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue