Compare commits
20 Commits
6592612db7
...
c1bb4ddf0f
Author | SHA1 | Date |
---|---|---|
junos | c1bb4ddf0f | |
junos | 0152fbe4ac | |
junos | 3611fc76f7 | |
junos | ee30c042ea | |
junos | a71e132edf | |
junos | 24c4bef7e2 | |
junos | 11381d6447 | |
junos | d19995385d | |
junos | f73f86486a | |
junos | aed73bb7ed | |
junos | 8507ff5761 | |
junos | 0b85ee8fdc | |
junos | e2e268148d | |
junos | 00015a3b8d | |
junos | 065cd4347e | |
junos | 0b98d59aad | |
junos | 08fdec34f1 | |
junos | 72b16af75c | |
junos | d6337e82ac | |
junos | 9a319ac6e5 |
|
@ -15,6 +15,7 @@ dependencies:
|
||||||
- psycopg2
|
- psycopg2
|
||||||
- python-dotenv
|
- python-dotenv
|
||||||
- pytz
|
- pytz
|
||||||
|
- pyprojroot
|
||||||
- pyyaml
|
- pyyaml
|
||||||
- seaborn
|
- seaborn
|
||||||
- scikit-learn
|
- scikit-learn
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
# %%
|
# %%
|
||||||
# %matplotlib inline
|
# %matplotlib inline
|
||||||
import datetime
|
import datetime
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -156,14 +157,25 @@ lin_reg_proximity.score(
|
||||||
# %%
|
# %%
|
||||||
from machine_learning import pipeline
|
from machine_learning import pipeline
|
||||||
|
|
||||||
|
# %%
|
||||||
|
importlib.reload(pipeline)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
with open("../machine_learning/config/minimal_features.yaml", "r") as file:
|
with open("../machine_learning/config/minimal_features.yaml", "r") as file:
|
||||||
sensor_features_params = yaml.safe_load(file)
|
sensor_features_params = yaml.safe_load(file)
|
||||||
|
print(sensor_features_params)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
sensor_features = pipeline.SensorFeatures(**sensor_features_params)
|
sensor_features = pipeline.SensorFeatures(**sensor_features_params)
|
||||||
sensor_features.data_types
|
sensor_features.data_types
|
||||||
|
|
||||||
|
# %%
|
||||||
|
sensor_features.set_participants_label("nokia_0000003")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
sensor_features.data_types = ["proximity", "communication"]
|
||||||
|
sensor_features.participants_usernames = ptcp_2
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
sensor_features.get_sensor_data("proximity")
|
sensor_features.get_sensor_data("proximity")
|
||||||
|
|
||||||
|
@ -179,12 +191,19 @@ sensor_features.calculate_features()
|
||||||
# %%
|
# %%
|
||||||
sensor_features.get_features("proximity", "all")
|
sensor_features.get_features("proximity", "all")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
sensor_features.get_features("communication", "all")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
sensor_features.get_features("all", "all")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
||||||
labels_params = yaml.safe_load(file)
|
labels_params = yaml.safe_load(file)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
labels = pipeline.Labels(**labels_params)
|
labels = pipeline.Labels(**labels_params)
|
||||||
|
labels.participants_usernames = ptcp_2
|
||||||
labels.questionnaires
|
labels.questionnaires
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -194,3 +213,25 @@ labels.set_labels()
|
||||||
labels.get_labels("PANAS")
|
labels.get_labels("PANAS")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
labels.aggregate_labels()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
labels.get_aggregated_labels()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model_validation = pipeline.ModelValidation(
|
||||||
|
sensor_features.get_features("all", "all"),
|
||||||
|
labels.get_aggregated_labels(),
|
||||||
|
group_variable="participant_id",
|
||||||
|
cv_name="loso",
|
||||||
|
)
|
||||||
|
model_validation.model = linear_model.LinearRegression()
|
||||||
|
model_validation.set_cv_method()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model_validation.cross_validate()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model_validation.groups
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
|
@ -13,14 +13,15 @@
|
||||||
# name: straw2analysis
|
# name: straw2analysis
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import importlib
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# %matplotlib inline
|
# %matplotlib inline
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
# %%
|
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
||||||
nb_dir = os.path.split(os.getcwd())[0]
|
nb_dir = os.path.split(os.getcwd())[0]
|
||||||
|
@ -28,21 +29,29 @@ if nb_dir not in sys.path:
|
||||||
sys.path.append(nb_dir)
|
sys.path.append(nb_dir)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
from features.communication import *
|
from features import communication, helper
|
||||||
|
|
||||||
|
# %%
|
||||||
|
importlib.reload(communication)
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # Example of communication data and feature calculation
|
# # Example of communication data and feature calculation
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
df_calls = get_call_data(["nokia_0000003"])
|
df_calls = communication.get_call_data(["nokia_0000003"])
|
||||||
print(df_calls)
|
print(df_calls)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
count_comms(df_calls)
|
df_calls = helper.get_date_from_timestamp(df_calls)
|
||||||
|
communication.count_comms(df_calls, ["date_lj"])
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
df_sms = get_sms_data(["nokia_0000003"])
|
df_sms = communication.get_sms_data(["nokia_0000003"])
|
||||||
count_comms(df_sms)
|
df_sms = helper.get_date_from_timestamp(df_sms)
|
||||||
|
communication.count_comms(df_sms, ["date_lj"])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
communication.calls_sms_features(df_calls, df_sms, ["date_lj"])
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # Call data
|
# # Call data
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
# %%
|
# %%
|
||||||
# %matplotlib inline
|
# %matplotlib inline
|
||||||
import datetime
|
import datetime
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -32,13 +33,16 @@ import participants.query_db
|
||||||
TZ_LJ = timezone("Europe/Ljubljana")
|
TZ_LJ = timezone("Europe/Ljubljana")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
from features.proximity import *
|
from features import helper, proximity
|
||||||
|
|
||||||
|
# %%
|
||||||
|
importlib.reload(proximity)
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # Basic characteristics
|
# # Basic characteristics
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
df_proximity_nokia = get_proximity_data(["nokia_0000003"])
|
df_proximity_nokia = proximity.get_proximity_data(["nokia_0000003"])
|
||||||
print(df_proximity_nokia)
|
print(df_proximity_nokia)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
@ -53,7 +57,7 @@ df_proximity_nokia.double_proximity.value_counts()
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
participants_inactive_usernames = participants.query_db.get_usernames()
|
participants_inactive_usernames = participants.query_db.get_usernames()
|
||||||
df_proximity_inactive = get_proximity_data(participants_inactive_usernames)
|
df_proximity_inactive = proximity.get_proximity_data(participants_inactive_usernames)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
df_proximity_inactive.double_proximity.describe()
|
df_proximity_inactive.double_proximity.describe()
|
||||||
|
@ -110,3 +114,13 @@ df_proximity_combinations[
|
||||||
(df_proximity_combinations[5.0] != 0)
|
(df_proximity_combinations[5.0] != 0)
|
||||||
& (df_proximity_combinations[5.00030517578125] != 0)
|
& (df_proximity_combinations[5.00030517578125] != 0)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # Features
|
||||||
|
|
||||||
|
# %%
|
||||||
|
df_proximity_inactive = helper.get_date_from_timestamp(df_proximity_inactive)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
df_proximity_features = proximity.count_proximity(df_proximity_inactive, ["date_lj"])
|
||||||
|
display(df_proximity_features)
|
||||||
|
|
|
@ -8,14 +8,21 @@ from setup import db_engine, session
|
||||||
call_types = {1: "incoming", 2: "outgoing", 3: "missed"}
|
call_types = {1: "incoming", 2: "outgoing", 3: "missed"}
|
||||||
sms_types = {1: "received", 2: "sent"}
|
sms_types = {1: "received", 2: "sent"}
|
||||||
|
|
||||||
FEATURES_CALLS = (
|
FILL_NA_CALLS = {
|
||||||
["no_calls_all"]
|
"no_calls_all": 0,
|
||||||
+ ["no_" + call_type for call_type in call_types.values()]
|
"no_" + call_types.get(1): 0,
|
||||||
+ ["duration_total_" + call_types.get(1), "duration_total_" + call_types.get(2)]
|
"no_" + call_types.get(2): 0,
|
||||||
+ ["duration_max_" + call_types.get(1), "duration_max_" + call_types.get(2)]
|
"no_" + call_types.get(3): 0,
|
||||||
+ ["no_" + call_types.get(1) + "_ratio", "no_" + call_types.get(2) + "_ratio"]
|
"duration_total_" + call_types.get(1): 0,
|
||||||
+ ["no_contacts_calls"]
|
"duration_total_" + call_types.get(2): 0,
|
||||||
)
|
"duration_max_" + call_types.get(1): 0,
|
||||||
|
"duration_max_" + call_types.get(2): 0,
|
||||||
|
"no_" + call_types.get(1) + "_ratio": 1 / 3, # Three different types
|
||||||
|
"no_" + call_types.get(2) + "_ratio": 1 / 3,
|
||||||
|
"no_contacts_calls": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
FEATURES_CALLS = list(FILL_NA_CALLS.keys())
|
||||||
|
|
||||||
# FEATURES_CALLS =
|
# FEATURES_CALLS =
|
||||||
# ["no_calls_all",
|
# ["no_calls_all",
|
||||||
|
@ -23,19 +30,24 @@ FEATURES_CALLS = (
|
||||||
# "duration_total_incoming", "duration_total_outgoing",
|
# "duration_total_incoming", "duration_total_outgoing",
|
||||||
# "duration_max_incoming", "duration_max_outgoing",
|
# "duration_max_incoming", "duration_max_outgoing",
|
||||||
# "no_incoming_ratio", "no_outgoing_ratio",
|
# "no_incoming_ratio", "no_outgoing_ratio",
|
||||||
# "no_contacts"]
|
# "no_contacts_calls"]
|
||||||
|
|
||||||
|
FILL_NA_SMS = {
|
||||||
|
"no_sms_all": 0,
|
||||||
|
"no_" + sms_types.get(1): 0,
|
||||||
|
"no_" + sms_types.get(2): 0,
|
||||||
|
"no_" + sms_types.get(1) + "_ratio": 1 / 2, # Two different types
|
||||||
|
"no_" + sms_types.get(2) + "_ratio": 1 / 2,
|
||||||
|
"no_contacts_sms": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
FEATURES_SMS = list(FILL_NA_SMS.keys())
|
||||||
|
|
||||||
FEATURES_SMS = (
|
|
||||||
["no_sms_all"]
|
|
||||||
+ ["no_" + sms_type for sms_type in sms_types.values()]
|
|
||||||
+ ["no_" + sms_types.get(1) + "_ratio", "no_" + sms_types.get(2) + "_ratio"]
|
|
||||||
+ ["no_contacts_sms"]
|
|
||||||
)
|
|
||||||
# FEATURES_SMS =
|
# FEATURES_SMS =
|
||||||
# ["no_sms_all",
|
# ["no_sms_all",
|
||||||
# "no_received", "no_sent",
|
# "no_received", "no_sent",
|
||||||
# "no_received_ratio", "no_sent_ratio",
|
# "no_received_ratio", "no_sent_ratio",
|
||||||
# "no_contacts"]
|
# "no_contacts_sms"]
|
||||||
|
|
||||||
FEATURES_CALLS_SMS_PROP = [
|
FEATURES_CALLS_SMS_PROP = [
|
||||||
"proportion_calls_all",
|
"proportion_calls_all",
|
||||||
|
@ -45,8 +57,15 @@ FEATURES_CALLS_SMS_PROP = [
|
||||||
"proportion_calls_missed_sms_received",
|
"proportion_calls_missed_sms_received",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
FILL_NA_CALLS_SMS_PROP = {
|
||||||
|
key: 1 / 2 for key in FEATURES_CALLS_SMS_PROP
|
||||||
|
} # All of the form of a / (a + b).
|
||||||
|
|
||||||
FEATURES_CALLS_SMS_ALL = FEATURES_CALLS + FEATURES_SMS + FEATURES_CALLS_SMS_PROP
|
FEATURES_CALLS_SMS_ALL = FEATURES_CALLS + FEATURES_SMS + FEATURES_CALLS_SMS_PROP
|
||||||
|
|
||||||
|
FILL_NA_CALLS_SMS_ALL = FILL_NA_CALLS | FILL_NA_SMS | FILL_NA_CALLS_SMS_PROP
|
||||||
|
# As per PEP-584 a union for dicts was implemented in Python 3.9.0.
|
||||||
|
|
||||||
|
|
||||||
def get_call_data(usernames: Collection) -> pd.DataFrame:
|
def get_call_data(usernames: Collection) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
|
@ -137,7 +156,7 @@ def enumerate_contacts(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
return comm_df
|
return comm_df
|
||||||
|
|
||||||
|
|
||||||
def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
def count_comms(comm_df: pd.DataFrame, group_by=None) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Calculate frequencies (and duration) of messages (or calls), grouped by their types.
|
Calculate frequencies (and duration) of messages (or calls), grouped by their types.
|
||||||
|
|
||||||
|
@ -145,6 +164,9 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
----------
|
----------
|
||||||
comm_df: pd.DataFrame
|
comm_df: pd.DataFrame
|
||||||
A dataframe of calls or SMSes.
|
A dataframe of calls or SMSes.
|
||||||
|
group_by: list
|
||||||
|
A list of strings, specifying by which parameters to group.
|
||||||
|
By default, the features are calculated per participant, but could be "date_lj" etc.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -157,11 +179,13 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
* the number of messages by type (received, sent), and
|
* the number of messages by type (received, sent), and
|
||||||
* the number of communication contacts by type.
|
* the number of communication contacts by type.
|
||||||
"""
|
"""
|
||||||
|
if group_by is None:
|
||||||
|
group_by = []
|
||||||
if "call_type" in comm_df:
|
if "call_type" in comm_df:
|
||||||
data_type = "calls"
|
data_type = "calls"
|
||||||
comm_counts = (
|
comm_counts = (
|
||||||
comm_df.value_counts(subset=["participant_id", "call_type"])
|
comm_df.value_counts(subset=group_by + ["participant_id", "call_type"])
|
||||||
.unstack()
|
.unstack(level="call_type", fill_value=0)
|
||||||
.rename(columns=call_types)
|
.rename(columns=call_types)
|
||||||
.add_prefix("no_")
|
.add_prefix("no_")
|
||||||
)
|
)
|
||||||
|
@ -174,17 +198,17 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
)
|
)
|
||||||
# Ratio of incoming and outgoing calls to all calls.
|
# Ratio of incoming and outgoing calls to all calls.
|
||||||
comm_duration_total = (
|
comm_duration_total = (
|
||||||
comm_df.groupby(["participant_id", "call_type"])
|
comm_df.groupby(group_by + ["participant_id", "call_type"])
|
||||||
.sum()["call_duration"]
|
.sum()["call_duration"]
|
||||||
.unstack()
|
.unstack(level="call_type", fill_value=0)
|
||||||
.rename(columns=call_types)
|
.rename(columns=call_types)
|
||||||
.add_prefix("duration_total_")
|
.add_prefix("duration_total_")
|
||||||
)
|
)
|
||||||
# Total call duration by type.
|
# Total call duration by type.
|
||||||
comm_duration_max = (
|
comm_duration_max = (
|
||||||
comm_df.groupby(["participant_id", "call_type"])
|
comm_df.groupby(group_by + ["participant_id", "call_type"])
|
||||||
.max()["call_duration"]
|
.max()["call_duration"]
|
||||||
.unstack()
|
.unstack(level="call_type", fill_value=0)
|
||||||
.rename(columns=call_types)
|
.rename(columns=call_types)
|
||||||
.add_prefix("duration_max_")
|
.add_prefix("duration_max_")
|
||||||
)
|
)
|
||||||
|
@ -202,8 +226,8 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
elif "message_type" in comm_df:
|
elif "message_type" in comm_df:
|
||||||
data_type = "sms"
|
data_type = "sms"
|
||||||
comm_counts = (
|
comm_counts = (
|
||||||
comm_df.value_counts(subset=["participant_id", "message_type"])
|
comm_df.value_counts(subset=group_by + ["participant_id", "message_type"])
|
||||||
.unstack()
|
.unstack(level="message_type", fill_value=0)
|
||||||
.rename(columns=sms_types)
|
.rename(columns=sms_types)
|
||||||
.add_prefix("no_")
|
.add_prefix("no_")
|
||||||
)
|
)
|
||||||
|
@ -218,7 +242,7 @@ def count_comms(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
raise KeyError("The dataframe contains neither call_type or message_type")
|
raise KeyError("The dataframe contains neither call_type or message_type")
|
||||||
comm_contacts_counts = (
|
comm_contacts_counts = (
|
||||||
enumerate_contacts(comm_df)
|
enumerate_contacts(comm_df)
|
||||||
.groupby(["participant_id"])
|
.groupby(group_by + ["participant_id"])
|
||||||
.nunique()["contact_id"]
|
.nunique()["contact_id"]
|
||||||
.rename("no_contacts_" + data_type)
|
.rename("no_contacts_" + data_type)
|
||||||
)
|
)
|
||||||
|
@ -270,7 +294,9 @@ def contact_features(comm_df: pd.DataFrame) -> pd.DataFrame:
|
||||||
return contacts_count
|
return contacts_count
|
||||||
|
|
||||||
|
|
||||||
def calls_sms_features(df_calls: pd.DataFrame, df_sms: pd.DataFrame) -> pd.DataFrame:
|
def calls_sms_features(
|
||||||
|
df_calls: pd.DataFrame, df_sms: pd.DataFrame, group_by=None
|
||||||
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Calculates additional features relating calls and sms data.
|
Calculates additional features relating calls and sms data.
|
||||||
|
|
||||||
|
@ -280,6 +306,9 @@ def calls_sms_features(df_calls: pd.DataFrame, df_sms: pd.DataFrame) -> pd.DataF
|
||||||
A dataframe of calls (return of get_call_data).
|
A dataframe of calls (return of get_call_data).
|
||||||
df_sms: pd.DataFrame
|
df_sms: pd.DataFrame
|
||||||
A dataframe of SMSes (return of get_sms_data).
|
A dataframe of SMSes (return of get_sms_data).
|
||||||
|
group_by: list
|
||||||
|
A list of strings, specifying by which parameters to group.
|
||||||
|
By default, the features are calculated per participant, but could be "date_lj" etc.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -297,24 +326,38 @@ def calls_sms_features(df_calls: pd.DataFrame, df_sms: pd.DataFrame) -> pd.DataF
|
||||||
* proportion_calls_contacts:
|
* proportion_calls_contacts:
|
||||||
proportion of calls contacts in total number of communication contacts
|
proportion of calls contacts in total number of communication contacts
|
||||||
"""
|
"""
|
||||||
count_calls = count_comms(df_calls)
|
if group_by is None:
|
||||||
count_sms = count_comms(df_sms)
|
group_by = []
|
||||||
count_joined = count_calls.join(count_sms).assign(
|
count_calls = count_comms(df_calls, group_by)
|
||||||
proportion_calls_all=(
|
count_sms = count_comms(df_sms, group_by)
|
||||||
lambda x: x.no_calls_all / (x.no_calls_all + x.no_sms_all)
|
count_joined = (
|
||||||
),
|
count_calls.merge(
|
||||||
proportion_calls_incoming=(
|
count_sms,
|
||||||
lambda x: x.no_incoming / (x.no_incoming + x.no_received)
|
how="outer",
|
||||||
),
|
left_index=True,
|
||||||
proportion_calls_missed_sms_received=(
|
right_index=True,
|
||||||
lambda x: x.no_missed / (x.no_missed + x.no_received)
|
validate="one_to_one",
|
||||||
),
|
|
||||||
proportion_calls_outgoing=(
|
|
||||||
lambda x: x.no_outgoing / (x.no_outgoing + x.no_sent)
|
|
||||||
),
|
|
||||||
proportion_calls_contacts=(
|
|
||||||
lambda x: x.no_contacts_calls / (x.no_contacts_calls + x.no_contacts_sms)
|
|
||||||
)
|
)
|
||||||
# Calculate new features and create additional columns
|
.fillna(0, downcast="infer")
|
||||||
|
.assign(
|
||||||
|
proportion_calls_all=(
|
||||||
|
lambda x: x.no_calls_all / (x.no_calls_all + x.no_sms_all)
|
||||||
|
),
|
||||||
|
proportion_calls_incoming=(
|
||||||
|
lambda x: x.no_incoming / (x.no_incoming + x.no_received)
|
||||||
|
),
|
||||||
|
proportion_calls_missed_sms_received=(
|
||||||
|
lambda x: x.no_missed / (x.no_missed + x.no_received)
|
||||||
|
),
|
||||||
|
proportion_calls_outgoing=(
|
||||||
|
lambda x: x.no_outgoing / (x.no_outgoing + x.no_sent)
|
||||||
|
),
|
||||||
|
proportion_calls_contacts=(
|
||||||
|
lambda x: x.no_contacts_calls
|
||||||
|
/ (x.no_contacts_calls + x.no_contacts_sms)
|
||||||
|
)
|
||||||
|
# Calculate new features and create additional columns
|
||||||
|
)
|
||||||
|
.fillna(0.5, downcast="infer")
|
||||||
)
|
)
|
||||||
return count_joined
|
return count_joined
|
||||||
|
|
|
@ -5,7 +5,12 @@ import pandas as pd
|
||||||
from config.models import Participant, Proximity
|
from config.models import Participant, Proximity
|
||||||
from setup import db_engine, session
|
from setup import db_engine, session
|
||||||
|
|
||||||
FEATURES_PROXIMITY = ["freq_prox_near", "prop_prox_near"]
|
FILL_NA_PROXIMITY = {
|
||||||
|
"freq_prox_near": 0,
|
||||||
|
"prop_prox_near": 1 / 2, # Of the form of a / (a + b).
|
||||||
|
}
|
||||||
|
|
||||||
|
FEATURES_PROXIMITY = list(FILL_NA_PROXIMITY.keys())
|
||||||
|
|
||||||
|
|
||||||
def get_proximity_data(usernames: Collection) -> pd.DataFrame:
|
def get_proximity_data(usernames: Collection) -> pd.DataFrame:
|
||||||
|
@ -78,11 +83,11 @@ def count_proximity(
|
||||||
A dataframe with the count of "near" proximity values and their relative count.
|
A dataframe with the count of "near" proximity values and their relative count.
|
||||||
"""
|
"""
|
||||||
if group_by is None:
|
if group_by is None:
|
||||||
group_by = ["participant_id"]
|
group_by = []
|
||||||
if "bool_prox_near" not in df_proximity:
|
if "bool_prox_near" not in df_proximity:
|
||||||
df_proximity = recode_proximity(df_proximity)
|
df_proximity = recode_proximity(df_proximity)
|
||||||
df_proximity["bool_prox_far"] = ~df_proximity["bool_prox_near"]
|
df_proximity["bool_prox_far"] = ~df_proximity["bool_prox_near"]
|
||||||
df_proximity_features = df_proximity.groupby(group_by).sum()[
|
df_proximity_features = df_proximity.groupby(["participant_id"] + group_by).sum()[
|
||||||
["bool_prox_near", "bool_prox_far"]
|
["bool_prox_near", "bool_prox_far"]
|
||||||
]
|
]
|
||||||
df_proximity_features = df_proximity_features.assign(
|
df_proximity_features = df_proximity_features.assign(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
grouping_variable: date_lj
|
grouping_variable: [date_lj]
|
||||||
labels:
|
labels:
|
||||||
PANAS:
|
PANAS:
|
||||||
- PA
|
- PA
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
grouping_variable: [date_lj]
|
||||||
|
features:
|
||||||
|
proximity:
|
||||||
|
all
|
||||||
|
communication:
|
||||||
|
all
|
|
@ -0,0 +1,5 @@
|
||||||
|
grouping_variable: [date_lj]
|
||||||
|
labels:
|
||||||
|
PANAS:
|
||||||
|
- PA
|
||||||
|
- NA
|
|
@ -1,13 +1,25 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import warnings
|
||||||
from collections.abc import Collection
|
from collections.abc import Collection
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.model_selection import cross_val_score
|
import yaml
|
||||||
|
from pyprojroot import here
|
||||||
|
from sklearn import linear_model
|
||||||
|
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
|
||||||
|
|
||||||
import participants.query_db
|
import participants.query_db
|
||||||
from features import communication, esm, helper, proximity
|
from features import communication, esm, helper, proximity
|
||||||
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||||
|
|
||||||
|
WARNING_PARTICIPANTS_LABEL = (
|
||||||
|
"Before calculating features, please set participants label using self.set_participants_label() "
|
||||||
|
"to be used as a filename prefix when exporting data. "
|
||||||
|
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SensorFeatures:
|
class SensorFeatures:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -16,16 +28,22 @@ class SensorFeatures:
|
||||||
features: dict,
|
features: dict,
|
||||||
participants_usernames: Collection = None,
|
participants_usernames: Collection = None,
|
||||||
):
|
):
|
||||||
self.grouping_variable = grouping_variable
|
|
||||||
|
self.grouping_variable_name = grouping_variable
|
||||||
|
self.grouping_variable = [grouping_variable]
|
||||||
|
|
||||||
self.data_types = features.keys()
|
self.data_types = features.keys()
|
||||||
|
|
||||||
|
self.participants_label: str = ""
|
||||||
if participants_usernames is None:
|
if participants_usernames is None:
|
||||||
participants_usernames = participants.query_db.get_usernames(
|
participants_usernames = participants.query_db.get_usernames(
|
||||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
||||||
)
|
)
|
||||||
|
self.participants_label = "all"
|
||||||
self.participants_usernames = participants_usernames
|
self.participants_usernames = participants_usernames
|
||||||
|
|
||||||
|
self.df_features_all = pd.DataFrame()
|
||||||
|
|
||||||
self.df_proximity = pd.DataFrame()
|
self.df_proximity = pd.DataFrame()
|
||||||
self.df_proximity_counts = pd.DataFrame()
|
self.df_proximity_counts = pd.DataFrame()
|
||||||
|
|
||||||
|
@ -33,19 +51,28 @@ class SensorFeatures:
|
||||||
self.df_sms = pd.DataFrame()
|
self.df_sms = pd.DataFrame()
|
||||||
self.df_calls_sms = pd.DataFrame()
|
self.df_calls_sms = pd.DataFrame()
|
||||||
|
|
||||||
|
self.folder = None
|
||||||
|
self.filename_prefix = ""
|
||||||
|
self.construct_export_path()
|
||||||
|
print("SensorFeatures initialized.")
|
||||||
|
|
||||||
def set_sensor_data(self):
|
def set_sensor_data(self):
|
||||||
|
print("Querying database ...")
|
||||||
if "proximity" in self.data_types:
|
if "proximity" in self.data_types:
|
||||||
self.df_proximity = proximity.get_proximity_data(
|
self.df_proximity = proximity.get_proximity_data(
|
||||||
self.participants_usernames
|
self.participants_usernames
|
||||||
)
|
)
|
||||||
|
print("Got proximity data from the DB.")
|
||||||
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
||||||
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
||||||
if "communication" in self.data_types:
|
if "communication" in self.data_types:
|
||||||
self.df_calls = communication.get_call_data(self.participants_usernames)
|
self.df_calls = communication.get_call_data(self.participants_usernames)
|
||||||
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
||||||
|
print("Got calls data from the DB.")
|
||||||
|
|
||||||
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
||||||
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
||||||
|
print("Got sms data from the DB.")
|
||||||
|
|
||||||
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
||||||
if data_type == "proximity":
|
if data_type == "proximity":
|
||||||
|
@ -56,15 +83,41 @@ class SensorFeatures:
|
||||||
raise KeyError("This data type has not been implemented.")
|
raise KeyError("This data type has not been implemented.")
|
||||||
|
|
||||||
def calculate_features(self):
|
def calculate_features(self):
|
||||||
|
print("Calculating features ...")
|
||||||
|
if not self.participants_label:
|
||||||
|
raise ValueError(WARNING_PARTICIPANTS_LABEL)
|
||||||
if "proximity" in self.data_types:
|
if "proximity" in self.data_types:
|
||||||
self.df_proximity_counts = proximity.count_proximity(
|
self.df_proximity_counts = proximity.count_proximity(
|
||||||
self.df_proximity, ["participant_id", self.grouping_variable]
|
self.df_proximity, self.grouping_variable
|
||||||
)
|
)
|
||||||
|
self.df_features_all = safe_outer_merge_on_index(
|
||||||
|
self.df_features_all, self.df_proximity_counts
|
||||||
|
)
|
||||||
|
print("Calculated proximity features.")
|
||||||
|
to_csv_with_settings(
|
||||||
|
self.df_proximity, self.folder, self.filename_prefix, data_type="prox"
|
||||||
|
)
|
||||||
|
|
||||||
if "communication" in self.data_types:
|
if "communication" in self.data_types:
|
||||||
self.df_calls_sms = communication.calls_sms_features(
|
self.df_calls_sms = communication.calls_sms_features(
|
||||||
df_calls=self.df_calls, df_sms=self.df_sms
|
df_calls=self.df_calls,
|
||||||
|
df_sms=self.df_sms,
|
||||||
|
group_by=self.grouping_variable,
|
||||||
)
|
)
|
||||||
# TODO Think about joining dataframes.
|
self.df_features_all = safe_outer_merge_on_index(
|
||||||
|
self.df_features_all, self.df_calls_sms
|
||||||
|
)
|
||||||
|
print("Calculated communication features.")
|
||||||
|
to_csv_with_settings(
|
||||||
|
self.df_calls_sms, self.folder, self.filename_prefix, data_type="comm"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.df_features_all.fillna(
|
||||||
|
value=proximity.FILL_NA_PROXIMITY, inplace=True, downcast="infer",
|
||||||
|
)
|
||||||
|
self.df_features_all.fillna(
|
||||||
|
value=communication.FILL_NA_CALLS_SMS_ALL, inplace=True, downcast="infer",
|
||||||
|
)
|
||||||
|
|
||||||
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
||||||
if data_type == "proximity":
|
if data_type == "proximity":
|
||||||
|
@ -75,14 +128,28 @@ class SensorFeatures:
|
||||||
if feature_names == "all":
|
if feature_names == "all":
|
||||||
feature_names = communication.FEATURES_CALLS_SMS_ALL
|
feature_names = communication.FEATURES_CALLS_SMS_ALL
|
||||||
return self.df_calls_sms[feature_names]
|
return self.df_calls_sms[feature_names]
|
||||||
|
elif data_type == "all":
|
||||||
|
return self.df_features_all
|
||||||
else:
|
else:
|
||||||
raise KeyError("This data type has not been implemented.")
|
raise KeyError("This data type has not been implemented.")
|
||||||
|
|
||||||
|
def construct_export_path(self):
|
||||||
|
if not self.participants_label:
|
||||||
|
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
|
||||||
|
self.folder = here("machine_learning/intermediate_results/features", warn=True)
|
||||||
|
self.filename_prefix = (
|
||||||
|
self.participants_label + "_" + self.grouping_variable_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_participants_label(self, label: str):
|
||||||
|
self.participants_label = label
|
||||||
|
self.construct_export_path()
|
||||||
|
|
||||||
|
|
||||||
class Labels:
|
class Labels:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
grouping_variable: str,
|
grouping_variable: list,
|
||||||
labels: dict,
|
labels: dict,
|
||||||
participants_usernames: Collection = None,
|
participants_usernames: Collection = None,
|
||||||
):
|
):
|
||||||
|
@ -101,9 +168,15 @@ class Labels:
|
||||||
self.df_esm_interest = pd.DataFrame()
|
self.df_esm_interest = pd.DataFrame()
|
||||||
self.df_esm_clean = pd.DataFrame()
|
self.df_esm_clean = pd.DataFrame()
|
||||||
|
|
||||||
|
self.df_esm_means = pd.DataFrame()
|
||||||
|
print("Labels initialized.")
|
||||||
|
|
||||||
def set_labels(self):
|
def set_labels(self):
|
||||||
|
print("Querying database ...")
|
||||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||||
|
print("Got ESM data from the DB.")
|
||||||
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
||||||
|
print("ESM data preprocessed.")
|
||||||
if "PANAS" in self.questionnaires:
|
if "PANAS" in self.questionnaires:
|
||||||
self.df_esm_interest = self.df_esm_preprocessed[
|
self.df_esm_interest = self.df_esm_preprocessed[
|
||||||
(
|
(
|
||||||
|
@ -116,6 +189,7 @@ class Labels:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||||
|
print("ESM data cleaned.")
|
||||||
|
|
||||||
def get_labels(self, questionnaire):
|
def get_labels(self, questionnaire):
|
||||||
if questionnaire == "PANAS":
|
if questionnaire == "PANAS":
|
||||||
|
@ -123,109 +197,131 @@ class Labels:
|
||||||
else:
|
else:
|
||||||
raise KeyError("This questionnaire has not been implemented as a label.")
|
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||||
|
|
||||||
|
def aggregate_labels(self):
|
||||||
class MachineLearningPipeline:
|
print("Aggregating labels ...")
|
||||||
def __init__(
|
self.df_esm_means = (
|
||||||
self,
|
self.df_esm_clean.groupby(
|
||||||
labels_questionnaire,
|
["participant_id", "questionnaire_id"] + self.grouping_variable
|
||||||
labels_scale,
|
|
||||||
data_types,
|
|
||||||
participants_usernames=None,
|
|
||||||
feature_names=None,
|
|
||||||
grouping_variable=None,
|
|
||||||
):
|
|
||||||
if participants_usernames is None:
|
|
||||||
participants_usernames = participants.query_db.get_usernames(
|
|
||||||
collection_start=datetime.date.fromisoformat("2020-08-01")
|
|
||||||
)
|
)
|
||||||
self.participants_usernames = participants_usernames
|
.esm_user_answer_numeric.agg("mean")
|
||||||
self.labels_questionnaire = labels_questionnaire
|
.reset_index()
|
||||||
self.data_types = data_types
|
.rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
||||||
|
|
||||||
if feature_names is None:
|
|
||||||
self.feature_names = []
|
|
||||||
self.df_features = pd.DataFrame()
|
|
||||||
self.labels_scale = labels_scale
|
|
||||||
self.df_labels = pd.DataFrame()
|
|
||||||
self.grouping_variable = grouping_variable
|
|
||||||
self.df_groups = pd.DataFrame()
|
|
||||||
|
|
||||||
self.model = None
|
|
||||||
self.validation_method = None
|
|
||||||
|
|
||||||
self.df_esm = pd.DataFrame()
|
|
||||||
self.df_esm_preprocessed = pd.DataFrame()
|
|
||||||
self.df_esm_interest = pd.DataFrame()
|
|
||||||
self.df_esm_clean = pd.DataFrame()
|
|
||||||
|
|
||||||
self.df_full_data_daily_means = pd.DataFrame()
|
|
||||||
self.df_esm_daily_means = pd.DataFrame()
|
|
||||||
self.df_proximity_daily_counts = pd.DataFrame()
|
|
||||||
|
|
||||||
# def get_labels(self):
|
|
||||||
# self.df_esm = esm.get_esm_data(self.participants_usernames)
|
|
||||||
# self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
|
||||||
# if self.labels_questionnaire == "PANAS":
|
|
||||||
# self.df_esm_interest = self.df_esm_preprocessed[
|
|
||||||
# (
|
|
||||||
# self.df_esm_preprocessed["questionnaire_id"]
|
|
||||||
# == QUESTIONNAIRE_IDS.get("PANAS").get("PA")
|
|
||||||
# )
|
|
||||||
# | (
|
|
||||||
# self.df_esm_preprocessed["questionnaire_id"]
|
|
||||||
# == QUESTIONNAIRE_IDS.get("PANAS").get("NA")
|
|
||||||
# )
|
|
||||||
# ]
|
|
||||||
# self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
|
||||||
|
|
||||||
# def aggregate_daily(self):
|
|
||||||
# self.df_esm_daily_means = (
|
|
||||||
# self.df_esm_clean.groupby(["participant_id", "date_lj", "questionnaire_id"])
|
|
||||||
# .esm_user_answer_numeric.agg("mean")
|
|
||||||
# .reset_index()
|
|
||||||
# .rename(columns={"esm_user_answer_numeric": "esm_numeric_mean"})
|
|
||||||
# )
|
|
||||||
# self.df_esm_daily_means = (
|
|
||||||
# self.df_esm_daily_means.pivot(
|
|
||||||
# index=["participant_id", "date_lj"],
|
|
||||||
# columns="questionnaire_id",
|
|
||||||
# values="esm_numeric_mean",
|
|
||||||
# )
|
|
||||||
# .reset_index(col_level=1)
|
|
||||||
# .rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
|
||||||
# .set_index(["participant_id", "date_lj"])
|
|
||||||
# )
|
|
||||||
# self.df_full_data_daily_means = self.df_esm_daily_means.copy()
|
|
||||||
# if "proximity" in self.data_types:
|
|
||||||
# self.df_proximity_daily_counts = proximity.count_proximity(
|
|
||||||
# self.df_proximity, ["participant_id", "date_lj"]
|
|
||||||
# )
|
|
||||||
# self.df_full_data_daily_means = self.df_full_data_daily_means.join(
|
|
||||||
# self.df_proximity_daily_counts
|
|
||||||
# )
|
|
||||||
|
|
||||||
def assign_columns(self):
|
|
||||||
self.df_features = self.df_full_data_daily_means[self.feature_names]
|
|
||||||
self.df_labels = self.df_full_data_daily_means[self.labels_scale]
|
|
||||||
if self.grouping_variable:
|
|
||||||
self.df_groups = self.df_full_data_daily_means[self.grouping_variable]
|
|
||||||
else:
|
|
||||||
self.df_groups = None
|
|
||||||
|
|
||||||
def validate_model(self):
|
|
||||||
if self.model is None:
|
|
||||||
raise AttributeError(
|
|
||||||
"Please, specify a machine learning model first, by setting the .model attribute."
|
|
||||||
)
|
|
||||||
if self.validation_method is None:
|
|
||||||
raise AttributeError(
|
|
||||||
"Please, specify a cross validation method first, by setting the .validation_method attribute."
|
|
||||||
)
|
|
||||||
cross_val_score(
|
|
||||||
estimator=self.model,
|
|
||||||
X=self.df_features,
|
|
||||||
y=self.df_labels,
|
|
||||||
groups=self.df_groups,
|
|
||||||
cv=self.validation_method,
|
|
||||||
n_jobs=-1,
|
|
||||||
)
|
)
|
||||||
|
self.df_esm_means = (
|
||||||
|
self.df_esm_means.pivot(
|
||||||
|
index=["participant_id"] + self.grouping_variable,
|
||||||
|
columns="questionnaire_id",
|
||||||
|
values="esm_numeric_mean",
|
||||||
|
)
|
||||||
|
.reset_index(col_level=1)
|
||||||
|
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
||||||
|
.set_index(["participant_id"] + self.grouping_variable)
|
||||||
|
)
|
||||||
|
print("Labels aggregated.")
|
||||||
|
|
||||||
|
def get_aggregated_labels(self):
|
||||||
|
return self.df_esm_means
|
||||||
|
|
||||||
|
|
||||||
|
class ModelValidation:
|
||||||
|
def __init__(self, X, y, group_variable=None, cv_name="loso"):
|
||||||
|
self.model = None
|
||||||
|
self.cv = None
|
||||||
|
|
||||||
|
idx_common = X.index.intersection(y.index)
|
||||||
|
self.y = y.loc[idx_common, "NA"]
|
||||||
|
# TODO Handle the case of multiple labels.
|
||||||
|
self.X = X.loc[idx_common]
|
||||||
|
self.groups = self.y.index.get_level_values(group_variable)
|
||||||
|
|
||||||
|
self.cv_name = cv_name
|
||||||
|
print("ModelValidation initialized.")
|
||||||
|
|
||||||
|
def set_cv_method(self):
|
||||||
|
if self.cv_name == "loso":
|
||||||
|
self.cv = LeaveOneGroupOut()
|
||||||
|
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
|
||||||
|
print("Validation method set.")
|
||||||
|
|
||||||
|
def cross_validate(self):
|
||||||
|
print("Running cross validation ...")
|
||||||
|
if self.model is None:
|
||||||
|
raise TypeError(
|
||||||
|
"Please, specify a machine learning model first, by setting the .model attribute. "
|
||||||
|
"E.g. self.model = sklearn.linear_model.LinearRegression()"
|
||||||
|
)
|
||||||
|
if self.cv is None:
|
||||||
|
raise TypeError(
|
||||||
|
"Please, specify a cross validation method first, by using set_cv_method() first."
|
||||||
|
)
|
||||||
|
if self.X.isna().any().any() or self.y.isna().any().any():
|
||||||
|
raise ValueError(
|
||||||
|
"NaNs were found in either X or y. Please, check your data before continuing."
|
||||||
|
)
|
||||||
|
return cross_val_score(
|
||||||
|
estimator=self.model,
|
||||||
|
X=self.X,
|
||||||
|
y=self.y,
|
||||||
|
groups=self.groups,
|
||||||
|
cv=self.cv,
|
||||||
|
n_jobs=-1,
|
||||||
|
scoring="r2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_outer_merge_on_index(left, right):
|
||||||
|
if left.empty:
|
||||||
|
return right
|
||||||
|
elif right.empty:
|
||||||
|
return left
|
||||||
|
else:
|
||||||
|
return pd.merge(
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
how="outer",
|
||||||
|
left_index=True,
|
||||||
|
right_index=True,
|
||||||
|
validate="one_to_one",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_csv_with_settings(
|
||||||
|
df: pd.DataFrame, folder: Path, filename_prefix: str, data_type: str
|
||||||
|
) -> None:
|
||||||
|
export_filename = filename_prefix + "_" + data_type + ".csv"
|
||||||
|
full_path = folder / export_filename
|
||||||
|
df.to_csv(
|
||||||
|
path_or_buf=full_path,
|
||||||
|
sep=",",
|
||||||
|
na_rep="NA",
|
||||||
|
header=True,
|
||||||
|
index=False,
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
print("Exported the dataframe to " + str(full_path))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open("./config/prox_comm_PANAS_features.yaml", "r") as file:
|
||||||
|
sensor_features_params = yaml.safe_load(file)
|
||||||
|
sensor_features = SensorFeatures(**sensor_features_params)
|
||||||
|
sensor_features.set_sensor_data()
|
||||||
|
sensor_features.calculate_features()
|
||||||
|
|
||||||
|
with open("./config/prox_comm_PANAS_labels.yaml", "r") as file:
|
||||||
|
labels_params = yaml.safe_load(file)
|
||||||
|
labels = Labels(**labels_params)
|
||||||
|
labels.set_labels()
|
||||||
|
labels.aggregate_labels()
|
||||||
|
|
||||||
|
model_validation = ModelValidation(
|
||||||
|
sensor_features.get_features("all", "all"),
|
||||||
|
labels.get_aggregated_labels(),
|
||||||
|
group_variable="participant_id",
|
||||||
|
cv_name="loso",
|
||||||
|
)
|
||||||
|
model_validation.model = linear_model.LinearRegression()
|
||||||
|
model_validation.set_cv_method()
|
||||||
|
model_loso_r2 = model_validation.cross_validate()
|
||||||
|
print(model_loso_r2)
|
||||||
|
print(np.mean(model_loso_r2))
|
||||||
|
|
Loading…
Reference in New Issue