diff --git a/machine_learning/pipeline.py b/machine_learning/pipeline.py index efc4cb1..bada6df 100644 --- a/machine_learning/pipeline.py +++ b/machine_learning/pipeline.py @@ -171,9 +171,10 @@ class ModelValidation: self.model = None self.cv = None - self.y = y["NA"] + idx_common = X.index.intersection(y.index) + self.y = y.loc[idx_common, "NA"] # TODO Handle the case of multiple labels. - self.X = X.loc[self.y.index] + self.X = X.loc[idx_common] self.groups = self.y.index.get_level_values(group_variable) self.cv_name = cv_name