Add some print statements for monitoring progress.
parent
d19995385d
commit
11381d6447
|
@ -40,14 +40,17 @@ class SensorFeatures:
|
||||||
self.df_proximity = proximity.get_proximity_data(
|
self.df_proximity = proximity.get_proximity_data(
|
||||||
self.participants_usernames
|
self.participants_usernames
|
||||||
)
|
)
|
||||||
|
print("Got proximity data from the DB.")
|
||||||
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
||||||
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
||||||
if "communication" in self.data_types:
|
if "communication" in self.data_types:
|
||||||
self.df_calls = communication.get_call_data(self.participants_usernames)
|
self.df_calls = communication.get_call_data(self.participants_usernames)
|
||||||
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
||||||
|
print("Got calls data from the DB.")
|
||||||
|
|
||||||
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
||||||
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
||||||
|
print("Got sms data from the DB.")
|
||||||
|
|
||||||
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
||||||
if data_type == "proximity":
|
if data_type == "proximity":
|
||||||
|
@ -65,6 +68,7 @@ class SensorFeatures:
|
||||||
self.df_features_all = safe_outer_merge_on_index(
|
self.df_features_all = safe_outer_merge_on_index(
|
||||||
self.df_features_all, self.df_proximity_counts
|
self.df_features_all, self.df_proximity_counts
|
||||||
)
|
)
|
||||||
|
print("Calculated proximity features.")
|
||||||
|
|
||||||
if "communication" in self.data_types:
|
if "communication" in self.data_types:
|
||||||
self.df_calls_sms = communication.calls_sms_features(
|
self.df_calls_sms = communication.calls_sms_features(
|
||||||
|
@ -80,6 +84,7 @@ class SensorFeatures:
|
||||||
inplace=True,
|
inplace=True,
|
||||||
downcast="infer",
|
downcast="infer",
|
||||||
)
|
)
|
||||||
|
print("Calculated communication features.")
|
||||||
|
|
||||||
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
||||||
if data_type == "proximity":
|
if data_type == "proximity":
|
||||||
|
@ -122,7 +127,9 @@ class Labels:
|
||||||
|
|
||||||
def set_labels(self):
|
def set_labels(self):
|
||||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||||
|
print("Got ESM data from the DB.")
|
||||||
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
||||||
|
print("ESM data preprocessed.")
|
||||||
if "PANAS" in self.questionnaires:
|
if "PANAS" in self.questionnaires:
|
||||||
self.df_esm_interest = self.df_esm_preprocessed[
|
self.df_esm_interest = self.df_esm_preprocessed[
|
||||||
(
|
(
|
||||||
|
@ -135,6 +142,7 @@ class Labels:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||||
|
print("ESM data cleaned.")
|
||||||
|
|
||||||
def get_labels(self, questionnaire):
|
def get_labels(self, questionnaire):
|
||||||
if questionnaire == "PANAS":
|
if questionnaire == "PANAS":
|
||||||
|
@ -161,6 +169,7 @@ class Labels:
|
||||||
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
||||||
.set_index(["participant_id"] + self.grouping_variable)
|
.set_index(["participant_id"] + self.grouping_variable)
|
||||||
)
|
)
|
||||||
|
print("Labels aggregated.")
|
||||||
|
|
||||||
def get_aggregated_labels(self):
|
def get_aggregated_labels(self):
|
||||||
return self.df_esm_means
|
return self.df_esm_means
|
||||||
|
@ -178,11 +187,13 @@ class ModelValidation:
|
||||||
self.groups = self.y.index.get_level_values(group_variable)
|
self.groups = self.y.index.get_level_values(group_variable)
|
||||||
|
|
||||||
self.cv_name = cv_name
|
self.cv_name = cv_name
|
||||||
|
print("ModelValidation initialized.")
|
||||||
|
|
||||||
def set_cv_method(self):
|
def set_cv_method(self):
|
||||||
if self.cv_name == "loso":
|
if self.cv_name == "loso":
|
||||||
self.cv = LeaveOneGroupOut()
|
self.cv = LeaveOneGroupOut()
|
||||||
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
|
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
|
||||||
|
print("Validation method set.")
|
||||||
|
|
||||||
def cross_validate(self):
|
def cross_validate(self):
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
|
|
Loading…
Reference in New Issue