/* * KEP.java * Version 1.1 * * Kea -- Automatic Keyphrase Extraction * Copyright 1998-1999 by Eibe Frank and Gordon Paynter * Contact eibe@cs.waikato.ac.nz or gwp@cs.waikato.ac.nz * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * KEP is the machine learning and post-processing part of KEA, * the Keyphrase Extraction Algorithm. * * KEP assumes that the class is the last attribute and * that keyphrases not found in the text have a * missing class value. * * @author Eibe Frank (eibe@cs.waikato.ac.nz) */ /* * Version History * * 1.0 Witten et.al version * 1.1 First distribution * 1.1.1 Bug: when more output requested than there are (filtered) * instances in the arff file, Kea would output garbage * STDOUT output cleaned up in training stage */ import java.io.*; import java.util.*; import java.util.zip.*; import jaws.corePackage.*; import jaws.evaluationPackage.*; import jaws.classifierPackage.*; import jaws.filterPackage.*; public class KEP implements Serializable { // ================ // Public variables // ================ /** * The maximum number of documents in * training and test corpus. */ public final static int MAX_NUM_DOCS = 1000; /** * Index of the column with the document name. */ public final static int DOC_NUM_INDEX = 1; /** * Number of steps for optimization of * probability cutoff. */ public final static int PROB_STEPS = 100; /** * Minimum probability that a phrase * has to achieve. */ public final static double MIN_PROB = 0.00; /** * Maximum number of bins for equal * width discretization. */ public final static int MAX_BINS = 30; /** * Minimum number of bins for equal * width discretization. */ public final static int MIN_BINS = 2; /** * Number of folds for standard * cross-validation. */ public final static int NUM_FOLDS = 5; /** * Columns to be discretized. */ public final static String DISC_RANGE = "first-last"; // ================= // Private variables // ================= /** * The classifier for determining * keyphrases. */ private DistributionClassifier theClassifier = new NaiveBayes(); /** * The number of phrases to be output. */ private int theNumPhrases = -1; /** * The minimum probability a phrase has * to achieve. */ private double theMinProb = -1; /** * Rank cutoff. */ private int theRankCutoff; /** * The maximum rank cutoff. */ private int theMaxRankCutoff = 20; /** * The minimum rank cutoff. */ private int theMinRankCutoff = 1; /** * Probability cutoff. */ private double theProbCutoff; /** * The number of bins for equal width discretization. */ private int theNumBins = 23; /** * The discretizer built up on the training data. */ private DiscretiseFilter theDisc; /** * Equal width discretization? */ private boolean theEqual = false; /** * Columns in arff file to be used for building * the classifier. */ private String[] theCols = {"tfidf","first_occurrence","class"}; /** * Column to be used for subphrase selection. */ private String theSubphraseColumnName = "tfidf"; /** * Column for subphrase sellection: higher is better */ private boolean theSubphraseHi = true; /** * The secondary key for sorting the phrases in * the table (primary key: probability) */ private String theSecKeyColumnName = "tfidf"; /** * Column for secondary key: higher is better */ private boolean theSecKeyHi = true; /** * Column to be used for selecting batches for the * cross-validation. */ private String theCVColumnName = ""; /** * Use training data instead of cross-validation? */ private boolean theUseTrain = true; /** * Use discretization? */ private boolean theUseDiscretization = true; // ============== // Public methods // ============== /** * Returns the indices for the given attribute names. * @param name the attributes' names * @return ind the indices * @exception Exception if something goes wrong */ private String indices(Instances dataset) { String rangeList = ""; for(int i = 0; i < theCols.length; i++) if (i == 0) rangeList = ""+(dataset.attribute(theCols[i]).index() + 1); else rangeList += ","+(dataset.attribute(theCols[i]).index() + 1); return rangeList; } /** * Performs a cross-validation on the given dataset. * @returns the best F-Measure for all probability * and rank cutoffs */ private double[][][] crossValidation(Instances instances, int numPos) throws Exception { SelectFilter selFilter; DiscretiseFilter discFilter; DeleteFilter delFilter; Instances train, test, trainDisc, testDisc; int numFolds, indexI, indexP; boolean batchedMode = false; double[][][] results = new double[theMaxRankCutoff][PROB_STEPS][2]; double[][][] bestInsts; double[] currentTPandFP; int[] range = new int[1]; int theCVColumn = -1; if (theCVColumnName.length() > 0) theCVColumn = instances.attribute(theCVColumnName).index(); // Is there more than one document in the dataset? if ((theCVColumn != -1) && (instances.attribute(theCVColumn).numValues() > 1)) { batchedMode = true; numFolds = instances.attribute(theCVColumn).numValues(); } else numFolds = NUM_FOLDS; // Perform the cross-validation for (int k = 0; k < numFolds; k++) { System.err.println("Processing fold: "+k); if (batchedMode) { // Batched cross-validation selFilter = new SelectFilter(); selFilter.setInvertSelection(true); selFilter.setAttributeIndex(theCVColumn + 1); selFilter.inputFormat(instances); // Generate training set range[0] = k; selFilter.setNominalIndicesArr(range); train = Filter.useFilter(instances, selFilter); // Generate test set selFilter.setInvertSelection(false); test = Filter.useFilter(instances, selFilter); } else { // Standard NUM_FOLDS cross-validation train = instances.trainCV(NUM_FOLDS, k); test = instances.testCV(NUM_FOLDS, k); } System.err.println("Choosing attributes..."); // Choose attributes in training and test data delFilter = new DeleteFilter(); delFilter.setInvertSelection(true); delFilter.setAttributeIndices(indices(instances)); delFilter.inputFormat(instances); trainDisc = Filter.useFilter(train, delFilter); testDisc = Filter.useFilter(test, delFilter); System.err.println("Discretizing..."); // Discretise training and test data discFilter = new DiscretiseFilter(); discFilter.setAttributeIndices(DISC_RANGE); if (theEqual) { discFilter.setBins(theNumBins); discFilter.setClassIndex(-1); discFilter.setUseMDL(false); } else discFilter.setClassIndex(trainDisc.numAttributes()); discFilter.inputFormat(trainDisc); trainDisc = Filter.useFilter(trainDisc, discFilter); testDisc = Filter.useFilter(testDisc, discFilter); System.err.println("Building classifier..."); // Build classifier on training data theClassifier.buildClassifier(trainDisc); // Find the most probable keyphrases in the test data System.err.println("Looking for the most probable keyphrases..."); bestInsts = new double[MAX_NUM_DOCS][theMaxRankCutoff][2]; mostProbablePhrases(test, testDisc, bestInsts); // Compute F-measure for all probability and rank // cutoffs and add to sum. indexI = 0; for (int i = theMinRankCutoff; i <= theMaxRankCutoff; i++) { indexP = 0; for (double p = 1 - (1.0/PROB_STEPS); p >= MIN_PROB; p = p - (1.0/PROB_STEPS)) { currentTPandFP = TPandFP(test, bestInsts, i, p, false); results[indexI][indexP][0] += currentTPandFP[0]; results[indexI][indexP][1] += currentTPandFP[1]; indexP++; } indexI++; } } return results; } /** * Optimizes the F-measure by choosing the discretization and the * rank and probability cutoffs. */ private double optimizeFMeasure(Instances instances, int numPos) throws Exception { double[][][][] results; double currentFMeasure, maxFMeasure = 0; int temp_MAX_BINS = MIN_BINS; if (theEqual) temp_MAX_BINS = MAX_BINS; // Do cross-validation for different number of bins. results = new double[temp_MAX_BINS - MIN_BINS + 1][theMaxRankCutoff][PROB_STEPS][2]; for (int b = MIN_BINS; b <= temp_MAX_BINS; b++) { if (theEqual) { System.err.println("Computing Estimate for "+b+" bins"); theNumBins = b; } results[b - MIN_BINS] = crossValidation(instances, numPos); } // Choose best parameter settings for (int b = MIN_BINS; b <= temp_MAX_BINS; b++) { // Find best probability and rank cutoffs int indexI = 0; for (int i = theMinRankCutoff; i <= theMaxRankCutoff; i++) { int indexP = 0; for (double p = 1-(1.0/PROB_STEPS); p >= MIN_PROB; p=p-(1.0/PROB_STEPS)) { currentFMeasure = fMeasure(results[b - MIN_BINS][indexI][indexP][0], results[b - MIN_BINS][indexI][indexP][1], numPos - results[b - MIN_BINS][indexI][indexP][0]); if (Utils.grOrEq(currentFMeasure, maxFMeasure)) { theRankCutoff = i; theProbCutoff = p; theNumBins = b; maxFMeasure = currentFMeasure; } indexP++; } indexI++; } } return maxFMeasure; } /** * Computes true positives and false positives for the dataset. */ private double[] TPandFP(Instances dataset, double[][][] bestInsts, int rankCutoff, double probCutoff, boolean output) throws Exception { double[] TPandFP = new double[2]; // For all documents for (int i = 0; i < dataset.attribute(DOC_NUM_INDEX).numValues(); i++) { if (output) System.out.println("Current document: "+ dataset.attribute(DOC_NUM_INDEX).value(i)); // Up to the current rankcutoff for (int j = 0; j < rankCutoff; j++) { // Don't output trash (instances with probability 0) if (bestInsts[i][j][0] == 0) { break; } // Find the true positives and false positives if (Utils.sm(bestInsts[i][j][0], probCutoff)) break; if(dataset.instance((int)bestInsts[i][j][1]).classValue() == 0) { TPandFP[0]++; if (output) System.out.println("Hit: "+dataset.instance((int)bestInsts[i][j][1])+ " "+bestInsts[i][j][0]); } else { TPandFP[1]++; if (output) System.out.println("Miss: "+dataset.instance((int)bestInsts[i][j][1])+ " "+bestInsts[i][j][0]); } } } return TPandFP; } /** * Computes and outputs precision per document for rank cutoffs 1 to 20. * Also outputs average precision and standard deviation. */ private void precisions(Instances dataset, double[][][] bestInsts) throws Exception { double[] precisions; int i, tp, fp; double standardError, deltaConfidence; // For all rank cutoffs for (int r = 1; r <= theMaxRankCutoff; r++) { // For all documents precisions = new double[dataset.attribute(DOC_NUM_INDEX).numValues()]; for (i = 0; i < dataset.attribute(DOC_NUM_INDEX).numValues(); i++) { // Up to the current rankcutoff tp = fp = 0; for (int j = 0; j < r; j++) { // Find the true positives and false positives if(dataset.instance((int)bestInsts[i][j][1]).classValue() == 0) tp++; else fp++; } precisions[i] = ((double) tp) / ((double) tp + fp); } System.err.print(r); System.err.print("\t"+Utils. doubleToString(Utils.sum(precisions) / i, 6,4)); standardError = Math.sqrt(Utils.variance(precisions) / i); deltaConfidence = Statistics.studentTConfidenceInterval(i - 1, 0.05, standardError); System.err.println("\t"+Utils.doubleToString(deltaConfidence, 6,4)+" &&&"); } } /** * Computes F-Measure. */ private double fMeasure(double tp, double fp, double fn) { return (2 * tp) / ((2 * tp) + fp + fn); } /** * Computes the F-Measure for the given dataset. */ private double fMeasure(Instances dataset, double[][][] bestInsts, int rankCutoff, double probCutoff, int numPos, boolean output) throws Exception { double[] TPandFP = TPandFP(dataset, bestInsts, rankCutoff, probCutoff, output); double fn = (double)numPos - TPandFP[0]; double precision, recall; if (output) { System.out.println("tp: "+TPandFP[0]+"\tfp: "+TPandFP[1]+"\tfn: "+fn); precision = TPandFP[0]/(TPandFP[0]+TPandFP[1]); recall = TPandFP[0]/(TPandFP[0]+fn); System.out.println("Precision: "+precision); System.out.println("Recall: "+recall); System.out.println("F-measure: "+((2*precision*recall)/(precision+recall))); } return fMeasure(TPandFP[0], TPandFP[1], fn); } /** * Computes the best cutoff. */ private void bestCutoffs(Instances dataset, double[][][] bestInsts, int numPos) throws Exception { double maxFMeasure = 0, currentFMeasure; theRankCutoff = 0; theProbCutoff = 1; for (int i = theMinRankCutoff; i <= theMaxRankCutoff; i++) { for (double p = 1-(1.0/PROB_STEPS); p >= MIN_PROB; p = p-(1.0/PROB_STEPS)) { currentFMeasure = fMeasure(dataset, bestInsts, i, p, numPos, false); if (Utils.grOrEq(currentFMeasure, maxFMeasure)) { theRankCutoff = i; theProbCutoff = p; maxFMeasure = currentFMeasure; } // System.out.println("Current rank cutoff: "+i+ // "\tCurrent prob cutoff: "+p+ // "\tF-Measure: "+currentFMeasure); } } System.err.println("Best rank cutoff: "+theRankCutoff+ "\tBest prob cutoff: "+theProbCutoff+ "\tF-Measure: "+maxFMeasure); } /** * Finds and deletes all missed author-defined keyphrases in the * dataset and returns the number of positive examples. */ private int findMissedPositives(Instances data) throws Exception { int numPos = 0, missing = 0; System.err.println("Looking for missed keyphrases in the dataset..."); for (int i = 0; i < data.numInstances(); i++) { if (data.instance(i).classIsMissing()) { missing++; numPos++; } else { if (data.instance(i).classValue() == 0) numPos++; } } System.err.println("Found " + (int)data.attribute(DOC_NUM_INDEX).numValues() + " documents in the dataset."); System.err.println("Total number of examples missing: "+missing); System.err.println("Total number of positive examples: "+numPos); return numPos; } /** * Gets all the subphrases of an instance. */ public String[] getSubphrases(String phrase) { StringTokenizer tokenizer = new StringTokenizer(phrase); String[] words = new String[tokenizer.countTokens()], subphrases; int i = 0, numSubphrases = 0, s = 0; // Get all words in the phrase while (tokenizer.hasMoreElements()) { words[i] = tokenizer.nextToken(); i++; } // How many subphrases? for (i = 0; i < words.length - 1; i++) numSubphrases += words.length - i; // Array for storing the subphrases. subphrases = new String[numSubphrases]; // Compute subphrases. for (i = 0; i < words.length; i++) { // All subphrases at position i. for (int j = 0; j < words.length - i; j++) { // The subphrase of length j + 1, apart // of the phrase itself. if ((i != 0) || (j != (words.length - 1))) { if (j == 0) subphrases[s] = words[i]; else subphrases[s] = subphrases[s - 1]+' '+words[i + j]; s++; } } } return subphrases; } /** * Gets the stemmed phrase of an instance. */ public String getPhrase(Instance instance) throws Exception { String string = instance.attribute(0).value((int)instance.value(0)); return string.substring(1, string.indexOf('(') - 1); } /** * Find the theMaxRankCutoff most probable phrases for each document. * @param data the set of instances * @param bestInsts an array double[MAX_NUM_DOCS][theMaxRankCutoff][2]: * double[i][j][0] contains the probability, and * double[i][j][1] the instances index. * @exception Exception if an instance can't be classified */ public void mostProbablePhrases(Instances originalData, Instances data, double bestInsts[][][]) throws Exception { double[] probAndIndex, probAndIndexSubphrase; int numDoc; Hashtable[] hashtables = new Hashtable[originalData.attribute(DOC_NUM_INDEX). numValues()]; String[] subphrases; String phrase; int theSubphraseColumn = -1, theSecKeyColumn = -1; if (theSubphraseColumnName.length() > 0) theSubphraseColumn = originalData.attribute(theSubphraseColumnName).index(); if (theSecKeyColumnName.length() > 0) theSecKeyColumn = originalData.attribute(theSecKeyColumnName).index(); // Make a new hashtable for each document that stores the phrases. for (numDoc = 0; numDoc < originalData.attribute(DOC_NUM_INDEX).numValues(); numDoc++) hashtables[numDoc] = new Hashtable(originalData.numInstances()/ originalData.attribute(DOC_NUM_INDEX).numValues()); // Go through all the instances and keep the best phrases in hash tables. // (ie., their indices and probabilities.) Delete all subphrases that suck. for (int i = 0; i < data.numInstances(); i++) { // Do nothing if the class is missing. if (!data.instance(i).classIsMissing()) { // Store index and probability in an array. probAndIndex = new double[2]; probAndIndex[0] = theClassifier. distributionForInstance(data.instance(i))[0]; probAndIndex[1] = (double)i; // Shortcut -- getting rid of phrases with very small probability. if (Utils.sm(probAndIndex[0], MIN_PROB)) continue; // Which document does the current phrase belong to? numDoc = (int)originalData.instance(i).value(DOC_NUM_INDEX); // Get the actual phrase. phrase = getPhrase(originalData.instance(i)); // Put the phrase into the hashtable. hashtables[numDoc].put(phrase, probAndIndex); // Do we want to get rid of subphrases? if (theSubphraseColumn != -2 ) { // Get all subphrases. subphrases = getSubphrases(phrase); // Delete all subphrases with lower or equal probability from the // hash table. for (int j = 0; j < subphrases.length; j++) { probAndIndexSubphrase = (double[]) hashtables[numDoc].get(subphrases[j]); if ((probAndIndexSubphrase != null) && Utils.smOrEq(probAndIndexSubphrase[0], probAndIndex[0])) if ((theSubphraseColumn == -1) || ((theSubphraseHi && Utils.smOrEq(originalData. instance((int)probAndIndexSubphrase[1]). value(theSubphraseColumn), originalData.instance((int)probAndIndex[1]). value(theSubphraseColumn))) || ((!theSubphraseHi) && Utils.grOrEq(originalData. instance((int)probAndIndexSubphrase[1]). value(theSubphraseColumn), originalData.instance((int)probAndIndex[1]). value(theSubphraseColumn))))) hashtables[numDoc].remove(subphrases[j]); } } } } // Find the theMaxRankCutoff most probable phrases for each document. for (numDoc = 0; numDoc < originalData.attribute(DOC_NUM_INDEX).numValues(); numDoc++){ Enumeration enum = hashtables[numDoc].elements(); while (enum.hasMoreElements()) { probAndIndex = (double[]) enum.nextElement(); for (int j = 0; j < theMaxRankCutoff; j++) if (Utils.gr(probAndIndex[0], bestInsts[numDoc][j][0]) || ((theSecKeyColumn != -1) && Utils.eq(probAndIndex[0], bestInsts[numDoc][j][0]) && ((theSecKeyHi && Utils.gr(originalData.instance((int)probAndIndex[1]). value(theSecKeyColumn), originalData.instance((int)bestInsts[numDoc][j][1]). value(theSecKeyColumn))) || ((!theSecKeyHi) && Utils.sm(originalData.instance((int)probAndIndex[1]). value(theSecKeyColumn), originalData.instance((int)bestInsts[numDoc][j][1]). value(theSecKeyColumn)))))) { for (int k = theMaxRankCutoff - 2; k >= j; k--) { bestInsts[numDoc][k + 1][0] = bestInsts[numDoc][k][0]; bestInsts[numDoc][k + 1][1] = bestInsts[numDoc][k][1]; } bestInsts[numDoc][j][0] = probAndIndex[0]; bestInsts[numDoc][j][1] = probAndIndex[1]; break; } } } } /** * Method that builds classifier. * @exception Exception if classifier can't be built successfully */ public void buildClassifier(InputStream inputStream) throws Exception { Instances instances, instancesDel, instancesDisc; DeleteFilter del; int numPos; double[][][] bestInsts = new double[MAX_NUM_DOCS][theMaxRankCutoff][2]; System.err.println("Loading data..."); instances = new Instances(inputStream); instances.setClassIndex(instances.numAttributes() - 1); // Number of positive examples? System.err.println("Counting positive examples..."); numPos = findMissedPositives(instances); if ((theNumPhrases <= 0) && (theMinProb < 0) && (!theUseTrain)) { // Cross-validation for parameter settings System.err.println("Performing cross-validation..."); System.err.println("F-measure: (cross-validation) "+ optimizeFMeasure(instances, numPos)); System.err.println("Deleting columns..."); } System.err.println("Deleting columns..."); // Deleting columns del = new DeleteFilter(); del.setInvertSelection(true); del.setAttributeIndices(indices(instances)); del.inputFormat(instances); instancesDel = Filter.useFilter(instances, del); if (theUseDiscretization) { System.err.println("Discretizing columns..."); // Discretizing columns theDisc = new DiscretiseFilter(); theDisc.setAttributeIndices(DISC_RANGE); if (theEqual) { theDisc.setBins(theNumBins); theDisc.setClassIndex(-1); theDisc.setUseMDL(false); } else theDisc.setClassIndex(instancesDel.numAttributes()); theDisc.inputFormat(instancesDel); instancesDisc = Filter.useFilter(instancesDel, theDisc); } else instancesDisc = instancesDel; System.err.println("Building classifier..."); // Build the classifier. theClassifier.buildClassifier(instancesDisc); System.err.println("\n"+theClassifier); if ((theNumPhrases <= 0) && (theUseTrain) && (theMinProb < 0)) { System.err.println("Finding most probable keyphrases..."); // Find the most probable keyphrases. mostProbablePhrases(instances, instancesDisc, bestInsts); System.err.println("Finding best cutoffs..."); // Find the best cutoffs. bestCutoffs(instances, bestInsts, numPos); } else if ((theNumPhrases > 0) || (theMinProb >= 0)) { theRankCutoff = theNumPhrases; theProbCutoff = theMinProb; } } /** * Method that classifies phrases. * @exception Exception if phrases can't be classified successfully. */ public void classifyPhrases(InputStream inputStream) throws Exception { Instances instances, instancesDel, instancesDisc; DeleteFilter del; int numPos; double[][][] bestInsts = new double[MAX_NUM_DOCS][theMaxRankCutoff][2]; if ((theNumPhrases > 0) || (theMinProb >= 0)) { theRankCutoff = theNumPhrases; theProbCutoff = theMinProb; } System.err.println("Preparing data..."); // Prepare data and find total number of positive // examples. instances = new Instances(inputStream); instances.setClassIndex(instances.numAttributes() - 1); System.err.println("Counting positive examples..."); // Number of positive examples? numPos = findMissedPositives(instances); System.err.println("Deleting columns..."); // Deleting columns del = new DeleteFilter(); del.setInvertSelection(true); del.setAttributeIndices(indices(instances)); del.inputFormat(instances); instancesDel = Filter.useFilter(instances, del); if (theUseDiscretization) // Discretizing columns instancesDisc = Filter.useFilter(instancesDel, theDisc); else instancesDisc = instancesDel; System.err.println("Finding most probable keyphrases..."); // Find the most probable keyphrases. mostProbablePhrases(instances, instancesDisc, bestInsts); System.err.println("Computing statistics..."); // Compute F-Measure. System.err.println("Rank cutoff: "+theRankCutoff); System.err.println("Probability cutoff: "+theProbCutoff); System.err.println("F-Measure: "+ fMeasure(instances, bestInsts, theRankCutoff, theProbCutoff, numPos, true)); // Output average precisions for ranks and confidence intervals System.err.println("Rank\tPrec.\tConf."); precisions(instances, bestInsts); } /** * Generate results for different size training sets. */ public void iterateTrainingSets(InputStream train, InputStream test) throws Exception { int[] sizes = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130}; Instances trainInst, testInst, tempTrainInst, tempTrainInstDisc, testInstDisc; SelectFilter selFilter; DeleteFilter del; int[] indices; int numPos; double[][][] bestInsts; trainInst = new Instances(train); trainInst.setClassIndex(trainInst.numAttributes() - 1); testInst = new Instances(test); testInst.setClassIndex(testInst.numAttributes() - 1); for (int i = 0; i < sizes.length; i++) { System.err.println("Processing "+sizes[i]+" documents"); // Delete documents System.err.println("Selecting documents..."); selFilter = new SelectFilter(); indices = new int[sizes[i]]; for (int j = 0; j < sizes[i]; j++) indices[j] = j; selFilter.setInvertSelection(false); selFilter.setAttributeIndex(2); selFilter.setNominalIndicesArr(indices); selFilter.inputFormat(trainInst); tempTrainInst = Filter.useFilter(trainInst, selFilter); numPos = findMissedPositives(tempTrainInst); System.err.println("Deleting columns..."); // Deleting columns del = new DeleteFilter(); del.setInvertSelection(true); del.setAttributeIndices(indices(tempTrainInst)); del.inputFormat(tempTrainInst); tempTrainInstDisc = Filter.useFilter(tempTrainInst, del); if (theUseDiscretization) { System.err.println("Discretizing columns..."); // Discretizing columns theDisc = new DiscretiseFilter(); theDisc.setAttributeIndices(DISC_RANGE); if (theEqual) { theDisc.setBins(theNumBins); theDisc.setClassIndex(-1); theDisc.setUseMDL(false); } else theDisc.setClassIndex(tempTrainInstDisc.numAttributes()); theDisc.inputFormat(tempTrainInstDisc); tempTrainInstDisc = Filter.useFilter(tempTrainInstDisc, theDisc); } System.err.println("Building classifier..."); // Build the classifier. theClassifier.buildClassifier(tempTrainInstDisc); System.err.println("\n"+theClassifier); if ((theNumPhrases <= 0) && (theUseTrain) && (theMinProb < 0)) { System.err.println("Finding most probable keyphrases..."); // Find the most probable keyphrases. bestInsts = new double[MAX_NUM_DOCS][theMaxRankCutoff][2]; mostProbablePhrases(tempTrainInst, tempTrainInstDisc, bestInsts); System.err.println("Finding best cutoffs..."); // Find the best cutoffs. bestCutoffs(tempTrainInst, bestInsts, numPos); } else if ((theNumPhrases > 0) || (theMinProb >= 0)) { theRankCutoff = theNumPhrases; theProbCutoff = theMinProb; } // Number of positive examples? numPos = findMissedPositives(testInst); System.err.println("Deleting columns..."); // Deleting columns del = new DeleteFilter(); del.setInvertSelection(true); del.setAttributeIndices(indices(testInst)); del.inputFormat(testInst); testInstDisc = Filter.useFilter(testInst, del); // Discretizing columns testInstDisc = Filter.useFilter(testInstDisc, theDisc); System.err.println("Finding most probable keyphrases..."); // Find the most probable keyphrases. bestInsts = new double[MAX_NUM_DOCS][theMaxRankCutoff][2]; mostProbablePhrases(testInst, testInstDisc, bestInsts); System.err.println("Computing statistics..."); // Compute F-Measure. System.err.println("Rank cutoff: "+theRankCutoff); System.err.println("Probability cutoff: "+theProbCutoff); System.err.println("F-Measure: "+ fMeasure(testInst, bestInsts, theRankCutoff, theProbCutoff, numPos, false)); // Output average precisions for ranks and confidence intervals System.err.println("Rank\tPrec.\tConf."); precisions(testInst, bestInsts); } } /** * Returns an enumeration describing the available options * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(14); newVector.addElement(new Option( "\tName of the training file.", "t", 1, "-t ")); newVector.addElement(new Option( "\tName of the test file.", "T", 1, "-T ")); newVector.addElement(new Option( "\tSpecify list of columns to to be used. Class column has to\n"+ "\tbe listed.", "R", 1, "-R ")); newVector.addElement(new Option( "\tUse equal width discretization.", "E", 0, "-E")); newVector.addElement(new Option( "\tDelete unlikely subphrases. Takes column to be used\n"+ "\tin conjunction with probability to delete unlikely\n"+ "\tsubphrases. Higher value is better. (\"\" means no column)", "S", 1, "-S ")); newVector.addElement(new Option( "\tDelete unlikely subphrases. Takes column to be used in\n"+ "\tconjunction with probability to delete unlikely subphrases.\n"+ "\tLower value is better. (\"\" means no column)", "s", 1, "-s ")); newVector.addElement(new Option( "\tSecondary key for sorting phrases in table. Primary key:\n"+ "\tprobability. Higher value is better. (\"\" means no secondary key)", "K", 1, "-K ")); newVector.addElement(new Option( "\tSecondary key for sorting phrases in table. Primary key:\n"+ "\tprobability. Lower value is better. (\"\" means no secondary key)", "k", 1, "-k ")); newVector.addElement(new Option( "\tColumn to be used for batched cross-validation.\n"+ "\t(\"\" means standard 5-fold cross-validation)", "B", 1, "-B ")); newVector.addElement(new Option( "\tThe maximum rank cutoff. (Default 10)", "X", 1, "-X ")); newVector.addElement(new Option( "\tThe minimum rank cutoff. (Default 1)", "M", 1, "-M ")); newVector.addElement(new Option( "\tDon't optimize F-measure. Just output the given number of\n"+ "\tphrases.", "N", 1, "-N ")); newVector.addElement(new Option( "\tDon't optimize F-measure. Just output phrases with probablity\n"+ "\tgreater or equal to the given one.", "P", 1, "-P ")); return newVector.elements(); } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void parseOptions(String[] options) throws Exception { String numPhrases = Utils.getOption('N', options); if (numPhrases.length() != 0) theNumPhrases = Integer.parseInt(numPhrases); String minProb = Utils.getOption('P', options); if (minProb.length() != 0) theMinProb = (new Double(minProb)).doubleValue(); String useList = Utils.getOption('R', options); if (useList.length() != 0) { StringTokenizer st = new StringTokenizer(useList, ",", false); theCols = new String[st.countTokens()]; for (int i = 0; i < theCols.length; i++) theCols[i] = st.nextToken(); } if (Utils.getFlag('E', options)) { theEqual = true; } if (Utils.getFlag('D', options)) { theUseDiscretization = false; } String subphrases = Utils.getOption('S', options); if (subphrases.length() != 0) { theSubphraseColumnName = subphrases; theSubphraseHi = true; } String subphrasesII = Utils.getOption('s', options); if (subphrasesII.length() != 0) { if (subphrases.length() != 0) throw new Exception("Can't use -S and -s at the same time!"); theSubphraseColumnName = subphrasesII; theSubphraseHi = false; } String key = Utils.getOption('K', options); if (key.length() != 0) { theSecKeyColumnName = key; theSecKeyHi = true; } String keyII = Utils.getOption('k', options); if (keyII.length() != 0) { if (key.length() != 0) throw new Exception("Can't use -K and -k at the same time!"); theSecKeyColumnName = keyII; theSecKeyHi = false; } String cv = Utils.getOption('B', options); if (cv.length() != 0) { theUseTrain = false; theCVColumnName = cv; } String maxRankCutoff = Utils.getOption('X', options); if (maxRankCutoff.length() != 0) theMaxRankCutoff = Integer.parseInt(maxRankCutoff); String minRankCutoff = Utils.getOption('M', options); if (minRankCutoff.length() != 0) theMinRankCutoff = Integer.parseInt(minRankCutoff); if (theNumPhrases > theMaxRankCutoff) theMaxRankCutoff = theNumPhrases; if (theNumPhrases < theMinRankCutoff) theMinRankCutoff = theNumPhrases; } // ====================================== // Main method for command line interface // ====================================== public static void main(String[] opts) { KEP kep; BufferedInputStream trainFile, testFile; ObjectInputStream modelIn; ObjectOutputStream modelOut; try{ if (Utils.getFlag('I', opts)) { System.err.println("Iterating over training sets..."); String train = Utils.getOption('t', opts); String test = Utils.getOption('T', opts); kep = new KEP(); kep.parseOptions(opts); trainFile = new BufferedInputStream(new FileInputStream(train)); testFile = new BufferedInputStream(new FileInputStream(test)); kep.iterateTrainingSets(trainFile, testFile); } else { String model = Utils.getOption('m', opts); String train = Utils.getOption('t', opts); if (model.length() != 0) if (train.length() != 0) { System.err.println("Building classifier and saving it..."); kep = new KEP(); kep.parseOptions(opts); trainFile = new BufferedInputStream(new FileInputStream(train)); kep.buildClassifier(trainFile); modelOut = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(model))); modelOut.writeObject(kep); modelOut.flush(); modelOut.close(); } else { System.err.println("Loading classifier..."); modelIn = new ObjectInputStream(new GZIPInputStream(new FileInputStream(model))); kep = (KEP) modelIn.readObject(); modelIn.close(); kep.parseOptions(opts); } else if (train.length() != 0) { System.err.println("Building classifier..."); kep = new KEP(); kep.parseOptions(opts); trainFile = new BufferedInputStream(new FileInputStream(train)); kep.buildClassifier(trainFile); } else throw new Exception("Neither classifier nor training file given!"); String test = Utils.getOption('T', opts); if (test.length() != 0) { System.err.println("Classifying phrases..."); testFile = new BufferedInputStream(new FileInputStream(test)); kep.classifyPhrases(testFile); } } } catch (Exception e) { System.err.println("\nOptions:\n"); kep = new KEP(); Enumeration enum = kep.listOptions(); while (enum.hasMoreElements()) { Option option = (Option) enum.nextElement(); System.err.println(option.synopsis()+'\n'+option.description()); } System.err.println(); e.printStackTrace(); } } }