2021-08-21 19:04:09 +02:00
|
|
|
import numpy as np
|
|
|
|
import yaml
|
|
|
|
from sklearn import linear_model
|
|
|
|
|
2021-09-13 11:41:57 +02:00
|
|
|
from machine_learning.features_sensor import SensorFeatures
|
|
|
|
from machine_learning.labels import Labels
|
|
|
|
from machine_learning.model import ModelValidation
|
2021-08-21 19:04:09 +02:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
with open("./config/prox_comm_PANAS_features.yaml", "r") as file:
|
|
|
|
sensor_features_params = yaml.safe_load(file)
|
|
|
|
sensor_features = SensorFeatures(**sensor_features_params)
|
|
|
|
sensor_features.set_sensor_data()
|
|
|
|
sensor_features.calculate_features()
|
|
|
|
|
|
|
|
with open("./config/prox_comm_PANAS_labels.yaml", "r") as file:
|
|
|
|
labels_params = yaml.safe_load(file)
|
|
|
|
labels = Labels(**labels_params)
|
|
|
|
labels.set_labels()
|
|
|
|
labels.aggregate_labels()
|
|
|
|
|
|
|
|
model_validation = ModelValidation(
|
|
|
|
sensor_features.get_features("all", "all"),
|
|
|
|
labels.get_aggregated_labels(),
|
|
|
|
group_variable="participant_id",
|
|
|
|
cv_name="loso",
|
|
|
|
)
|
|
|
|
model_validation.model = linear_model.LinearRegression()
|
|
|
|
model_validation.set_cv_method()
|
|
|
|
model_loso_r2 = model_validation.cross_validate()
|
|
|
|
print(model_loso_r2)
|
|
|
|
print(np.mean(model_loso_r2))
|