Changeset 34798
- Timestamp:
- 2021-02-02T22:33:22+13:00 (3 years ago)
- Location:
- main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaApplyArousalModel.java
r34797 r34798 1 1 package org.greenstone.mars; 2 3 //import java.util.Random;4 2 5 3 import java.io.BufferedInputStream; … … 28 26 class WekaApplyArousalModel 29 27 { 30 31 //public final static String PREDICT_ATTRIBUTE_NAME = WekaUtil.AROUSAL_ATTRIBUTE_NAME;32 33 34 35 28 public static void main(String[] args) 36 29 { 37 /*38 if (args.length != 3) {39 System.err.println("Error: incorrect number of command-line arguments");40 System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}");41 System.exit(1);42 }43 */44 45 30 WekaUtil.checkUsageApplyModel(args); 46 31 … … 49 34 String classified_data_output_filename = args[2]; 50 35 51 52 36 Classifier classifier = WekaUtil.loadClassifierModel(classifier_input_filename); 53 54 /*55 System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);56 57 try {58 FileInputStream fis = new FileInputStream(classifier_input_filename);59 BufferedInputStream bis= new BufferedInputStream(fis);60 Classifier classifier = (Classifier)SerializationHelper.read(bis);61 62 //63 // Load in unlabeled data64 //65 66 System.out.println("Loading unlabeled instances: " + unclassified_data_input_filename);67 68 DataSource data_source = new DataSource(unclassified_data_input_filename);69 Instances unlabeled_instances = data_source.getDataSet();70 71 // Work out if we're dealing with a ground-truth ARFF file or not72 // (i.e. already has the desired attribute)73 74 Instances groundtruth_instances = null;75 Attribute predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME);76 77 if (predict_attribute == null) {78 79 unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,null); // no additional top-up to remove80 int num_attributes = unlabeled_instances.numAttributes();81 82 Attribute arousal_attribute = new Attribute(WekaUtil.AROUSAL_ATTRIBUTE_NAME);83 unlabeled_instances.insertAttributeAt(arousal_attribute,num_attributes);84 num_attributes++;85 }86 else {87 // Dealing with ground-truth data:88 // => already has 'arousal' attribute89 // => in fact has 'valence' attribute too, which we want to remove90 91 unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,"472"); // top-up with removal of 'valence'92 93 // reference share this as 'groundtruth_instances' to trigger error calculation and output94 groundtruth_instances = unlabeled_instances;95 }96 97 int num_attributes = unlabeled_instances.numAttributes();98 99 // Set class attribute100 unlabeled_instances.setClassIndex(num_attributes - 1);101 102 WekaUtil.checkDatasetInstancesCompatible(unlabeled_instances);103 104 105 106 107 108 // Create copy where the predictions are to be made109 Instances labeled_instances = new Instances(unlabeled_instances);110 111 //112 // Label instances113 //114 115 final int num_instances = unlabeled_instances.numInstances();116 for (int i=0; i<num_instances; i++) {117 Instance unlabeled_instance = unlabeled_instances.instance(i);118 119 System.out.print("Making prediction for: " + i + "/" + num_instances);120 double classified_value = classifier.classifyInstance(unlabeled_instance);121 labeled_instances.instance(i).setClassValue(classified_value);122 123 String formatted_classified_value = String.format("% 06.3f", classified_value);124 125 System.out.print(" value = " + formatted_classified_value);126 127 if (groundtruth_instances != null) {128 Instance gt_instance = groundtruth_instances.instance(i);129 double gt_class_value = gt_instance.classValue();130 double error = Math.abs(classified_value - gt_class_value);131 132 String formatted_error = String.format("%.3f", error);133 System.out.print(" [error: " + formatted_error + "]");134 }135 System.out.println();136 }137 138 //139 // Save labeled data140 //141 142 System.out.println("Saving labeled instances: " + classified_data_output_filename);143 FileWriter fw = new FileWriter(classified_data_output_filename);144 BufferedWriter bw = new BufferedWriter(fw);145 146 bw.write(labeled_instances.toString());147 bw.newLine();148 bw.flush();149 bw.close();150 151 }152 catch (Exception e) {153 e.printStackTrace();154 }155 156 */157 158 37 159 38 Instances unlabeled_instances= WekaUtil.loadInstancesForClassification(unclassified_data_input_filename); -
main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaApplyValenceModel.java
r34797 r34798 1 1 package org.greenstone.mars; 2 3 //import java.util.Random;4 2 5 3 import java.io.BufferedInputStream; -
main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaTrainArousalModel.java
r34797 r34798 1 1 package org.greenstone.mars; 2 2 3 //import weka.core.converters.ConverterUtils.DataSource;4 3 import weka.core.Instances; 5 4 6 //import weka.filters.Filter;7 //import weka.filters.unsupervised.attribute.Remove;8 9 5 import weka.classifiers.Classifier; 10 6 import weka.classifiers.Evaluation; … … 35 31 36 32 public static void main(String[] args) 37 { 38 /* 39 if (args.length != 2) { 40 System.err.println("Error: incorrect number of command-line arguments"); 41 System.err.println("Usage: input_training_data.arff output-model.{model|ser}"); 42 System.exit(1); 43 }*/ 44 33 { 45 34 WekaUtil.checkUsageTraining(args); 46 35 … … 62 51 catch (Exception e) { 63 52 e.printStackTrace(); 64 } 65 66 67 /* 68 System.out.println("Training on ARFF file: " + input_arff_filename); 69 70 try { 71 DataSource data_source = new DataSource(input_arff_filename); 72 Instances data_instances = data_source.getDataSet(); 73 74 // ********* 75 Instances filtered_data_instances = WekaUtil.applyFilter(data_instances,"472"); // top-up with removal of 'valence' 76 77 78 // Filtering above has removed R472, so class to predict is numAttributes()-1 79 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1); 80 81 82 // Build scheme/classifier 83 REPTree classifier = new REPTree(); // scheme 84 String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0"); 85 classifier.setOptions(reptree_options); 86 87 classifier.buildClassifier(filtered_data_instances); 88 System.out.println(classifier); 89 90 / * 91 // Evaluate 92 Evaluation eval = new Evaluation(filtered_data_instances); 93 Random rand = new Random(1); 94 int folds = 10; 95 eval.crossValidateModel(classifier, filtered_data_instances, folds, rand); 96 System.out.println(eval.toSummaryString()); 97 * / 98 99 Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances); 100 System.out.println(eval.toSummaryString()); 101 102 System.out.println("Saving REPTree classifier model as: " + output_model_filename); 103 weka.core.SerializationHelper.write(output_model_filename, classifier); 104 105 106 } 107 catch (Exception e) { 108 e.printStackTrace(); 109 } 110 */ 53 } 111 54 } 112 55 } -
main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaTrainValenceModel.java
r34797 r34798 1 1 package org.greenstone.mars; 2 2 3 //import weka.core.converters.ConverterUtils.DataSource;4 3 import weka.core.Instances; 5 4 6 //import weka.filters.Filter;7 //import weka.filters.unsupervised.attribute.Remove;8 9 5 import weka.classifiers.Classifier; 10 6 import weka.classifiers.Evaluation; 11 //import weka.classifiers.trees.REPTree;12 7 13 8 // Based on: … … 26 21 public static void main(String[] args) 27 22 { 28 /*29 if (args.length != 2) {30 System.err.println("Error: incorrect number of command-line arguments");31 System.err.println("Usage: input_training_data.arff output-model.{model|ser}");32 System.exit(1);33 }*/34 35 23 WekaUtil.checkUsageTraining(args); 36 24 … … 52 40 catch (Exception e) { 53 41 e.printStackTrace(); 54 } 55 56 /* 57 System.out.println("Training on ARFF file: " + input_arff_filename); 58 59 try { 60 DataSource data_source = new DataSource(input_arff_filename); 61 Instances data_instances = data_source.getDataSet(); 62 63 // ********* 64 Instances filtered_data_instances = WekaUtil.applyFilter(data_instances,"471"); // top-up with removal of 'arousal' 65 66 // ****** comment varies 67 // With removal of R471, last column is the one we want to predict => numAttributes()-1 68 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1); 69 70 71 // Build scheme/classifier 72 REPTree classifier = new REPTree(); // scheme 73 String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0"); 74 classifier.setOptions(reptree_options); 75 76 classifier.buildClassifier(filtered_data_instances); 77 System.out.println(classifier); 78 79 80 Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances); 81 System.out.println(eval.toSummaryString()); 82 83 System.out.println("Saving REPTree classifier model as: " + output_model_filename); 84 weka.core.SerializationHelper.write(output_model_filename, classifier); 85 86 87 } 88 catch (Exception e) { 89 e.printStackTrace(); 90 } 91 92 */ 93 42 } 94 43 } 95 44 }
Note:
See TracChangeset
for help on using the changeset viewer.