[ https://issues.apache.org/jira/browse/SPARK-16872?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
zhengruifeng updated SPARK-16872: --------------------------------- Summary: Impl Gaussian Naive Bayes Classifier (was: Include Gaussian Naive Bayes Classifier) > Impl Gaussian Naive Bayes Classifier > ------------------------------------ > > Key: SPARK-16872 > URL: https://issues.apache.org/jira/browse/SPARK-16872 > Project: Spark > Issue Type: New Feature > Components: ML, PySpark > Reporter: zhengruifeng > Assignee: zhengruifeng > Priority: Major > > I implemented Gaussian NB according to scikit-learn's {{GaussianNB}}. > In GaussianNB model, the {{theta}} matrix is used to store means and there is > a extra {{sigma}} matrix storing the variance of each feature. > GaussianNB in spark > {code} > scala> import org.apache.spark.ml.classification.GaussianNaiveBayes > import org.apache.spark.ml.classification.GaussianNaiveBayes > scala> val path = > "/Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt" > path: String = > /Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt > scala> val data = spark.read.format("libsvm").load(path).persist() > data: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: > double, features: vector] > scala> val gnb = new GaussianNaiveBayes() > gnb: org.apache.spark.ml.classification.GaussianNaiveBayes = gnb_54c50467306c > scala> val model = gnb.fit(data) > 17/01/03 14:25:48 INFO Instrumentation: > GaussianNaiveBayes-gnb_54c50467306c-720112035-1: training: numPartitions=1 > storageLevel=StorageLevel(1 replicas) > 17/01/03 14:25:48 INFO Instrumentation: > GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {} > 17/01/03 14:25:49 INFO Instrumentation: > GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {"numFeatures":4} > 17/01/03 14:25:49 INFO Instrumentation: > GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {"numClasses":3} > 17/01/03 14:25:49 INFO Instrumentation: > GaussianNaiveBayes-gnb_54c50467306c-720112035-1: training finished > model: org.apache.spark.ml.classification.GaussianNaiveBayesModel = > GaussianNaiveBayesModel (uid=gnb_54c50467306c) with 3 classes > scala> model.pi > res0: org.apache.spark.ml.linalg.Vector = > [-1.0986122886681098,-1.0986122886681098,-1.0986122886681098] > scala> model.pi.toArray.map(math.exp) > res1: Array[Double] = Array(0.3333333333333333, 0.3333333333333333, > 0.3333333333333333) > scala> model.theta > res2: org.apache.spark.ml.linalg.Matrix = > 0.2711110067018001 -0.18833335400000006 0.5430507200000001 0.605000046 > -0.6077777799999998 0.181666672 -0.8427117400000006 > -0.8800001399999998 > -0.0911111425964 -0.3583333580000001 0.105084738 > 0.021666701507102017 > scala> model.sigma > res3: org.apache.spark.ml.linalg.Matrix = > 0.1223012510889361 0.07078051983960698 0.03430000595243976 > 0.051336071297393815 > 0.03758145300924998 0.09880280046403413 0.003390296940069426 > 0.007822241779598893 > 0.08058763609659315 0.06701386661293329 0.024866409227781675 > 0.02661391644759426 > scala> model.transform(data).select("probability").take(10) > [rdd_68_0] > res4: Array[org.apache.spark.sql.Row] = > Array([[1.0627410543476422E-21,0.9999999999999938,6.2765233965353945E-15]], > [[7.254521422345374E-26,1.0,1.3849442153180895E-18]], > [[1.9629244119173135E-24,0.9999999999999998,1.9424765181237926E-16]], > [[6.061218297948492E-22,0.9999999999999902,9.853216073401884E-15]], > [[0.9972225671942837,8.844241161578932E-165,0.002777432805716399]], > [[5.361683970373604E-26,1.0,2.3004604508982183E-18]], > [[0.01062850630038623,3.3102617689978775E-100,0.9893714936996136]], > [[1.9297314618271785E-4,2.124922209137708E-71,0.9998070268538172]], > [[3.118816393732361E-27,1.0,6.5310299615983584E-21]], > [[0.9999926009854522,8.734773657627494E-206,7.399014547943611E-6]]) > scala> model.transform(data).select("prediction").take(10) > [rdd_68_0] > res5: Array[org.apache.spark.sql.Row] = Array([1.0], [1.0], [1.0], [1.0], > [0.0], [1.0], [2.0], [2.0], [1.0], [0.0]) > {code} > GaussianNB in scikit-learn > {code} > import numpy as np > from sklearn.naive_bayes import GaussianNB > from sklearn.datasets import load_svmlight_file > path = > '/Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt' > X, y = load_svmlight_file(path) > X = X.toarray() > clf = GaussianNB() > clf.fit(X, y) > >>> clf.class_prior_ > array([ 0.33333333, 0.33333333, 0.33333333]) > >>> clf.theta_ > array([[ 0.27111101, -0.18833335, 0.54305072, 0.60500005], > [-0.60777778, 0.18166667, -0.84271174, -0.88000014], > [-0.09111114, -0.35833336, 0.10508474, 0.0216667 ]]) > > >>> clf.sigma_ > array([[ 0.12230125, 0.07078052, 0.03430001, 0.05133607], > [ 0.03758145, 0.0988028 , 0.0033903 , 0.00782224], > [ 0.08058764, 0.06701387, 0.02486641, 0.02661392]]) > > >>> clf.predict_proba(X)[:10] > array([[ 1.06274105e-021, 1.00000000e+000, 6.27652340e-015], > [ 7.25452142e-026, 1.00000000e+000, 1.38494422e-018], > [ 1.96292441e-024, 1.00000000e+000, 1.94247652e-016], > [ 6.06121830e-022, 1.00000000e+000, 9.85321607e-015], > [ 9.97222567e-001, 8.84424116e-165, 2.77743281e-003], > [ 5.36168397e-026, 1.00000000e+000, 2.30046045e-018], > [ 1.06285063e-002, 3.31026177e-100, 9.89371494e-001], > [ 1.92973146e-004, 2.12492221e-071, 9.99807027e-001], > [ 3.11881639e-027, 1.00000000e+000, 6.53102996e-021], > [ 9.99992601e-001, 8.73477366e-206, 7.39901455e-006]]) > > >>> clf.predict(X)[:10] > array([ 1., 1., 1., 1., 0., 1., 2., 2., 1., 0.]) > {code} -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org