Prepare the first full pipeline.
parent
24c4bef7e2
commit
a71e132edf
|
@ -0,0 +1,6 @@
|
||||||
|
grouping_variable: [date_lj]
|
||||||
|
features:
|
||||||
|
proximity:
|
||||||
|
all
|
||||||
|
communication:
|
||||||
|
all
|
|
@ -0,0 +1,5 @@
|
||||||
|
grouping_variable: [date_lj]
|
||||||
|
labels:
|
||||||
|
PANAS:
|
||||||
|
- PA
|
||||||
|
- NA
|
|
@ -1,7 +1,10 @@
|
||||||
import datetime
|
import datetime
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import yaml
|
||||||
|
from sklearn import linear_model
|
||||||
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
||||||
|
|
||||||
import participants.query_db
|
import participants.query_db
|
||||||
|
@ -343,3 +346,29 @@ class MachineLearningPipeline:
|
||||||
cv=self.validation_method,
|
cv=self.validation_method,
|
||||||
n_jobs=-1,
|
n_jobs=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open("./config/prox_comm_PANAS_features.yaml", "r") as file:
|
||||||
|
sensor_features_params = yaml.safe_load(file)
|
||||||
|
sensor_features = SensorFeatures(**sensor_features_params)
|
||||||
|
sensor_features.set_sensor_data()
|
||||||
|
sensor_features.calculate_features()
|
||||||
|
|
||||||
|
with open("./config/prox_comm_PANAS_labels.yaml", "r") as file:
|
||||||
|
labels_params = yaml.safe_load(file)
|
||||||
|
labels = Labels(**labels_params)
|
||||||
|
labels.set_labels()
|
||||||
|
labels.aggregate_labels()
|
||||||
|
|
||||||
|
model_validation = ModelValidation(
|
||||||
|
sensor_features.get_features("all", "all"),
|
||||||
|
labels.get_aggregated_labels(),
|
||||||
|
group_variable="participant_id",
|
||||||
|
cv_name="loso",
|
||||||
|
)
|
||||||
|
model_validation.model = linear_model.LinearRegression()
|
||||||
|
model_validation.set_cv_method()
|
||||||
|
model_loso_r2 = model_validation.cross_validate()
|
||||||
|
print(model_loso_r2)
|
||||||
|
print(np.mean(model_loso_r2))
|
||||||
|
|
Loading…
Reference in New Issue