232 lines
8.9 KiB
Python
232 lines
8.9 KiB
Python
import datetime
|
|
import warnings
|
|
from pathlib import Path
|
|
from typing import Collection
|
|
|
|
import pandas as pd
|
|
from pyprojroot import here
|
|
|
|
import participants.query_db
|
|
from features import communication, helper, proximity
|
|
from machine_learning.helper import (
|
|
read_csv_with_settings,
|
|
safe_outer_merge_on_index,
|
|
to_csv_with_settings,
|
|
)
|
|
|
|
WARNING_PARTICIPANTS_LABEL = (
|
|
"Before calculating features, please set participants label using self.set_participants_label() "
|
|
"to be used as a filename prefix when exporting data. "
|
|
"The filename will be of the form: %participants_label_%grouping_variable_%data_type.csv"
|
|
)
|
|
|
|
|
|
class SensorFeatures:
|
|
"""
|
|
A class to represent all sensor (AWARE) features.
|
|
|
|
Attributes
|
|
----------
|
|
grouping_variable: str
|
|
The name of the variable by which to group (segment) data, e.g. date_lj.
|
|
features: dict
|
|
A dictionary of sensors (data types) and features to calculate.
|
|
See config/minimal_features.yaml for an example.
|
|
participants_usernames: Collection
|
|
A list of usernames for which to calculate features.
|
|
If None, use all participants.
|
|
|
|
Methods
|
|
-------
|
|
set_sensor_data():
|
|
Query the database for data types defined by self.features.
|
|
get_sensor_data(data_type): pd.DataFrame
|
|
Returns the dataframe of sensor data for specified data_type.
|
|
calculate_features():
|
|
Calls appropriate functions from features/ and joins them in a single dataframe, df_features_all.
|
|
get_features(data_type, feature_names): pd.DataFrame
|
|
Returns the dataframe of specified features for selected sensor.
|
|
|
|
construct_export_path():
|
|
Construct a path for exporting the features as csv files.
|
|
set_participants_label(label):
|
|
Sets a label for the usernames subset. This is used to distinguish feature exports.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
grouping_variable: str,
|
|
features: dict,
|
|
participants_usernames: Collection = None,
|
|
) -> None:
|
|
"""
|
|
Specifies the grouping variable and usernames for which to calculate features.
|
|
Sets other (implicit) attributes used in other methods.
|
|
If participants_usernames=None, this queries the usernames which belong to the main part of the study,
|
|
i.e. from 2020-08-01 on.
|
|
|
|
Parameters
|
|
----------
|
|
grouping_variable: str
|
|
The name of the variable by which to group (segment) data, e.g. date_lj.
|
|
features: dict
|
|
A dictionary of sensors (data types) and features to calculate.
|
|
See config/minimal_features.yaml for an example.
|
|
participants_usernames: Collection
|
|
A list of usernames for which to calculate features.
|
|
If None, use all participants.
|
|
|
|
Returns
|
|
-------
|
|
None
|
|
"""
|
|
self.grouping_variable_name = grouping_variable
|
|
self.grouping_variable = [grouping_variable]
|
|
|
|
self.data_types = features.keys()
|
|
|
|
self.participants_label: str = ""
|
|
if participants_usernames is None:
|
|
participants_usernames = participants.query_db.get_usernames(
|
|
collection_start=datetime.date.fromisoformat("2020-08-01")
|
|
)
|
|
self.participants_label = "all"
|
|
self.participants_usernames = participants_usernames
|
|
|
|
self.df_features_all = pd.DataFrame()
|
|
|
|
self.df_proximity = pd.DataFrame()
|
|
self.df_proximity_counts = pd.DataFrame()
|
|
|
|
self.df_calls = pd.DataFrame()
|
|
self.df_sms = pd.DataFrame()
|
|
self.df_calls_sms = pd.DataFrame()
|
|
|
|
self.folder: Path = Path()
|
|
self.filename_prefix = ""
|
|
self.construct_export_path()
|
|
print("SensorFeatures initialized.")
|
|
|
|
def set_sensor_data(self) -> None:
|
|
print("Querying database ...")
|
|
if "proximity" in self.data_types:
|
|
self.df_proximity = proximity.get_proximity_data(
|
|
self.participants_usernames
|
|
)
|
|
print("Got proximity data from the DB.")
|
|
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
|
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
|
if "communication" in self.data_types:
|
|
self.df_calls = communication.get_call_data(self.participants_usernames)
|
|
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
|
print("Got calls data from the DB.")
|
|
|
|
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
|
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
|
print("Got sms data from the DB.")
|
|
|
|
def get_sensor_data(self, data_type: str) -> pd.DataFrame:
|
|
if data_type == "proximity":
|
|
return self.df_proximity
|
|
elif data_type == "communication":
|
|
return self.df_calls_sms
|
|
else:
|
|
raise KeyError("This data type has not been implemented.")
|
|
|
|
def calculate_features(self, cached=True) -> None:
|
|
print("Calculating features ...")
|
|
if not self.participants_label:
|
|
raise ValueError(WARNING_PARTICIPANTS_LABEL)
|
|
self.df_features_all = pd.DataFrame()
|
|
|
|
if "proximity" in self.data_types:
|
|
try:
|
|
if not cached: # Do not use the file, even if it exists.
|
|
raise FileNotFoundError
|
|
self.df_proximity_counts = read_csv_with_settings(
|
|
self.folder,
|
|
self.filename_prefix,
|
|
data_type="prox",
|
|
grouping_variable=self.grouping_variable,
|
|
)
|
|
print("Read proximity features from the file.")
|
|
except FileNotFoundError:
|
|
# We need to recalculate the features in this case.
|
|
self.df_proximity_counts = proximity.count_proximity(
|
|
self.df_proximity, self.grouping_variable
|
|
)
|
|
print("Calculated proximity features.")
|
|
to_csv_with_settings(
|
|
self.df_proximity_counts,
|
|
self.folder,
|
|
self.filename_prefix,
|
|
data_type="prox",
|
|
)
|
|
finally:
|
|
self.df_features_all = safe_outer_merge_on_index(
|
|
self.df_features_all, self.df_proximity_counts
|
|
)
|
|
|
|
if "communication" in self.data_types:
|
|
try:
|
|
if not cached: # Do not use the file, even if it exists.
|
|
raise FileNotFoundError
|
|
self.df_calls_sms = read_csv_with_settings(
|
|
self.folder,
|
|
self.filename_prefix,
|
|
data_type="comm",
|
|
grouping_variable=self.grouping_variable,
|
|
)
|
|
print("Read communication features from the file.")
|
|
except FileNotFoundError:
|
|
# We need to recalculate the features in this case.
|
|
self.df_calls_sms = communication.calls_sms_features(
|
|
df_calls=self.df_calls,
|
|
df_sms=self.df_sms,
|
|
group_by=self.grouping_variable,
|
|
)
|
|
print("Calculated communication features.")
|
|
to_csv_with_settings(
|
|
self.df_calls_sms,
|
|
self.folder,
|
|
self.filename_prefix,
|
|
data_type="comm",
|
|
)
|
|
finally:
|
|
self.df_features_all = safe_outer_merge_on_index(
|
|
self.df_features_all, self.df_calls_sms
|
|
)
|
|
|
|
self.df_features_all.fillna(
|
|
value=proximity.FILL_NA_PROXIMITY, inplace=True, downcast="infer",
|
|
)
|
|
self.df_features_all.fillna(
|
|
value=communication.FILL_NA_CALLS_SMS_ALL, inplace=True, downcast="infer",
|
|
)
|
|
|
|
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
|
if data_type == "proximity":
|
|
if feature_names == "all":
|
|
feature_names = proximity.FEATURES_PROXIMITY
|
|
return self.df_proximity_counts[feature_names]
|
|
elif data_type == "communication":
|
|
if feature_names == "all":
|
|
feature_names = communication.FEATURES_CALLS_SMS_ALL
|
|
return self.df_calls_sms[feature_names]
|
|
elif data_type == "all":
|
|
return self.df_features_all
|
|
else:
|
|
raise KeyError("This data type has not been implemented.")
|
|
|
|
def construct_export_path(self) -> None:
|
|
if not self.participants_label:
|
|
warnings.warn(WARNING_PARTICIPANTS_LABEL, UserWarning)
|
|
self.folder = here("machine_learning/intermediate_results/features", warn=True)
|
|
self.filename_prefix = (
|
|
self.participants_label + "_" + self.grouping_variable_name
|
|
)
|
|
|
|
def set_participants_label(self, label: str) -> None:
|
|
self.participants_label = label
|
|
self.construct_export_path()
|