source: trunk/gsdl/packages/kea/kea-3.0/weka/classifiers/NaiveBayesSimple.java@ 8815

Last change on this file since 8815 was 8815, checked in by mdewsnip, 19 years ago

Kea 3.0, as downloaded from http://www.nzdl.org/kea but with CSTR_abstracts_test, CSTR_abstracts_train, Chinese_test, and Chinese_train directories removed.

  • Property svn:keywords set to Author Date Id Revision
File size: 9.7 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * NaiveBayesSimple.java
19 * Copyright (C) 1999 Eibe Frank
20 *
21 */
22
23package weka.classifiers;
24
25import java.io.*;
26import java.util.*;
27import weka.core.*;
28
29/**
30 * Class for building and using a simple Naive Bayes classifier.
31 * Numeric attributes are modelled by a normal distribution. For more
32 * information, see<p>
33 *
34 * Richard Duda and Peter Hart (1973).<i>Pattern
35 * Classification and Scene Analysis</i>. Wiley, New York.
36
37 * @author Eibe Frank ([email protected])
38 * @version $Revision: 8815 $
39*/
40public class NaiveBayesSimple extends DistributionClassifier {
41
42 /** All the counts for nominal attributes. */
43 private double [][][] m_Counts;
44
45 /** The means for numeric attributes. */
46 private double [][] m_Means;
47
48 /** The standard deviations for numeric attributes. */
49 private double [][] m_Devs;
50
51 /** The prior probabilities of the classes. */
52 private double [] m_Priors;
53
54 /** The instances used for training. */
55 private Instances m_Instances;
56
57 /** Constant for normal distribution. */
58 private static double NORM_CONST = Math.sqrt(2 * Math.PI);
59
60 /**
61 * Generates the classifier.
62 *
63 * @param instances set of instances serving as training data
64 * @exception Exception if the classifier has not been generated successfully
65 */
66 public void buildClassifier(Instances instances) throws Exception {
67
68 int attIndex = 0;
69 double sum;
70
71 if (instances.checkForStringAttributes()) {
72 throw new Exception("Can't handle string attributes!");
73 }
74 if (instances.classAttribute().isNumeric()) {
75 throw new Exception("Naive Bayes: Class is numeric!");
76 }
77
78 m_Instances = new Instances(instances, 0);
79
80 // Reserve space
81 m_Counts = new double[instances.numClasses()]
82 [instances.numAttributes() - 1][0];
83 m_Means = new double[instances.numClasses()]
84 [instances.numAttributes() - 1];
85 m_Devs = new double[instances.numClasses()]
86 [instances.numAttributes() - 1];
87 m_Priors = new double[instances.numClasses()];
88 Enumeration enum = instances.enumerateAttributes();
89 while (enum.hasMoreElements()) {
90 Attribute attribute = (Attribute) enum.nextElement();
91 if (attribute.isNominal()) {
92 for (int j = 0; j < instances.numClasses(); j++) {
93 m_Counts[j][attIndex] = new double[attribute.numValues()];
94 }
95 } else {
96 for (int j = 0; j < instances.numClasses(); j++) {
97 m_Counts[j][attIndex] = new double[1];
98 }
99 }
100 attIndex++;
101 }
102
103 // Compute counts and sums
104 Enumeration enumInsts = instances.enumerateInstances();
105 while (enumInsts.hasMoreElements()) {
106 Instance instance = (Instance) enumInsts.nextElement();
107 if (!instance.classIsMissing()) {
108 Enumeration enumAtts = instances.enumerateAttributes();
109 attIndex = 0;
110 while (enumAtts.hasMoreElements()) {
111 Attribute attribute = (Attribute) enumAtts.nextElement();
112 if (!instance.isMissing(attribute)) {
113 if (attribute.isNominal()) {
114 m_Counts[(int)instance.classValue()][attIndex]
115 [(int)instance.value(attribute)]++;
116 } else {
117 m_Means[(int)instance.classValue()][attIndex] +=
118 instance.value(attribute);
119 m_Counts[(int)instance.classValue()][attIndex][0]++;
120 }
121 }
122 attIndex++;
123 }
124 m_Priors[(int)instance.classValue()]++;
125 }
126 }
127
128 // Compute means
129 Enumeration enumAtts = instances.enumerateAttributes();
130 attIndex = 0;
131 while (enumAtts.hasMoreElements()) {
132 Attribute attribute = (Attribute) enumAtts.nextElement();
133 if (attribute.isNumeric()) {
134 for (int j = 0; j < instances.numClasses(); j++) {
135 if (m_Counts[j][attIndex][0] < 2) {
136 throw new Exception("attribute " + attribute.name() +
137 ": less than two values for class " +
138 instances.classAttribute().value(j));
139 }
140 m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
141 }
142 }
143 attIndex++;
144 }
145
146 // Compute standard deviations
147 enumInsts = instances.enumerateInstances();
148 while (enumInsts.hasMoreElements()) {
149 Instance instance =
150 (Instance) enumInsts.nextElement();
151 if (!instance.classIsMissing()) {
152 enumAtts = instances.enumerateAttributes();
153 attIndex = 0;
154 while (enumAtts.hasMoreElements()) {
155 Attribute attribute = (Attribute) enumAtts.nextElement();
156 if (!instance.isMissing(attribute)) {
157 if (attribute.isNumeric()) {
158 m_Devs[(int)instance.classValue()][attIndex] +=
159 (m_Means[(int)instance.classValue()][attIndex]-
160 instance.value(attribute))*
161 (m_Means[(int)instance.classValue()][attIndex]-
162 instance.value(attribute));
163 }
164 }
165 attIndex++;
166 }
167 }
168 }
169 enumAtts = instances.enumerateAttributes();
170 attIndex = 0;
171 while (enumAtts.hasMoreElements()) {
172 Attribute attribute = (Attribute) enumAtts.nextElement();
173 if (attribute.isNumeric()) {
174 for (int j = 0; j < instances.numClasses(); j++) {
175 if (m_Devs[j][attIndex] <= 0) {
176 throw new Exception("attribute " + attribute.name() +
177 ": standard deviation is 0 for class " +
178 instances.classAttribute().value(j));
179 }
180 else {
181 m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
182 m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
183 }
184 }
185 }
186 attIndex++;
187 }
188
189 // Normalize counts
190 enumAtts = instances.enumerateAttributes();
191 attIndex = 0;
192 while (enumAtts.hasMoreElements()) {
193 Attribute attribute = (Attribute) enumAtts.nextElement();
194 if (attribute.isNominal()) {
195 for (int j = 0; j < instances.numClasses(); j++) {
196 sum = Utils.sum(m_Counts[j][attIndex]);
197 for (int i = 0; i < attribute.numValues(); i++) {
198 m_Counts[j][attIndex][i] =
199 (m_Counts[j][attIndex][i] + 1)
200 / (sum + (double)attribute.numValues());
201 }
202 }
203 }
204 attIndex++;
205 }
206
207 // Normalize priors
208 sum = Utils.sum(m_Priors);
209 for (int j = 0; j < instances.numClasses(); j++)
210 m_Priors[j] = (m_Priors[j] + 1)
211 / (sum + (double)instances.numClasses());
212 }
213
214 /**
215 * Calculates the class membership probabilities for the given test instance.
216 *
217 * @param instance the instance to be classified
218 * @return predicted class probability distribution
219 * @exception Exception if distribution can't be computed
220 */
221 public double[] distributionForInstance(Instance instance) throws Exception {
222
223 double [] probs = new double[instance.numClasses()];
224 int attIndex;
225
226 for (int j = 0; j < instance.numClasses(); j++) {
227 probs[j] = 1;
228 Enumeration enumAtts = instance.enumerateAttributes();
229 attIndex = 0;
230 while (enumAtts.hasMoreElements()) {
231 Attribute attribute = (Attribute) enumAtts.nextElement();
232 if (!instance.isMissing(attribute)) {
233 if (attribute.isNominal()) {
234 probs[j] *= m_Counts[j][attIndex][(int)instance.value(attribute)];
235 } else {
236 probs[j] *= normalDens(instance.value(attribute),
237 m_Means[j][attIndex],
238 m_Devs[j][attIndex]);}
239 }
240 attIndex++;
241 }
242 probs[j] *= m_Priors[j];
243 }
244
245 // Normalize probabilities
246 Utils.normalize(probs);
247
248 return probs;
249 }
250
251 /**
252 * Returns a description of the classifier.
253 *
254 * @return a description of the classifier as a string.
255 */
256 public String toString() {
257
258 if (m_Instances == null) {
259 return "Naive Bayes (simple): No model built yet.";
260 }
261 try {
262 StringBuffer text = new StringBuffer("Naive Bayes (simple)");
263 int attIndex;
264
265 for (int i = 0; i < m_Instances.numClasses(); i++) {
266 text.append("\n\nClass " + m_Instances.classAttribute().value(i)
267 + ": P(C) = "
268 + Utils.doubleToString(m_Priors[i], 10, 8)
269 + "\n\n");
270 Enumeration enumAtts = m_Instances.enumerateAttributes();
271 attIndex = 0;
272 while (enumAtts.hasMoreElements()) {
273 Attribute attribute = (Attribute) enumAtts.nextElement();
274 text.append("Attribute " + attribute.name() + "\n");
275 if (attribute.isNominal()) {
276 for (int j = 0; j < attribute.numValues(); j++) {
277 text.append(attribute.value(j) + "\t");
278 }
279 text.append("\n");
280 for (int j = 0; j < attribute.numValues(); j++)
281 text.append(Utils.
282 doubleToString(m_Counts[i][attIndex][j], 10, 8)
283 + "\t");
284 } else {
285 text.append("Mean: " + Utils.
286 doubleToString(m_Means[i][attIndex], 10, 8) + "\t");
287 text.append("Standard Deviation: "
288 + Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
289 }
290 text.append("\n\n");
291 attIndex++;
292 }
293 }
294
295 return text.toString();
296 } catch (Exception e) {
297 return "Can't print Naive Bayes classifier!";
298 }
299 }
300
301 /**
302 * Density function of normal distribution.
303 */
304 private double normalDens(double x, double mean, double stdDev) {
305
306 double diff = x - mean;
307
308 return (1 / (NORM_CONST * stdDev))
309 * Math.exp(-(diff * diff / (2 * stdDev * stdDev)));
310 }
311
312 /**
313 * Main method for testing this class.
314 *
315 * @param argv the options
316 */
317 public static void main(String [] argv) {
318
319 Classifier scheme;
320
321 try {
322 scheme = new NaiveBayesSimple();
323 System.out.println("Evaluation disabled!");
324 //System.out.println(Evaluation.evaluateModel(scheme, argv));
325 } catch (Exception e) {
326 System.err.println(e.getMessage());
327 }
328 }
329}
330
331
Note: See TracBrowser for help on using the repository browser.