diff --git a/machine_learning/pipeline.py b/machine_learning/pipeline.py index 572c7a5..223131c 100644 --- a/machine_learning/pipeline.py +++ b/machine_learning/pipeline.py @@ -40,8 +40,10 @@ class SensorFeatures: self.df_proximity = proximity.recode_proximity(self.df_proximity) def get_sensor_data(self, data_type) -> pd.DataFrame: - # TODO implement the getter (Check if it has been set.) - return self.df_proximity + if data_type == "proximity": + return self.df_proximity + else: + raise KeyError("This data type has not been implemented.") def calculate_features(self): if "proximity" in self.data_types: @@ -50,9 +52,13 @@ class SensorFeatures: ) # TODO Think about joining dataframes. - def get_features(self, data_type) -> pd.DataFrame: - # TODO implement the getter (Check if it has been set.) - return self.df_proximity_counts + def get_features(self, data_type, feature_names) -> pd.DataFrame: + if data_type == "proximity": + if feature_names == "all": + feature_names = proximity.FEATURES_PROXIMITY + return self.df_proximity_counts[feature_names] + else: + raise KeyError("This data type has not been implemented.") class MachineLearningPipeline: