stress_at_work_analysis/machine_learning/prox_comm_PANAS_nb.ipynb

479 lines
24 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "25ba2626-2b93-48e7-b9cc-551fe03335f4",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import yaml\n",
"from sklearn import linear_model\n",
"from sklearn.model_selection import LeaveOneGroupOut, cross_val_score\n",
"import os\n",
"import importlib\n",
"import matplotlib.pyplot as plt\n",
"import sys\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import pandas as pd\n",
"\n",
"nb_dir = os.path.split(os.getcwd())[0]\n",
"if nb_dir not in sys.path:\n",
" sys.path.append(nb_dir)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b429e654-b065-4ea7-9dac-955584f7a016",
"metadata": {},
"outputs": [],
"source": [
"from machine_learning import pipeline, features_sensor, labels, model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5f2a92e0-d6ea-49a1-9f06-d808c1bd57e9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<module 'machine_learning.labels' from 'C:\\\\Users\\\\junos\\\\Documents\\\\FWO-ARRS\\\\Analysis\\\\straw2analysis\\\\machine_learning\\\\labels.py'>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"importlib.reload(labels)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "948cb320-f2c1-46a2-a42d-ab12894d321a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SensorFeatures initialized.\n",
"Calculating features ...\n",
"Read proximity features from the file.\n",
"Read communication features from the file.\n"
]
}
],
"source": [
"with open(\"./config/prox_comm_PANAS_features.yaml\", \"r\") as file:\n",
" sensor_features_params = yaml.safe_load(file)\n",
"sensor_features = features_sensor.SensorFeatures(**sensor_features_params)\n",
"#sensor_features.set_sensor_data()\n",
"sensor_features.calculate_features(cached=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "872679f6-e343-4d2a-bfc5-e4e3d224c766",
"metadata": {},
"outputs": [],
"source": [
"all_features = sensor_features.get_features(\"all\",\"all\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "52f0f3cb-733a-4345-ab36-e52dc3c5a76c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Labels initialized.\n",
"Aggregating labels ...\n",
"Read labels from the file.\n"
]
}
],
"source": [
"with open(\"./config/prox_comm_PANAS_labels.yaml\", \"r\") as file:\n",
" labels_params = yaml.safe_load(file)\n",
"labels_current = labels.Labels(**labels_params)\n",
"#labels_current.set_labels()\n",
"labels_current.aggregate_labels(cached=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c366516a-6aa6-4101-a18d-0dc35f597d87",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelValidation initialized.\n",
"Validation method set.\n"
]
}
],
"source": [
"model_validation = model.ModelValidation(\n",
" sensor_features.get_features(\"all\", \"all\"),\n",
" labels_current.get_aggregated_labels(),\n",
" group_variable=\"participant_id\",\n",
" cv_name=\"loso\",\n",
")\n",
"model_validation.model = linear_model.LinearRegression()\n",
"model_validation.set_cv_method()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0eab568d-ad7f-4243-be05-26bafb310c5c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running cross validation ...\n"
]
}
],
"source": [
"model_loso_r2 = model_validation.cross_validate()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fde0151b-c259-45e8-af2e-94f37edf0b01",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-8.50176280e-02 -3.66239404e-02 -5.77416259e-01 -9.06942628e-01\n",
" -3.13084482e+00 -5.25290051e-02 -2.32414699e+00 -7.65972845e+00\n",
" -3.65181380e+00 -9.62417876e+00 -7.44270369e-02 6.78427260e-02\n",
" -5.69919784e-01 -9.03242379e-01 -1.21151912e-01 -5.13453030e+00\n",
" -1.60384696e+00 -3.19062741e+00 -6.63847516e-01 -9.90156817e-02\n",
" -7.72057926e-01 -4.90843105e+01 3.69446095e-01 -2.08765985e+00\n",
" -1.99641377e+00 -1.29034837e+03 -3.21364491e+00 -5.77331614e-01\n",
" 0.00000000e+00 -6.84298747e-03 -1.63138097e+01 -1.66204067e+00\n",
" -2.27751119e-01 -1.33661361e+00 -9.81485624e-01 -8.49005069e+00\n",
" -1.54261232e+01 -1.07208976e+01 -5.94109632e-01 -1.46186838e-01\n",
" -3.35992820e-01 -1.56058931e-01 -4.30691060e+00 -4.02218511e+00\n",
" -1.76476411e+01 -4.87642855e-02 -5.30649694e+00 -2.17399142e-01\n",
" -4.13520657e-01 -3.86148143e+00 -8.01412328e-01 -8.23760834e+00\n",
" -2.06664107e+00 -1.18230651e+00 -2.06091099e-02 -9.97601126e-02]\n",
"-26.655054402780422\n"
]
}
],
"source": [
"print(model_loso_r2)\n",
"print(np.mean(model_loso_r2))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0435685e-9998-4eff-a3ee-6edc781dde81",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.06784273, 0.36944609])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_loso_r2[model_loso_r2 > 0]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c4560ac5-8c83-43d0-b6e0-b03dfd19c1c3",
"metadata": {},
"outputs": [],
"source": [
"logo = LeaveOneGroupOut()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "75dc95ca-ad77-4e36-bc8b-653d3b73037f",
"metadata": {},
"outputs": [],
"source": [
"try_X = model_validation.X.reset_index().drop([\"participant_id\",\"date_lj\"], axis=1)\n",
"try_y = model_validation.y.reset_index().drop([\"participant_id\",\"date_lj\"], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8a58963d-a10b-468a-ae82-6395e8b2e7b5",
"metadata": {},
"outputs": [],
"source": [
"model_loso_mean_absolute_error = -1 * cross_val_score(\n",
"estimator=model_validation.model,\n",
"X=try_X,\n",
"y=try_y,\n",
"groups=model_validation.groups,\n",
"cv=logo.split(X=try_X, y=try_y, groups=model_validation.groups), \n",
"scoring='neg_mean_absolute_error'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "c98e13d6-734f-4adc-909b-c4a400a01d3e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.43618444, 0.39780929, 0.56970789, 0.38847095, 0.78244762,\n",
" 0.29847855, 0.4348883 , 1.80633684, 0.29097794, 0.53106755,\n",
" 0.32757327, 0.37845186, 0.30228743, 0.34129752, 0.2555845 ,\n",
" 1.27981007, 0.40270591, 0.35411635, 0.2568122 , 0.5820276 ,\n",
" 0.33293713, 0.47789249, 0.19690204, 0.68629304, 0.67457704,\n",
" 13.0369228 , 0.41234072, 0.31384332, 0.45126702, 0.34806906,\n",
" 0.52854722, 0.28707449, 0.28282637, 0.49286602, 0.26406791,\n",
" 0.39567315, 0.33661383, 1.23764371, 0.43788937, 0.32592072,\n",
" 0.47443271, 0.55999948, 0.50408039, 0.40523803, 0.50241167,\n",
" 0.30617356, 0.31461521, 0.28494495, 0.32278505, 0.29084659,\n",
" 0.47211231, 0.33807521, 0.34608592, 0.40624902, 0.22882316,\n",
" 0.45563856])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_loso_mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cd821657-cc18-46f3-92d1-b331b863790f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.39674122009711504"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.median(model_loso_mean_absolute_error)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ead0d898-8a96-404d-a895-b213771dc7ea",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LinearRegression()"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_validation.model.fit(try_X, try_y)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "406d82e7-488c-46a3-8426-ca49e01993f5",
"metadata": {},
"outputs": [],
"source": [
"Y_predicted = model_validation.model.predict(try_X)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "310c6287-7d6e-4261-8c9c-3c592c822bd1",
"metadata": {},
"outputs": [],
"source": [
"try_y.rename(columns={\"NA\": \"NA_true\"}, inplace=True)\n",
"try_y[\"NA_predicted\"] = Y_predicted\n",
"NA_long = pd.wide_to_long(\n",
" try_y.reset_index(),\n",
" i=\"index\",\n",
" j=\"value\",\n",
" stubnames=\"NA\",\n",
" sep=\"_\",\n",
" suffix=\".+\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "62f9312d-f9d7-403c-89c0-5c04d05e76bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<seaborn.axisgrid.FacetGrid at 0x26fc52a8d00>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 623.375x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"g1 = sns.displot(NA_long, x=\"NA\", hue=\"value\", binwidth=0.1, height=5, aspect=1.5)\n",
"sns.move_legend(g1, \"upper left\", bbox_to_anchor=(.55, .45))\n",
"g1.set_axis_labels(\"Daily mean\", \"Day count\")\n",
"\n",
"display(g1)\n",
"g1.savefig(\"prox_comm_PANAS_predictions.pdf\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "7e84c79b-321a-4e8f-a795-515fafe169a4",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "e1579333-b57b-4bce-9c86-f77c0cd0d3d4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.42725018860641295"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean_absolute_error(try_y[\"NA_true\"], try_y[\"NA_predicted\"])"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "5f633f76-999a-436f-afc3-f3cc44061e5a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.43618444, 0.39780929, 0.56970789, 0.38847095, 0.78244762,\n",
" 0.29847855, 0.4348883 , 1.80633684, 0.29097794, 0.53106755,\n",
" 0.32757327, 0.37845186, 0.30228743, 0.34129752, 0.2555845 ,\n",
" 1.27981007, 0.40270591, 0.35411635, 0.2568122 , 0.5820276 ,\n",
" 0.33293713, 0.47789249, 0.19690204, 0.68629304, 0.67457704,\n",
" 13.0369228 , 0.41234072, 0.31384332, 0.45126702, 0.34806906,\n",
" 0.52854722, 0.28707449, 0.28282637, 0.49286602, 0.26406791,\n",
" 0.39567315, 0.33661383, 1.23764371, 0.43788937, 0.32592072,\n",
" 0.47443271, 0.55999948, 0.50408039, 0.40523803, 0.50241167,\n",
" 0.30617356, 0.31461521, 0.28494495, 0.32278505, 0.29084659,\n",
" 0.47211231, 0.33807521, 0.34608592, 0.40624902, 0.22882316,\n",
" 0.45563856])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_loso_mean_absolute_error"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "straw2analysis",
"language": "python",
"name": "straw2analysis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}