source: trunk/gsdl/perllib/Kea-1.1.4/KEP.java@ 3161

Last change on this file since 3161 was 1972, checked in by jmt14, 23 years ago

* empty log message *

  • Property svn:keywords set to Author Date Id Revision
File size: 36.6 KB
Line 
1/*
2 * KEP.java
3 * Version 1.1
4 *
5 * Kea -- Automatic Keyphrase Extraction
6 * Copyright 1998-1999 by Eibe Frank and Gordon Paynter
7 * Contact [email protected] or [email protected]
8 *
9 * This program is free software; you can redistribute it and/or modify
10 * it under the terms of the GNU General Public License as published by
11 * the Free Software Foundation; either version 2 of the License, or
12 * (at your option) any later version.
13 *
14 * This program is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 * GNU General Public License for more details.
18 *
19 * You should have received a copy of the GNU General Public License
20 * along with this program; if not, write to the Free Software
21 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
22 */
23
24/*
25 * KEP is the machine learning and post-processing part of KEA,
26 * the Keyphrase Extraction Algorithm.
27 *
28 * KEP assumes that the class is the last attribute and
29 * that keyphrases not found in the text have a
30 * missing class value.
31 *
32 * @author Eibe Frank ([email protected])
33 */
34
35/*
36 * Version History
37 *
38 * 1.0 Witten et.al version
39 * 1.1 First distribution
40 * 1.1.1 Bug: when more output requested than there are (filtered)
41 * instances in the arff file, Kea would output garbage
42 * STDOUT output cleaned up in training stage
43 */
44
45import java.io.*;
46import java.util.*;
47import java.util.zip.*;
48import jaws.corePackage.*;
49import jaws.evaluationPackage.*;
50import jaws.classifierPackage.*;
51import jaws.filterPackage.*;
52
53public class KEP implements Serializable {
54
55 // ================
56 // Public variables
57 // ================
58
59 /**
60 * The maximum number of documents in
61 * training and test corpus.
62 */
63
64 public final static int MAX_NUM_DOCS = 1000;
65
66 /**
67 * Index of the column with the document name.
68 */
69
70 public final static int DOC_NUM_INDEX = 1;
71
72 /**
73 * Number of steps for optimization of
74 * probability cutoff.
75 */
76
77 public final static int PROB_STEPS = 100;
78
79 /**
80 * Minimum probability that a phrase
81 * has to achieve.
82 */
83
84 public final static double MIN_PROB = 0.00;
85
86 /**
87 * Maximum number of bins for equal
88 * width discretization.
89 */
90
91 public final static int MAX_BINS = 30;
92
93 /**
94 * Minimum number of bins for equal
95 * width discretization.
96 */
97
98 public final static int MIN_BINS = 2;
99
100 /**
101 * Number of folds for standard
102 * cross-validation.
103 */
104
105 public final static int NUM_FOLDS = 5;
106
107 /**
108 * Columns to be discretized.
109 */
110
111 public final static String DISC_RANGE = "first-last";
112
113 // =================
114 // Private variables
115 // =================
116
117 /**
118 * The classifier for determining
119 * keyphrases.
120 */
121
122 private DistributionClassifier theClassifier = new NaiveBayes();
123
124 /**
125 * The number of phrases to be output.
126 */
127
128 private int theNumPhrases = -1;
129
130 /**
131 * The minimum probability a phrase has
132 * to achieve.
133 */
134
135 private double theMinProb = -1;
136
137 /**
138 * Rank cutoff.
139 */
140
141 private int theRankCutoff;
142
143 /**
144 * The maximum rank cutoff.
145 */
146
147 private int theMaxRankCutoff = 20;
148
149 /**
150 * The minimum rank cutoff.
151 */
152
153 private int theMinRankCutoff = 1;
154
155 /**
156 * Probability cutoff.
157 */
158
159 private double theProbCutoff;
160
161 /**
162 * The number of bins for equal width discretization.
163 */
164
165 private int theNumBins = 23;
166
167 /**
168 * The discretizer built up on the training data.
169 */
170
171 private DiscretiseFilter theDisc;
172
173 /**
174 * Equal width discretization?
175 */
176
177 private boolean theEqual = false;
178
179 /**
180 * Columns in arff file to be used for building
181 * the classifier.
182 */
183
184 private String[] theCols = {"tfidf","first_occurrence","class"};
185
186 /**
187 * Column to be used for subphrase selection.
188 */
189
190 private String theSubphraseColumnName = "tfidf";
191
192 /**
193 * Column for subphrase sellection: higher is better
194 */
195
196 private boolean theSubphraseHi = true;
197
198 /**
199 * The secondary key for sorting the phrases in
200 * the table (primary key: probability)
201 */
202
203 private String theSecKeyColumnName = "tfidf";
204
205 /**
206 * Column for secondary key: higher is better
207 */
208
209 private boolean theSecKeyHi = true;
210
211 /**
212 * Column to be used for selecting batches for the
213 * cross-validation.
214 */
215
216 private String theCVColumnName = "";
217
218 /**
219 * Use training data instead of cross-validation?
220 */
221
222 private boolean theUseTrain = true;
223
224 /**
225 * Use discretization?
226 */
227
228 private boolean theUseDiscretization = true;
229
230 // ==============
231 // Public methods
232 // ==============
233
234 /**
235 * Returns the indices for the given attribute names.
236 * @param name the attributes' names
237 * @return ind the indices
238 * @exception Exception if something goes wrong
239 */
240
241 private String indices(Instances dataset) {
242
243 String rangeList = "";
244 for(int i = 0; i < theCols.length; i++)
245 if (i == 0)
246 rangeList = ""+(dataset.attribute(theCols[i]).index() + 1);
247 else
248 rangeList += ","+(dataset.attribute(theCols[i]).index() + 1);
249 return rangeList;
250 }
251
252 /**
253 * Performs a cross-validation on the given dataset.
254 * @returns the best F-Measure for all probability
255 * and rank cutoffs
256 */
257
258 private double[][][] crossValidation(Instances instances, int numPos)
259 throws Exception {
260
261 SelectFilter selFilter;
262 DiscretiseFilter discFilter;
263 DeleteFilter delFilter;
264 Instances train, test, trainDisc, testDisc;
265 int numFolds, indexI, indexP;
266 boolean batchedMode = false;
267 double[][][] results = new double[theMaxRankCutoff][PROB_STEPS][2];
268 double[][][] bestInsts;
269 double[] currentTPandFP;
270 int[] range = new int[1];
271 int theCVColumn = -1;
272
273 if (theCVColumnName.length() > 0)
274 theCVColumn = instances.attribute(theCVColumnName).index();
275
276 // Is there more than one document in the dataset?
277
278 if ((theCVColumn != -1) && (instances.attribute(theCVColumn).numValues() > 1)) {
279 batchedMode = true;
280 numFolds = instances.attribute(theCVColumn).numValues();
281 } else
282 numFolds = NUM_FOLDS;
283
284 // Perform the cross-validation
285
286 for (int k = 0; k < numFolds; k++) {
287
288 System.err.println("Processing fold: "+k);
289
290 if (batchedMode) {
291
292 // Batched cross-validation
293
294 selFilter = new SelectFilter();
295 selFilter.setInvertSelection(true);
296 selFilter.setAttributeIndex(theCVColumn + 1);
297 selFilter.inputFormat(instances);
298
299 // Generate training set
300
301 range[0] = k;
302 selFilter.setNominalIndicesArr(range);
303 train = Filter.useFilter(instances, selFilter);
304
305 // Generate test set
306
307 selFilter.setInvertSelection(false);
308 test = Filter.useFilter(instances, selFilter);
309 } else {
310
311 // Standard NUM_FOLDS cross-validation
312
313 train = instances.trainCV(NUM_FOLDS, k);
314 test = instances.testCV(NUM_FOLDS, k);
315 }
316
317 System.err.println("Choosing attributes...");
318
319 // Choose attributes in training and test data
320
321 delFilter = new DeleteFilter();
322 delFilter.setInvertSelection(true);
323 delFilter.setAttributeIndices(indices(instances));
324 delFilter.inputFormat(instances);
325 trainDisc = Filter.useFilter(train, delFilter);
326 testDisc = Filter.useFilter(test, delFilter);
327
328 System.err.println("Discretizing...");
329
330 // Discretise training and test data
331
332 discFilter = new DiscretiseFilter();
333 discFilter.setAttributeIndices(DISC_RANGE);
334 if (theEqual) {
335 discFilter.setBins(theNumBins);
336 discFilter.setClassIndex(-1);
337 discFilter.setUseMDL(false);
338 } else
339 discFilter.setClassIndex(trainDisc.numAttributes());
340 discFilter.inputFormat(trainDisc);
341 trainDisc = Filter.useFilter(trainDisc, discFilter);
342 testDisc = Filter.useFilter(testDisc, discFilter);
343
344 System.err.println("Building classifier...");
345
346 // Build classifier on training data
347
348 theClassifier.buildClassifier(trainDisc);
349
350 // Find the most probable keyphrases in the test data
351
352 System.err.println("Looking for the most probable keyphrases...");
353
354 bestInsts = new double[MAX_NUM_DOCS][theMaxRankCutoff][2];
355 mostProbablePhrases(test, testDisc, bestInsts);
356
357 // Compute F-measure for all probability and rank
358 // cutoffs and add to sum.
359
360 indexI = 0;
361 for (int i = theMinRankCutoff; i <= theMaxRankCutoff; i++) {
362 indexP = 0;
363 for (double p = 1 - (1.0/PROB_STEPS); p >= MIN_PROB;
364 p = p - (1.0/PROB_STEPS)) {
365 currentTPandFP = TPandFP(test, bestInsts, i, p, false);
366 results[indexI][indexP][0] += currentTPandFP[0];
367 results[indexI][indexP][1] += currentTPandFP[1];
368 indexP++;
369 }
370 indexI++;
371 }
372 }
373
374 return results;
375 }
376
377 /**
378 * Optimizes the F-measure by choosing the discretization and the
379 * rank and probability cutoffs.
380 */
381
382 private double optimizeFMeasure(Instances instances, int numPos)
383 throws Exception {
384
385 double[][][][] results;
386 double currentFMeasure, maxFMeasure = 0;
387 int temp_MAX_BINS = MIN_BINS;
388
389 if (theEqual)
390 temp_MAX_BINS = MAX_BINS;
391
392 // Do cross-validation for different number of bins.
393
394 results =
395 new double[temp_MAX_BINS - MIN_BINS + 1][theMaxRankCutoff][PROB_STEPS][2];
396 for (int b = MIN_BINS; b <= temp_MAX_BINS; b++) {
397 if (theEqual) {
398 System.err.println("Computing Estimate for "+b+" bins");
399 theNumBins = b;
400 }
401 results[b - MIN_BINS] = crossValidation(instances, numPos);
402 }
403
404 // Choose best parameter settings
405
406 for (int b = MIN_BINS; b <= temp_MAX_BINS; b++) {
407
408 // Find best probability and rank cutoffs
409
410 int indexI = 0;
411 for (int i = theMinRankCutoff; i <= theMaxRankCutoff; i++) {
412 int indexP = 0;
413 for (double p = 1-(1.0/PROB_STEPS); p >= MIN_PROB; p=p-(1.0/PROB_STEPS)) {
414 currentFMeasure =
415 fMeasure(results[b - MIN_BINS][indexI][indexP][0],
416 results[b - MIN_BINS][indexI][indexP][1],
417 numPos - results[b - MIN_BINS][indexI][indexP][0]);
418 if (Utils.grOrEq(currentFMeasure, maxFMeasure)) {
419 theRankCutoff = i;
420 theProbCutoff = p;
421 theNumBins = b;
422 maxFMeasure = currentFMeasure;
423 }
424 indexP++;
425 }
426 indexI++;
427 }
428 }
429
430 return maxFMeasure;
431 }
432
433 /**
434 * Computes true positives and false positives for the dataset.
435 */
436
437 private double[] TPandFP(Instances dataset, double[][][] bestInsts,
438 int rankCutoff, double probCutoff, boolean output)
439 throws Exception {
440
441 double[] TPandFP = new double[2];
442
443 // For all documents
444
445 for (int i = 0; i < dataset.attribute(DOC_NUM_INDEX).numValues(); i++) {
446
447 if (output)
448 System.out.println("Current document: "+
449 dataset.attribute(DOC_NUM_INDEX).value(i));
450
451 // Up to the current rankcutoff
452
453 for (int j = 0; j < rankCutoff; j++) {
454
455 // Don't output trash (instances with probability 0)
456 if (bestInsts[i][j][0] == 0) {
457 break;
458 }
459
460 // Find the true positives and false positives
461
462 if (Utils.sm(bestInsts[i][j][0], probCutoff))
463 break;
464 if(dataset.instance((int)bestInsts[i][j][1]).classValue() == 0) {
465 TPandFP[0]++;
466 if (output)
467 System.out.println("Hit: "+dataset.instance((int)bestInsts[i][j][1])+
468 " "+bestInsts[i][j][0]);
469 } else {
470 TPandFP[1]++;
471 if (output)
472 System.out.println("Miss: "+dataset.instance((int)bestInsts[i][j][1])+
473 " "+bestInsts[i][j][0]);
474 }
475 }
476 }
477
478 return TPandFP;
479 }
480
481 /**
482 * Computes and outputs precision per document for rank cutoffs 1 to 20.
483 * Also outputs average precision and standard deviation.
484 */
485
486 private void precisions(Instances dataset, double[][][] bestInsts)
487 throws Exception {
488
489 double[] precisions;
490 int i, tp, fp;
491 double standardError, deltaConfidence;
492
493 // For all rank cutoffs
494
495 for (int r = 1; r <= theMaxRankCutoff; r++) {
496
497 // For all documents
498
499 precisions = new double[dataset.attribute(DOC_NUM_INDEX).numValues()];
500 for (i = 0; i < dataset.attribute(DOC_NUM_INDEX).numValues(); i++) {
501
502 // Up to the current rankcutoff
503
504 tp = fp = 0;
505 for (int j = 0; j < r; j++) {
506
507 // Find the true positives and false positives
508
509 if(dataset.instance((int)bestInsts[i][j][1]).classValue() == 0)
510 tp++;
511 else
512 fp++;
513 }
514 precisions[i] = ((double) tp) / ((double) tp + fp);
515 }
516 System.err.print(r);
517 System.err.print("\t"+Utils.
518 doubleToString(Utils.sum(precisions) / i, 6,4));
519 standardError = Math.sqrt(Utils.variance(precisions) / i);
520 deltaConfidence = Statistics.studentTConfidenceInterval(i - 1,
521 0.05, standardError);
522 System.err.println("\t"+Utils.doubleToString(deltaConfidence, 6,4)+" &&&");
523 }
524 }
525
526
527 /**
528 * Computes F-Measure.
529 */
530
531 private double fMeasure(double tp, double fp, double fn) {
532
533 return (2 * tp) / ((2 * tp) + fp + fn);
534 }
535
536 /**
537 * Computes the F-Measure for the given dataset.
538 */
539
540 private double fMeasure(Instances dataset, double[][][] bestInsts,
541 int rankCutoff, double probCutoff, int numPos,
542 boolean output)
543 throws Exception {
544
545 double[] TPandFP = TPandFP(dataset, bestInsts, rankCutoff, probCutoff, output);
546 double fn = (double)numPos - TPandFP[0];
547 double precision, recall;
548
549 if (output) {
550 System.out.println("tp: "+TPandFP[0]+"\tfp: "+TPandFP[1]+"\tfn: "+fn);
551 precision = TPandFP[0]/(TPandFP[0]+TPandFP[1]);
552 recall = TPandFP[0]/(TPandFP[0]+fn);
553 System.out.println("Precision: "+precision);
554 System.out.println("Recall: "+recall);
555 System.out.println("F-measure: "+((2*precision*recall)/(precision+recall)));
556 }
557
558 return fMeasure(TPandFP[0], TPandFP[1], fn);
559 }
560
561 /**
562 * Computes the best cutoff.
563 */
564
565 private void bestCutoffs(Instances dataset, double[][][] bestInsts,
566 int numPos) throws Exception {
567
568 double maxFMeasure = 0, currentFMeasure;
569
570 theRankCutoff = 0;
571 theProbCutoff = 1;
572 for (int i = theMinRankCutoff; i <= theMaxRankCutoff; i++) {
573 for (double p = 1-(1.0/PROB_STEPS); p >= MIN_PROB; p = p-(1.0/PROB_STEPS)) {
574 currentFMeasure = fMeasure(dataset, bestInsts, i, p, numPos, false);
575 if (Utils.grOrEq(currentFMeasure, maxFMeasure)) {
576 theRankCutoff = i;
577 theProbCutoff = p;
578 maxFMeasure = currentFMeasure;
579 }
580 // System.out.println("Current rank cutoff: "+i+
581 // "\tCurrent prob cutoff: "+p+
582 // "\tF-Measure: "+currentFMeasure);
583 }
584 }
585 System.err.println("Best rank cutoff: "+theRankCutoff+
586 "\tBest prob cutoff: "+theProbCutoff+
587 "\tF-Measure: "+maxFMeasure);
588 }
589
590 /**
591 * Finds and deletes all missed author-defined keyphrases in the
592 * dataset and returns the number of positive examples.
593 */
594
595 private int findMissedPositives(Instances data) throws Exception {
596
597 int numPos = 0, missing = 0;
598
599 System.err.println("Looking for missed keyphrases in the dataset...");
600
601 for (int i = 0; i < data.numInstances(); i++) {
602 if (data.instance(i).classIsMissing()) {
603 missing++;
604 numPos++;
605 } else {
606 if (data.instance(i).classValue() == 0)
607 numPos++;
608 }
609 }
610
611 System.err.println("Found " + (int)data.attribute(DOC_NUM_INDEX).numValues() +
612 " documents in the dataset.");
613 System.err.println("Total number of examples missing: "+missing);
614 System.err.println("Total number of positive examples: "+numPos);
615
616 return numPos;
617 }
618
619 /**
620 * Gets all the subphrases of an instance.
621 */
622
623 public String[] getSubphrases(String phrase) {
624
625 StringTokenizer tokenizer = new StringTokenizer(phrase);
626 String[] words = new String[tokenizer.countTokens()], subphrases;
627 int i = 0, numSubphrases = 0, s = 0;
628
629 // Get all words in the phrase
630
631 while (tokenizer.hasMoreElements()) {
632 words[i] = tokenizer.nextToken();
633 i++;
634 }
635
636 // How many subphrases?
637
638 for (i = 0; i < words.length - 1; i++)
639 numSubphrases += words.length - i;
640
641 // Array for storing the subphrases.
642
643 subphrases = new String[numSubphrases];
644
645 // Compute subphrases.
646
647 for (i = 0; i < words.length; i++) {
648
649 // All subphrases at position i.
650
651 for (int j = 0; j < words.length - i; j++) {
652
653 // The subphrase of length j + 1, apart
654 // of the phrase itself.
655
656 if ((i != 0) || (j != (words.length - 1))) {
657 if (j == 0)
658 subphrases[s] = words[i];
659 else
660 subphrases[s] = subphrases[s - 1]+' '+words[i + j];
661 s++;
662 }
663 }
664 }
665
666 return subphrases;
667 }
668
669 /**
670 * Gets the stemmed phrase of an instance.
671 */
672
673 public String getPhrase(Instance instance) throws Exception {
674
675 String string = instance.attribute(0).value((int)instance.value(0));
676
677 return string.substring(1, string.indexOf('(') - 1);
678 }
679
680 /**
681 * Find the theMaxRankCutoff most probable phrases for each document.
682 * @param data the set of instances
683 * @param bestInsts an array double[MAX_NUM_DOCS][theMaxRankCutoff][2]:
684 * double[i][j][0] contains the probability, and
685 * double[i][j][1] the instances index.
686 * @exception Exception if an instance can't be classified
687 */
688
689 public void mostProbablePhrases(Instances originalData,
690 Instances data, double bestInsts[][][])
691 throws Exception {
692
693 double[] probAndIndex, probAndIndexSubphrase;
694 int numDoc;
695 Hashtable[] hashtables = new Hashtable[originalData.attribute(DOC_NUM_INDEX).
696 numValues()];
697 String[] subphrases;
698 String phrase;
699 int theSubphraseColumn = -1, theSecKeyColumn = -1;
700
701 if (theSubphraseColumnName.length() > 0)
702 theSubphraseColumn = originalData.attribute(theSubphraseColumnName).index();
703 if (theSecKeyColumnName.length() > 0)
704 theSecKeyColumn = originalData.attribute(theSecKeyColumnName).index();
705
706 // Make a new hashtable for each document that stores the phrases.
707
708 for (numDoc = 0; numDoc < originalData.attribute(DOC_NUM_INDEX).numValues();
709 numDoc++)
710 hashtables[numDoc] =
711 new Hashtable(originalData.numInstances()/
712 originalData.attribute(DOC_NUM_INDEX).numValues());
713
714 // Go through all the instances and keep the best phrases in hash tables.
715 // (ie., their indices and probabilities.) Delete all subphrases that suck.
716
717 for (int i = 0; i < data.numInstances(); i++) {
718
719 // Do nothing if the class is missing.
720
721 if (!data.instance(i).classIsMissing()) {
722
723 // Store index and probability in an array.
724
725 probAndIndex = new double[2];
726 probAndIndex[0] = theClassifier.
727 distributionForInstance(data.instance(i))[0];
728 probAndIndex[1] = (double)i;
729
730 // Shortcut -- getting rid of phrases with very small probability.
731
732 if (Utils.sm(probAndIndex[0], MIN_PROB))
733 continue;
734
735 // Which document does the current phrase belong to?
736
737 numDoc = (int)originalData.instance(i).value(DOC_NUM_INDEX);
738
739 // Get the actual phrase.
740
741 phrase = getPhrase(originalData.instance(i));
742
743 // Put the phrase into the hashtable.
744
745 hashtables[numDoc].put(phrase, probAndIndex);
746
747 // Do we want to get rid of subphrases?
748
749 if (theSubphraseColumn != -2 ) {
750
751 // Get all subphrases.
752
753 subphrases = getSubphrases(phrase);
754
755 // Delete all subphrases with lower or equal probability from the
756 // hash table.
757
758 for (int j = 0; j < subphrases.length; j++) {
759 probAndIndexSubphrase =
760 (double[]) hashtables[numDoc].get(subphrases[j]);
761 if ((probAndIndexSubphrase != null) &&
762 Utils.smOrEq(probAndIndexSubphrase[0], probAndIndex[0]))
763 if ((theSubphraseColumn == -1) ||
764 ((theSubphraseHi &&
765 Utils.smOrEq(originalData.
766 instance((int)probAndIndexSubphrase[1]).
767 value(theSubphraseColumn),
768 originalData.instance((int)probAndIndex[1]).
769 value(theSubphraseColumn))) ||
770 ((!theSubphraseHi) &&
771 Utils.grOrEq(originalData.
772 instance((int)probAndIndexSubphrase[1]).
773 value(theSubphraseColumn),
774 originalData.instance((int)probAndIndex[1]).
775 value(theSubphraseColumn)))))
776 hashtables[numDoc].remove(subphrases[j]);
777 }
778 }
779 }
780 }
781
782 // Find the theMaxRankCutoff most probable phrases for each document.
783
784 for (numDoc = 0; numDoc < originalData.attribute(DOC_NUM_INDEX).numValues();
785 numDoc++){
786 Enumeration enum = hashtables[numDoc].elements();
787 while (enum.hasMoreElements()) {
788 probAndIndex = (double[]) enum.nextElement();
789 for (int j = 0; j < theMaxRankCutoff; j++)
790 if (Utils.gr(probAndIndex[0], bestInsts[numDoc][j][0]) ||
791 ((theSecKeyColumn != -1) &&
792 Utils.eq(probAndIndex[0], bestInsts[numDoc][j][0]) &&
793 ((theSecKeyHi &&
794 Utils.gr(originalData.instance((int)probAndIndex[1]).
795 value(theSecKeyColumn),
796 originalData.instance((int)bestInsts[numDoc][j][1]).
797 value(theSecKeyColumn))) ||
798 ((!theSecKeyHi) &&
799 Utils.sm(originalData.instance((int)probAndIndex[1]).
800 value(theSecKeyColumn),
801 originalData.instance((int)bestInsts[numDoc][j][1]).
802 value(theSecKeyColumn)))))) {
803 for (int k = theMaxRankCutoff - 2; k >= j; k--) {
804 bestInsts[numDoc][k + 1][0] = bestInsts[numDoc][k][0];
805 bestInsts[numDoc][k + 1][1] = bestInsts[numDoc][k][1];
806 }
807 bestInsts[numDoc][j][0] = probAndIndex[0];
808 bestInsts[numDoc][j][1] = probAndIndex[1];
809 break;
810 }
811 }
812 }
813 }
814
815 /**
816 * Method that builds classifier.
817 * @exception Exception if classifier can't be built successfully
818 */
819
820 public void buildClassifier(InputStream inputStream) throws Exception {
821
822 Instances instances, instancesDel, instancesDisc;
823 DeleteFilter del;
824 int numPos;
825 double[][][] bestInsts =
826 new double[MAX_NUM_DOCS][theMaxRankCutoff][2];
827
828 System.err.println("Loading data...");
829
830 instances = new Instances(inputStream);
831 instances.setClassIndex(instances.numAttributes() - 1);
832
833 // Number of positive examples?
834
835 System.err.println("Counting positive examples...");
836
837 numPos = findMissedPositives(instances);
838
839 if ((theNumPhrases <= 0) && (theMinProb < 0) && (!theUseTrain)) {
840
841 // Cross-validation for parameter settings
842
843 System.err.println("Performing cross-validation...");
844
845 System.err.println("F-measure: (cross-validation) "+
846 optimizeFMeasure(instances, numPos));
847
848 System.err.println("Deleting columns...");
849 }
850
851 System.err.println("Deleting columns...");
852
853 // Deleting columns
854
855 del = new DeleteFilter();
856 del.setInvertSelection(true);
857 del.setAttributeIndices(indices(instances));
858 del.inputFormat(instances);
859 instancesDel = Filter.useFilter(instances, del);
860
861 if (theUseDiscretization) {
862
863 System.err.println("Discretizing columns...");
864
865 // Discretizing columns
866
867 theDisc = new DiscretiseFilter();
868 theDisc.setAttributeIndices(DISC_RANGE);
869 if (theEqual) {
870 theDisc.setBins(theNumBins);
871 theDisc.setClassIndex(-1);
872 theDisc.setUseMDL(false);
873 } else
874 theDisc.setClassIndex(instancesDel.numAttributes());
875 theDisc.inputFormat(instancesDel);
876 instancesDisc = Filter.useFilter(instancesDel, theDisc);
877 } else
878 instancesDisc = instancesDel;
879
880 System.err.println("Building classifier...");
881
882 // Build the classifier.
883
884 theClassifier.buildClassifier(instancesDisc);
885
886 System.err.println("\n"+theClassifier);
887
888 if ((theNumPhrases <= 0) && (theUseTrain) && (theMinProb < 0)) {
889
890 System.err.println("Finding most probable keyphrases...");
891
892 // Find the most probable keyphrases.
893
894 mostProbablePhrases(instances, instancesDisc, bestInsts);
895
896 System.err.println("Finding best cutoffs...");
897
898 // Find the best cutoffs.
899
900 bestCutoffs(instances, bestInsts, numPos);
901 } else if ((theNumPhrases > 0) || (theMinProb >= 0)) {
902 theRankCutoff = theNumPhrases;
903 theProbCutoff = theMinProb;
904 }
905 }
906
907 /**
908 * Method that classifies phrases.
909 * @exception Exception if phrases can't be classified successfully.
910 */
911
912 public void classifyPhrases(InputStream inputStream) throws Exception {
913
914 Instances instances, instancesDel, instancesDisc;
915 DeleteFilter del;
916 int numPos;
917 double[][][] bestInsts =
918 new double[MAX_NUM_DOCS][theMaxRankCutoff][2];
919
920 if ((theNumPhrases > 0) || (theMinProb >= 0)) {
921 theRankCutoff = theNumPhrases;
922 theProbCutoff = theMinProb;
923 }
924
925 System.err.println("Preparing data...");
926
927 // Prepare data and find total number of positive
928 // examples.
929
930 instances = new Instances(inputStream);
931 instances.setClassIndex(instances.numAttributes() - 1);
932
933 System.err.println("Counting positive examples...");
934
935 // Number of positive examples?
936
937 numPos = findMissedPositives(instances);
938
939 System.err.println("Deleting columns...");
940
941 // Deleting columns
942
943 del = new DeleteFilter();
944 del.setInvertSelection(true);
945 del.setAttributeIndices(indices(instances));
946 del.inputFormat(instances);
947 instancesDel = Filter.useFilter(instances, del);
948
949 if (theUseDiscretization)
950
951 // Discretizing columns
952
953 instancesDisc = Filter.useFilter(instancesDel, theDisc);
954 else
955 instancesDisc = instancesDel;
956
957 System.err.println("Finding most probable keyphrases...");
958
959 // Find the most probable keyphrases.
960
961 mostProbablePhrases(instances, instancesDisc, bestInsts);
962
963 System.err.println("Computing statistics...");
964
965 // Compute F-Measure.
966
967 System.err.println("Rank cutoff: "+theRankCutoff);
968 System.err.println("Probability cutoff: "+theProbCutoff);
969
970 System.err.println("F-Measure: "+
971 fMeasure(instances, bestInsts, theRankCutoff,
972 theProbCutoff, numPos, true));
973
974 // Output average precisions for ranks and confidence intervals
975
976 System.err.println("Rank\tPrec.\tConf.");
977 precisions(instances, bestInsts);
978 }
979
980 /**
981 * Generate results for different size training sets.
982 */
983
984 public void iterateTrainingSets(InputStream train, InputStream test)
985 throws Exception {
986
987 int[] sizes = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130};
988 Instances trainInst, testInst, tempTrainInst, tempTrainInstDisc, testInstDisc;
989 SelectFilter selFilter;
990 DeleteFilter del;
991 int[] indices;
992 int numPos;
993 double[][][] bestInsts;
994
995 trainInst = new Instances(train);
996 trainInst.setClassIndex(trainInst.numAttributes() - 1);
997 testInst = new Instances(test);
998 testInst.setClassIndex(testInst.numAttributes() - 1);
999 for (int i = 0; i < sizes.length; i++) {
1000
1001 System.err.println("Processing "+sizes[i]+" documents");
1002
1003 // Delete documents
1004
1005 System.err.println("Selecting documents...");
1006
1007 selFilter = new SelectFilter();
1008 indices = new int[sizes[i]];
1009 for (int j = 0; j < sizes[i]; j++)
1010 indices[j] = j;
1011 selFilter.setInvertSelection(false);
1012 selFilter.setAttributeIndex(2);
1013 selFilter.setNominalIndicesArr(indices);
1014 selFilter.inputFormat(trainInst);
1015 tempTrainInst = Filter.useFilter(trainInst, selFilter);
1016
1017 numPos = findMissedPositives(tempTrainInst);
1018
1019 System.err.println("Deleting columns...");
1020
1021 // Deleting columns
1022
1023 del = new DeleteFilter();
1024 del.setInvertSelection(true);
1025 del.setAttributeIndices(indices(tempTrainInst));
1026 del.inputFormat(tempTrainInst);
1027 tempTrainInstDisc = Filter.useFilter(tempTrainInst, del);
1028
1029 if (theUseDiscretization) {
1030
1031 System.err.println("Discretizing columns...");
1032
1033 // Discretizing columns
1034
1035 theDisc = new DiscretiseFilter();
1036 theDisc.setAttributeIndices(DISC_RANGE);
1037 if (theEqual) {
1038 theDisc.setBins(theNumBins);
1039 theDisc.setClassIndex(-1);
1040 theDisc.setUseMDL(false);
1041 } else
1042 theDisc.setClassIndex(tempTrainInstDisc.numAttributes());
1043 theDisc.inputFormat(tempTrainInstDisc);
1044 tempTrainInstDisc = Filter.useFilter(tempTrainInstDisc, theDisc);
1045 }
1046
1047 System.err.println("Building classifier...");
1048
1049 // Build the classifier.
1050
1051 theClassifier.buildClassifier(tempTrainInstDisc);
1052
1053 System.err.println("\n"+theClassifier);
1054
1055 if ((theNumPhrases <= 0) && (theUseTrain) && (theMinProb < 0)) {
1056
1057 System.err.println("Finding most probable keyphrases...");
1058
1059 // Find the most probable keyphrases.
1060
1061 bestInsts =
1062 new double[MAX_NUM_DOCS][theMaxRankCutoff][2];
1063 mostProbablePhrases(tempTrainInst, tempTrainInstDisc, bestInsts);
1064
1065 System.err.println("Finding best cutoffs...");
1066
1067 // Find the best cutoffs.
1068
1069 bestCutoffs(tempTrainInst, bestInsts, numPos);
1070 } else if ((theNumPhrases > 0) || (theMinProb >= 0)) {
1071 theRankCutoff = theNumPhrases;
1072 theProbCutoff = theMinProb;
1073 }
1074
1075 // Number of positive examples?
1076
1077 numPos = findMissedPositives(testInst);
1078
1079 System.err.println("Deleting columns...");
1080
1081 // Deleting columns
1082
1083 del = new DeleteFilter();
1084 del.setInvertSelection(true);
1085 del.setAttributeIndices(indices(testInst));
1086 del.inputFormat(testInst);
1087 testInstDisc = Filter.useFilter(testInst, del);
1088
1089 // Discretizing columns
1090
1091 testInstDisc = Filter.useFilter(testInstDisc, theDisc);
1092
1093 System.err.println("Finding most probable keyphrases...");
1094
1095 // Find the most probable keyphrases.
1096
1097 bestInsts =
1098 new double[MAX_NUM_DOCS][theMaxRankCutoff][2];
1099 mostProbablePhrases(testInst, testInstDisc, bestInsts);
1100
1101 System.err.println("Computing statistics...");
1102
1103 // Compute F-Measure.
1104
1105 System.err.println("Rank cutoff: "+theRankCutoff);
1106 System.err.println("Probability cutoff: "+theProbCutoff);
1107
1108 System.err.println("F-Measure: "+
1109 fMeasure(testInst, bestInsts, theRankCutoff,
1110 theProbCutoff, numPos, false));
1111
1112 // Output average precisions for ranks and confidence intervals
1113
1114 System.err.println("Rank\tPrec.\tConf.");
1115 precisions(testInst, bestInsts);
1116 }
1117 }
1118
1119 /**
1120 * Returns an enumeration describing the available options
1121 * @return an enumeration of all the available options
1122 */
1123
1124 public Enumeration listOptions() {
1125
1126 Vector newVector = new Vector(14);
1127
1128 newVector.addElement(new Option(
1129 "\tName of the training file.",
1130 "t", 1, "-t <training file>"));
1131 newVector.addElement(new Option(
1132 "\tName of the test file.",
1133 "T", 1, "-T <test file>"));
1134 newVector.addElement(new Option(
1135 "\tSpecify list of columns to to be used. Class column has to\n"+
1136 "\tbe listed.",
1137 "R", 1, "-R <name1,name2-name4,...>"));
1138 newVector.addElement(new Option(
1139 "\tUse equal width discretization.",
1140 "E", 0, "-E"));
1141 newVector.addElement(new Option(
1142 "\tDelete unlikely subphrases. Takes column to be used\n"+
1143 "\tin conjunction with probability to delete unlikely\n"+
1144 "\tsubphrases. Higher value is better. (\"\" means no column)",
1145 "S", 1, "-S <column name>"));
1146 newVector.addElement(new Option(
1147 "\tDelete unlikely subphrases. Takes column to be used in\n"+
1148 "\tconjunction with probability to delete unlikely subphrases.\n"+
1149 "\tLower value is better. (\"\" means no column)",
1150 "s", 1, "-s <column name>"));
1151 newVector.addElement(new Option(
1152 "\tSecondary key for sorting phrases in table. Primary key:\n"+
1153 "\tprobability. Higher value is better. (\"\" means no secondary key)",
1154 "K", 1, "-K <column name>"));
1155 newVector.addElement(new Option(
1156 "\tSecondary key for sorting phrases in table. Primary key:\n"+
1157 "\tprobability. Lower value is better. (\"\" means no secondary key)",
1158 "k", 1, "-k <column name>"));
1159 newVector.addElement(new Option(
1160 "\tColumn to be used for batched cross-validation.\n"+
1161 "\t(\"\" means standard 5-fold cross-validation)",
1162 "B", 1, "-B <column name>"));
1163 newVector.addElement(new Option(
1164 "\tThe maximum rank cutoff. (Default 10)",
1165 "X", 1, "-X <rank cutoff>"));
1166 newVector.addElement(new Option(
1167 "\tThe minimum rank cutoff. (Default 1)",
1168 "M", 1, "-M <rank cutoff>"));
1169 newVector.addElement(new Option(
1170 "\tDon't optimize F-measure. Just output the given number of\n"+
1171 "\tphrases.",
1172 "N", 1, "-N <number>"));
1173 newVector.addElement(new Option(
1174 "\tDon't optimize F-measure. Just output phrases with probablity\n"+
1175 "\tgreater or equal to the given one.",
1176 "P", 1, "-P <prob>"));
1177
1178 return newVector.elements();
1179 }
1180
1181
1182 /**
1183 * Parses a given list of options.
1184 * @param options the list of options as an array of strings
1185 * @exception Exception if an option is not supported
1186 */
1187
1188 public void parseOptions(String[] options) throws Exception {
1189
1190 String numPhrases = Utils.getOption('N', options);
1191 if (numPhrases.length() != 0)
1192 theNumPhrases = Integer.parseInt(numPhrases);
1193
1194 String minProb = Utils.getOption('P', options);
1195 if (minProb.length() != 0)
1196 theMinProb = (new Double(minProb)).doubleValue();
1197
1198 String useList = Utils.getOption('R', options);
1199 if (useList.length() != 0) {
1200 StringTokenizer st = new StringTokenizer(useList, ",", false);
1201 theCols = new String[st.countTokens()];
1202 for (int i = 0; i < theCols.length; i++)
1203 theCols[i] = st.nextToken();
1204 }
1205
1206 if (Utils.getFlag('E', options)) {
1207 theEqual = true;
1208 }
1209
1210 if (Utils.getFlag('D', options)) {
1211 theUseDiscretization = false;
1212 }
1213
1214 String subphrases = Utils.getOption('S', options);
1215 if (subphrases.length() != 0) {
1216 theSubphraseColumnName = subphrases;
1217 theSubphraseHi = true;
1218 }
1219
1220 String subphrasesII = Utils.getOption('s', options);
1221 if (subphrasesII.length() != 0) {
1222 if (subphrases.length() != 0)
1223 throw new Exception("Can't use -S and -s at the same time!");
1224 theSubphraseColumnName = subphrasesII;
1225 theSubphraseHi = false;
1226 }
1227
1228 String key = Utils.getOption('K', options);
1229 if (key.length() != 0) {
1230 theSecKeyColumnName = key;
1231 theSecKeyHi = true;
1232 }
1233
1234 String keyII = Utils.getOption('k', options);
1235 if (keyII.length() != 0) {
1236 if (key.length() != 0)
1237 throw new Exception("Can't use -K and -k at the same time!");
1238 theSecKeyColumnName = keyII;
1239 theSecKeyHi = false;
1240 }
1241
1242 String cv = Utils.getOption('B', options);
1243 if (cv.length() != 0) {
1244 theUseTrain = false;
1245 theCVColumnName = cv;
1246 }
1247
1248 String maxRankCutoff = Utils.getOption('X', options);
1249 if (maxRankCutoff.length() != 0)
1250 theMaxRankCutoff = Integer.parseInt(maxRankCutoff);
1251
1252 String minRankCutoff = Utils.getOption('M', options);
1253 if (minRankCutoff.length() != 0)
1254 theMinRankCutoff = Integer.parseInt(minRankCutoff);
1255
1256 if (theNumPhrases > theMaxRankCutoff)
1257 theMaxRankCutoff = theNumPhrases;
1258 if (theNumPhrases < theMinRankCutoff)
1259 theMinRankCutoff = theNumPhrases;
1260 }
1261
1262 // ======================================
1263 // Main method for command line interface
1264 // ======================================
1265
1266 public static void main(String[] opts) {
1267
1268 KEP kep;
1269 BufferedInputStream trainFile, testFile;
1270 ObjectInputStream modelIn;
1271 ObjectOutputStream modelOut;
1272
1273 try{
1274
1275 if (Utils.getFlag('I', opts)) {
1276 System.err.println("Iterating over training sets...");
1277 String train = Utils.getOption('t', opts);
1278 String test = Utils.getOption('T', opts);
1279 kep = new KEP();
1280 kep.parseOptions(opts);
1281 trainFile = new BufferedInputStream(new FileInputStream(train));
1282 testFile = new BufferedInputStream(new FileInputStream(test));
1283 kep.iterateTrainingSets(trainFile, testFile);
1284 } else {
1285 String model = Utils.getOption('m', opts);
1286 String train = Utils.getOption('t', opts);
1287 if (model.length() != 0)
1288 if (train.length() != 0) {
1289 System.err.println("Building classifier and saving it...");
1290 kep = new KEP();
1291 kep.parseOptions(opts);
1292 trainFile = new BufferedInputStream(new FileInputStream(train));
1293 kep.buildClassifier(trainFile);
1294 modelOut =
1295 new ObjectOutputStream(new
1296 GZIPOutputStream(new FileOutputStream(model)));
1297 modelOut.writeObject(kep);
1298 modelOut.flush();
1299 modelOut.close();
1300 } else {
1301 System.err.println("Loading classifier...");
1302 modelIn =
1303 new ObjectInputStream(new
1304 GZIPInputStream(new FileInputStream(model)));
1305 kep = (KEP) modelIn.readObject();
1306 modelIn.close();
1307 kep.parseOptions(opts);
1308 }
1309 else if (train.length() != 0) {
1310 System.err.println("Building classifier...");
1311 kep = new KEP();
1312 kep.parseOptions(opts);
1313 trainFile = new BufferedInputStream(new FileInputStream(train));
1314 kep.buildClassifier(trainFile);
1315 } else
1316 throw new Exception("Neither classifier nor training file given!");
1317
1318 String test = Utils.getOption('T', opts);
1319 if (test.length() != 0) {
1320 System.err.println("Classifying phrases...");
1321 testFile = new BufferedInputStream(new FileInputStream(test));
1322 kep.classifyPhrases(testFile);
1323 }
1324 }
1325 } catch (Exception e) {
1326 System.err.println("\nOptions:\n");
1327 kep = new KEP();
1328 Enumeration enum = kep.listOptions();
1329 while (enum.hasMoreElements()) {
1330 Option option = (Option) enum.nextElement();
1331 System.err.println(option.synopsis()+'\n'+option.description());
1332 }
1333 System.err.println();
1334 e.printStackTrace();
1335 }
1336 }
1337}
1338
Note: See TracBrowser for help on using the repository browser.