Provide data instead of csv input.

master
junos 2023-05-10 15:20:33 +02:00
parent cd5d8b6a10
commit caeaf03239
3 changed files with 55 additions and 17 deletions

View File

@ -0,0 +1,6 @@
<component name="ProjectCodeStyleConfiguration">
<code_scheme name="Project" version="173">
<option name="RIGHT_MARGIN" value="150" />
<option name="SOFT_MARGINS" value="88" />
</code_scheme>
</component>

View File

@ -0,0 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="USE_PER_PROJECT_SETTINGS" value="true" />
</state>
</component>

View File

@ -136,14 +136,11 @@ def prepare_regression_model_input(model_input, cv_method="logo"):
return train_x, data_y, data_groups
def run_all_regression_models(input_csv):
# Prepare data
data_x, data_y, data_groups = prepare_regression_model_input(input_csv)
def run_all_regression_models(train_x, data_y, data_groups):
# Prepare cross validation
logo = LeaveOneGroupOut()
logo.get_n_splits(
data_x,
train_x,
data_y,
groups=data_groups,
)
@ -155,7 +152,7 @@ def run_all_regression_models(input_csv):
dummy_regr = DummyRegressor(strategy="mean")
dummy_regr_scores = cross_validate(
dummy_regr,
X=data_x,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
@ -173,7 +170,7 @@ def run_all_regression_models(input_csv):
lin_reg_rapids = linear_model.LinearRegression()
lin_reg_scores = cross_validate(
lin_reg_rapids,
X=data_x,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
@ -191,7 +188,7 @@ def run_all_regression_models(input_csv):
ridge_reg = linear_model.Ridge(alpha=0.5)
ridge_reg_scores = cross_validate(
ridge_reg,
X=data_x,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
@ -208,7 +205,7 @@ def run_all_regression_models(input_csv):
lasso_reg = linear_model.Lasso(alpha=0.1)
lasso_reg_score = cross_validate(
lasso_reg,
X=data_x,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
@ -225,7 +222,7 @@ def run_all_regression_models(input_csv):
bayesian_ridge_reg = linear_model.BayesianRidge()
bayesian_ridge_reg_score = cross_validate(
bayesian_ridge_reg,
X=data_x,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
@ -242,7 +239,7 @@ def run_all_regression_models(input_csv):
ransac_reg = linear_model.RANSACRegressor()
ransac_reg_score = cross_validate(
ransac_reg,
X=data_x,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
@ -258,7 +255,13 @@ def run_all_regression_models(input_csv):
svr = svm.SVR()
svr_score = cross_validate(
svr, X=data_x, y=data_y, groups=data_groups, cv=logo, n_jobs=-1, scoring=metrics
svr,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
n_jobs=-1,
scoring=metrics,
)
print("Support vector regression")
@ -270,7 +273,7 @@ def run_all_regression_models(input_csv):
kridge = kernel_ridge.KernelRidge()
kridge_score = cross_validate(
kridge,
X=data_x,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
@ -286,7 +289,13 @@ def run_all_regression_models(input_csv):
gpr = gaussian_process.GaussianProcessRegressor()
gpr_score = cross_validate(
gpr, X=data_x, y=data_y, groups=data_groups, cv=logo, n_jobs=-1, scoring=metrics
gpr,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
n_jobs=-1,
scoring=metrics,
)
print("Gaussian Process Regression")
@ -297,7 +306,13 @@ def run_all_regression_models(input_csv):
rfr = ensemble.RandomForestRegressor(max_features=0.3, n_jobs=-1)
rfr_score = cross_validate(
rfr, X=data_x, y=data_y, groups=data_groups, cv=logo, n_jobs=-1, scoring=metrics
rfr,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
n_jobs=-1,
scoring=metrics,
)
print("Random Forest Regression")
@ -308,7 +323,13 @@ def run_all_regression_models(input_csv):
xgb = XGBRegressor()
xgb_score = cross_validate(
xgb, X=data_x, y=data_y, groups=data_groups, cv=logo, n_jobs=-1, scoring=metrics
xgb,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
n_jobs=-1,
scoring=metrics,
)
print("XGBoost Regressor")
@ -319,7 +340,13 @@ def run_all_regression_models(input_csv):
ada = ensemble.AdaBoostRegressor()
ada_score = cross_validate(
ada, X=data_x, y=data_y, groups=data_groups, cv=logo, n_jobs=-1, scoring=metrics
ada,
X=train_x,
y=data_y,
groups=data_groups,
cv=logo,
n_jobs=-1,
scoring=metrics,
)
print("ADA Boost Regressor")