diff --git a/exploration/ml_pipeline_classification.py b/exploration/ml_pipeline_classification.py index 84301b5..7a46401 100644 --- a/exploration/ml_pipeline_classification.py +++ b/exploration/ml_pipeline_classification.py @@ -82,11 +82,15 @@ print(model_input["target"].value_counts()) REMOVE_MEDIUM = True if ("medium" in model_input["target"]) and REMOVE_MEDIUM: model_input = model_input[model_input["target"] != "medium"] + model_input["target"] = ( + model_input["target"].astype(str).apply(lambda x: 0 if x == "low" else 1) + ) +else: + model_input["target"] = model_input["target"].map( + {"low": 0, "medium": 1, "high": 2} + ) print(model_input["target"].value_counts()) -model_input["target"] = ( - model_input["target"].astype(str).apply(lambda x: 0 if x == "low" else 1) -) # %% jupyter={"outputs_hidden": false, "source_hidden": false} # UnderSampling