Merge features into a common df.
But first, group communication by the grouping_variable.rapids
parent
72b16af75c
commit
08fdec34f1
|
@ -159,11 +159,15 @@ from machine_learning import pipeline
|
||||||
# %%
|
# %%
|
||||||
with open("../machine_learning/config/minimal_features.yaml", "r") as file:
|
with open("../machine_learning/config/minimal_features.yaml", "r") as file:
|
||||||
sensor_features_params = yaml.safe_load(file)
|
sensor_features_params = yaml.safe_load(file)
|
||||||
|
print(sensor_features_params)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
sensor_features = pipeline.SensorFeatures(**sensor_features_params)
|
sensor_features = pipeline.SensorFeatures(**sensor_features_params)
|
||||||
sensor_features.data_types
|
sensor_features.data_types
|
||||||
|
|
||||||
|
# %%
|
||||||
|
sensor_features.data_types = ["proximity", "communication"]
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
sensor_features.get_sensor_data("proximity")
|
sensor_features.get_sensor_data("proximity")
|
||||||
|
|
||||||
|
@ -179,6 +183,12 @@ sensor_features.calculate_features()
|
||||||
# %%
|
# %%
|
||||||
sensor_features.get_features("proximity", "all")
|
sensor_features.get_features("proximity", "all")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
sensor_features.get_features("communication", "all")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
sensor_features.get_features("all", "all")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
with open("../machine_learning/config/minimal_labels.yaml", "r") as file:
|
||||||
labels_params = yaml.safe_load(file)
|
labels_params = yaml.safe_load(file)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
grouping_variable: date_lj
|
grouping_variable: [date_lj]
|
||||||
features:
|
features:
|
||||||
proximity:
|
proximity:
|
||||||
all
|
all
|
||||||
|
|
|
@ -12,7 +12,7 @@ from machine_learning import QUESTIONNAIRE_IDS, QUESTIONNAIRE_IDS_RENAME
|
||||||
class SensorFeatures:
|
class SensorFeatures:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
grouping_variable: str,
|
grouping_variable: list,
|
||||||
features: dict,
|
features: dict,
|
||||||
participants_usernames: Collection = None,
|
participants_usernames: Collection = None,
|
||||||
):
|
):
|
||||||
|
@ -26,6 +26,8 @@ class SensorFeatures:
|
||||||
)
|
)
|
||||||
self.participants_usernames = participants_usernames
|
self.participants_usernames = participants_usernames
|
||||||
|
|
||||||
|
self.df_features_all = pd.DataFrame()
|
||||||
|
|
||||||
self.df_proximity = pd.DataFrame()
|
self.df_proximity = pd.DataFrame()
|
||||||
self.df_proximity_counts = pd.DataFrame()
|
self.df_proximity_counts = pd.DataFrame()
|
||||||
|
|
||||||
|
@ -58,13 +60,21 @@ class SensorFeatures:
|
||||||
def calculate_features(self):
|
def calculate_features(self):
|
||||||
if "proximity" in self.data_types:
|
if "proximity" in self.data_types:
|
||||||
self.df_proximity_counts = proximity.count_proximity(
|
self.df_proximity_counts = proximity.count_proximity(
|
||||||
self.df_proximity, ["participant_id", self.grouping_variable]
|
self.df_proximity, self.grouping_variable
|
||||||
)
|
)
|
||||||
|
self.df_features_all = safe_outer_merge_on_index(
|
||||||
|
self.df_features_all, self.df_proximity_counts
|
||||||
|
)
|
||||||
|
|
||||||
if "communication" in self.data_types:
|
if "communication" in self.data_types:
|
||||||
self.df_calls_sms = communication.calls_sms_features(
|
self.df_calls_sms = communication.calls_sms_features(
|
||||||
df_calls=self.df_calls, df_sms=self.df_sms
|
df_calls=self.df_calls,
|
||||||
|
df_sms=self.df_sms,
|
||||||
|
group_by=self.grouping_variable,
|
||||||
|
)
|
||||||
|
self.df_features_all = safe_outer_merge_on_index(
|
||||||
|
self.df_features_all, self.df_calls_sms
|
||||||
)
|
)
|
||||||
# TODO Think about joining dataframes.
|
|
||||||
|
|
||||||
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
||||||
if data_type == "proximity":
|
if data_type == "proximity":
|
||||||
|
@ -75,6 +85,8 @@ class SensorFeatures:
|
||||||
if feature_names == "all":
|
if feature_names == "all":
|
||||||
feature_names = communication.FEATURES_CALLS_SMS_ALL
|
feature_names = communication.FEATURES_CALLS_SMS_ALL
|
||||||
return self.df_calls_sms[feature_names]
|
return self.df_calls_sms[feature_names]
|
||||||
|
elif data_type == "all":
|
||||||
|
return self.df_features_all
|
||||||
else:
|
else:
|
||||||
raise KeyError("This data type has not been implemented.")
|
raise KeyError("This data type has not been implemented.")
|
||||||
|
|
||||||
|
@ -124,6 +136,22 @@ class Labels:
|
||||||
raise KeyError("This questionnaire has not been implemented as a label.")
|
raise KeyError("This questionnaire has not been implemented as a label.")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_outer_merge_on_index(left, right):
|
||||||
|
if left.empty:
|
||||||
|
return right
|
||||||
|
elif right.empty:
|
||||||
|
return left
|
||||||
|
else:
|
||||||
|
return pd.merge(
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
how="outer",
|
||||||
|
left_index=True,
|
||||||
|
right_index=True,
|
||||||
|
validate="one_to_one",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MachineLearningPipeline:
|
class MachineLearningPipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue