rapids/calculatingfeatures/eda_explorer/classify.py

91 lines
2.1 KiB
Python

from sklearn.svm import SVC
import pickle
#docs: http://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html#sklearn.svm.SVC
class SVM:
def __init__(self, C=1.0,beta=0.0, kernel='linear', poly_degree=3, max_iter=-1, tol=0.001):
#possible kernels: linear, 'poly', 'rbf', 'sigmoid', 'precomputed' or a callable
#data features
self.n_features = None
self.train_X = []
self.train_Y = []
self.val_X = []
self.val_Y = []
self.test_X = []
self.test_Y = []
#classifier features
self.C = C
self.beta = beta
self.kernel = kernel
self.poly_degree = poly_degree
self.max_iter = max_iter
self.tolerance = tol
self.classifier = None
#SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
# gamma=0.0, kernel='rbf', max_iter=-1, probability=False,
# random_state=None, shrinking=True, tol=0.001, verbose=False)
def setTrainData(self, X, Y):
self.train_X = X
self.train_Y = Y
self.n_features = self.train_X.shape[1]
def setTestData(self, X, Y):
self.test_X = X
self.test_Y = Y
def setC(self, c):
self.C = c
def setBeta(self, beta):
if beta is None:
self.beta = 0.0
else:
self.beta=beta
def setKernel(self, kernel, poly_degree=3):
self.kernel = kernel
self.poly_degree = poly_degree
def setValData(self, X, Y):
self.val_X = X
self.val_Y = Y
def train(self):
self.classifier = SVC(C=self.C, kernel=self.kernel, gamma=self.beta, degree=self.poly_degree, max_iter=self.max_iter,
tol=self.tolerance)
self.classifier.fit(self.train_X, self.train_Y)
def predict(self, X):
return self.classifier.predict(X)
def getScore(self, X, Y):
#returns accuracy
return self.classifier.score(X, Y)
def getNumSupportVectors(self):
return self.classifier.n_support_
def getHingeLoss(self,X,Y):
preds = self.predict(X)
hinge_inner_func = 1.0 - preds*Y
hinge_inner_func = [max(0,x) for x in hinge_inner_func]
return sum(hinge_inner_func)
def saveClassifierToFile(self, filepath):
s = pickle.dumps(self.classifier)
f = open(filepath, 'wb')
f.write(s)
def loadClassifierFromFile(self, filepath):
f2 = open(filepath, 'rb')
s2 = f2.read()
self.classifier = pickle.loads(s2)