diff --git a/exploration/ml_pipeline_classification_composite.py b/exploration/ml_pipeline_classification_composite.py index 97ca903..39330d9 100644 --- a/exploration/ml_pipeline_classification_composite.py +++ b/exploration/ml_pipeline_classification_composite.py @@ -102,8 +102,9 @@ model_input["target"], edges = pd.cut( ) # ['low', 'medium', 'high'] print(model_input["target"].value_counts()) REMOVE_MEDIUM = True -if ("medium" in model_input["target"]) and REMOVE_MEDIUM: - model_input = model_input[model_input["target"] != "medium"] +if REMOVE_MEDIUM: + if "medium" in model_input["target"]: + model_input = model_input[model_input["target"] != "medium"] model_input["target"] = ( model_input["target"].astype(str).apply(lambda x: 0 if x == "low" else 1) )