stress_at_work_analysis/machine_learning/pipeline.py

33 lines
1.1 KiB
Python
Raw Normal View History

2021-08-21 19:04:09 +02:00
import numpy as np
import yaml
from sklearn import linear_model
from machine_learning.features_sensor import SensorFeatures
from machine_learning.labels import Labels
from machine_learning.model import ModelValidation
2021-08-21 19:04:09 +02:00
if __name__ == "__main__":
with open("./config/prox_comm_PANAS_features.yaml", "r") as file:
sensor_features_params = yaml.safe_load(file)
sensor_features = SensorFeatures(**sensor_features_params)
sensor_features.set_sensor_data()
sensor_features.calculate_features()
with open("./config/prox_comm_PANAS_labels.yaml", "r") as file:
labels_params = yaml.safe_load(file)
labels = Labels(**labels_params)
labels.set_labels()
labels.aggregate_labels()
model_validation = ModelValidation(
sensor_features.get_features("all", "all"),
labels.get_aggregated_labels(),
group_variable="participant_id",
cv_name="loso",
)
model_validation.model = linear_model.LinearRegression()
model_validation.set_cv_method()
model_loso_r2 = model_validation.cross_validate()
print(model_loso_r2)
print(np.mean(model_loso_r2))