Add some print statements for monitoring progress.

rapids
junos 2021-08-21 18:54:02 +02:00
parent d19995385d
commit 11381d6447
1 changed files with 11 additions and 0 deletions

View File

@ -40,14 +40,17 @@ class SensorFeatures:
self.df_proximity = proximity.get_proximity_data(
self.participants_usernames
)
print("Got proximity data from the DB.")
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
self.df_proximity = proximity.recode_proximity(self.df_proximity)
if "communication" in self.data_types:
self.df_calls = communication.get_call_data(self.participants_usernames)
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
print("Got calls data from the DB.")
self.df_sms = communication.get_sms_data(self.participants_usernames)
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
print("Got sms data from the DB.")
def get_sensor_data(self, data_type) -> pd.DataFrame:
if data_type == "proximity":
@ -65,6 +68,7 @@ class SensorFeatures:
self.df_features_all = safe_outer_merge_on_index(
self.df_features_all, self.df_proximity_counts
)
print("Calculated proximity features.")
if "communication" in self.data_types:
self.df_calls_sms = communication.calls_sms_features(
@ -80,6 +84,7 @@ class SensorFeatures:
inplace=True,
downcast="infer",
)
print("Calculated communication features.")
def get_features(self, data_type, feature_names) -> pd.DataFrame:
if data_type == "proximity":
@ -122,7 +127,9 @@ class Labels:
def set_labels(self):
self.df_esm = esm.get_esm_data(self.participants_usernames)
print("Got ESM data from the DB.")
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
print("ESM data preprocessed.")
if "PANAS" in self.questionnaires:
self.df_esm_interest = self.df_esm_preprocessed[
(
@ -135,6 +142,7 @@ class Labels:
)
]
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
print("ESM data cleaned.")
def get_labels(self, questionnaire):
if questionnaire == "PANAS":
@ -161,6 +169,7 @@ class Labels:
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
.set_index(["participant_id"] + self.grouping_variable)
)
print("Labels aggregated.")
def get_aggregated_labels(self):
return self.df_esm_means
@ -178,11 +187,13 @@ class ModelValidation:
self.groups = self.y.index.get_level_values(group_variable)
self.cv_name = cv_name
print("ModelValidation initialized.")
def set_cv_method(self):
if self.cv_name == "loso":
self.cv = LeaveOneGroupOut()
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
print("Validation method set.")
def cross_validate(self):
if self.model is None: