import numpy as np from sklearn.preprocessing import LabelBinarizer class GaussianNB(object): """ 朴素贝叶斯分类器,适用于连续型数据。 """ @staticmethod def gaussfunc(x, mu, singma): """高斯函数 :param x: 数据集 :param mu: 均值 :param singma: 方差 :return: """ sqsingma = singma @ singma numerator = -np.exp(np.sum((x - mu) ** 2, axis=1) / (2 * sqsingma)) return numerator / np.sqrt(2 * np.pi * sqsingma) def fit(self, X, y): """ :param X_: shape = [n_samples, n_features] :param y: shape = [n_samples] :return: self """ self.classes, self.classes_count = np.unique(y, return_counts=True) self.mean = np.zeros((self.classes_count.shape[0], X.shape[1]), dtype=np.float64) self.var = np.zeros((self.classes_count.shape[0], X.shape[1]), dtype=np.float64) for i, label in enumerate(self.classes): x_i = X[y == label] self.mean[i, :] = np.mean(x_i, axis=0) self.var[i, :] = np.var(x_i, axis=0) return self def predict(self, X): """ :param X: shape = [n_samples, n_features] :return: shape = [n_samples] """ likelihood = [] for i in range(self.classes.shape[0]): likelihood.append(self.classes_count[i] * GaussianNB.gaussfunc(X, self.mean[i, :], self.var[i, :])) likelihood = np.array(likelihood).T return np.argmax(likelihood, axis=1) class MultinomialNB(object): """ 朴素贝叶斯分类器,适用于离散型数据。 """ def __init__(self, alpha=1.0): self.alpha = alpha def fit(self, X, y): """ :param X_: shape = [n_samples, n_features] :param y: shape = [n_samples] :return: self """ labelbin = LabelBinarizer() Y = labelbin.fit_transform(y) self.classes = labelbin.classes_ self.class_count = np.zeros(Y.shape[1], dtype=np.float64) self.feature_count = np.zeros((Y.shape[1], X.shape[1]), dtype=np.float64) self.feature_count += Y.T @ X self.class_count += Y.sum(axis=0) smoothed_fc = self.feature_count + self.alpha smoothed_cc = smoothed_fc.sum(axis=1) self.feature_log_prob = (np.log(smoothed_fc) - np.log(smoothed_cc.reshape(-1, 1))) def predict(self, X): """ :param X: shape = [n_samples, n_features] :return: shape = [n_samples] """ likelihood = X @ self.feature_log_prob.T return np.argmax(likelihood, axis=1)