Add some print statements for monitoring progress.
parent
d19995385d
commit
11381d6447
|
@ -40,14 +40,17 @@ class SensorFeatures:
|
|||
self.df_proximity = proximity.get_proximity_data(
|
||||
self.participants_usernames
|
||||
)
|
||||
print("Got proximity data from the DB.")
|
||||
self.df_proximity = helper.get_date_from_timestamp(self.df_proximity)
|
||||
self.df_proximity = proximity.recode_proximity(self.df_proximity)
|
||||
if "communication" in self.data_types:
|
||||
self.df_calls = communication.get_call_data(self.participants_usernames)
|
||||
self.df_calls = helper.get_date_from_timestamp(self.df_calls)
|
||||
print("Got calls data from the DB.")
|
||||
|
||||
self.df_sms = communication.get_sms_data(self.participants_usernames)
|
||||
self.df_sms = helper.get_date_from_timestamp(self.df_sms)
|
||||
print("Got sms data from the DB.")
|
||||
|
||||
def get_sensor_data(self, data_type) -> pd.DataFrame:
|
||||
if data_type == "proximity":
|
||||
|
@ -65,6 +68,7 @@ class SensorFeatures:
|
|||
self.df_features_all = safe_outer_merge_on_index(
|
||||
self.df_features_all, self.df_proximity_counts
|
||||
)
|
||||
print("Calculated proximity features.")
|
||||
|
||||
if "communication" in self.data_types:
|
||||
self.df_calls_sms = communication.calls_sms_features(
|
||||
|
@ -80,6 +84,7 @@ class SensorFeatures:
|
|||
inplace=True,
|
||||
downcast="infer",
|
||||
)
|
||||
print("Calculated communication features.")
|
||||
|
||||
def get_features(self, data_type, feature_names) -> pd.DataFrame:
|
||||
if data_type == "proximity":
|
||||
|
@ -122,7 +127,9 @@ class Labels:
|
|||
|
||||
def set_labels(self):
|
||||
self.df_esm = esm.get_esm_data(self.participants_usernames)
|
||||
print("Got ESM data from the DB.")
|
||||
self.df_esm_preprocessed = esm.preprocess_esm(self.df_esm)
|
||||
print("ESM data preprocessed.")
|
||||
if "PANAS" in self.questionnaires:
|
||||
self.df_esm_interest = self.df_esm_preprocessed[
|
||||
(
|
||||
|
@ -135,6 +142,7 @@ class Labels:
|
|||
)
|
||||
]
|
||||
self.df_esm_clean = esm.clean_up_esm(self.df_esm_interest)
|
||||
print("ESM data cleaned.")
|
||||
|
||||
def get_labels(self, questionnaire):
|
||||
if questionnaire == "PANAS":
|
||||
|
@ -161,6 +169,7 @@ class Labels:
|
|||
.rename(columns=QUESTIONNAIRE_IDS_RENAME)
|
||||
.set_index(["participant_id"] + self.grouping_variable)
|
||||
)
|
||||
print("Labels aggregated.")
|
||||
|
||||
def get_aggregated_labels(self):
|
||||
return self.df_esm_means
|
||||
|
@ -178,11 +187,13 @@ class ModelValidation:
|
|||
self.groups = self.y.index.get_level_values(group_variable)
|
||||
|
||||
self.cv_name = cv_name
|
||||
print("ModelValidation initialized.")
|
||||
|
||||
def set_cv_method(self):
|
||||
if self.cv_name == "loso":
|
||||
self.cv = LeaveOneGroupOut()
|
||||
self.cv.get_n_splits(X=self.X, y=self.y, groups=self.groups)
|
||||
print("Validation method set.")
|
||||
|
||||
def cross_validate(self):
|
||||
if self.model is None:
|
||||
|
|
Loading…
Reference in New Issue