Changeset 34788 for main/trunk/model-sites-dev/mars/src/java
- Timestamp:
- 2021-02-02T00:46:14+13:00 (3 years ago)
- Location:
- main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars
- Files:
-
- 2 added
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaApplyArousalModel.java
r34786 r34788 28 28 class WekaApplyArousalModel 29 29 { 30 //public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";31 //public final static String VALANCE_ATTRIBUTE_NAME = "valance_sample_26500ms";32 30 33 public final static String PREDICT_ATTRIBUTE_NAME = WekaUtil.AROUSAL_ATTRIBUTE_NAME;31 //public final static String PREDICT_ATTRIBUTE_NAME = WekaUtil.AROUSAL_ATTRIBUTE_NAME; 34 32 35 33 36 /*37 public static Instances applyFilter(Instances data_instances,String additional_remove)38 {39 String remove_option_args = ("-R 425,458-466"); // remove ordinal attributes40 if (additional_remove != null) {41 // top up, e.g. '472' representing valance in ground-truth files42 remove_option_args += "," + additional_remove;43 }44 45 46 Instances filtered_data_instances = null;47 try {48 String[] filter_options = weka.core.Utils.splitOptions(remove_option_args); // remove ordinal attributes and 'valance'49 Remove filter_remove = new Remove();50 filter_remove.setOptions(filter_options);51 filter_remove.setInputFormat(data_instances);52 filtered_data_instances = Filter.useFilter(data_instances, filter_remove);53 }54 catch (Exception e) {55 e.printStackTrace();56 System.exit(1);57 }58 59 return filtered_data_instances;60 }61 */62 63 /*64 65 public static void checkDatasetInstancesCompatible(Instances new_instances)66 {67 68 final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";69 70 try {71 DataSource gt_data_source = new DataSource(gt_datasource_filename);72 Instances gt_instances = gt_data_source.getDataSet();73 gt_instances = applyFilter(gt_instances,"472"); // top up with removing 'valance'74 75 gt_instances.setClassIndex(gt_instances.numAttributes() - 1);76 77 String equal_header_message = gt_instances.equalHeadersMsg(new_instances);78 79 if (equal_header_message != null) {80 System.err.println("Headers to datasets were not equal!");81 System.err.println(equal_header_message);82 System.exit(1);83 }84 }85 catch (Exception e) {86 e.printStackTrace();87 System.exit(1);88 }89 90 }91 */92 34 93 35 public static void main(String[] args) 94 36 { 37 /* 95 38 if (args.length != 3) { 96 39 System.err.println("Error: incorrect number of command-line arguments"); … … 98 41 System.exit(1); 99 42 } 100 43 */ 44 45 WekaUtil.checkUsageApplyModel(args); 46 101 47 String classifier_input_filename = args[0]; 102 48 String unclassified_data_input_filename = args[1]; 103 49 String classified_data_output_filename = args[2]; 104 50 51 52 Classifier classifier = WekaUtil.loadClassifierModel(classifier_input_filename); 53 54 /* 105 55 System.out.println("Loading Weka saved Classifier: " + classifier_input_filename); 106 56 … … 152 102 WekaUtil.checkDatasetInstancesCompatible(unlabeled_instances); 153 103 104 105 106 107 154 108 // Create copy where the predictions are to be made 155 109 Instances labeled_instances = new Instances(unlabeled_instances); … … 199 153 e.printStackTrace(); 200 154 } 201 155 156 */ 157 158 159 Instances unlabeled_instances= WekaUtil.loadInstancesForClassification(unclassified_data_input_filename); 160 161 // It is permissible to run this code and supply it with a data file that includes groundtruth in it. 162 // In this situation, the 'unlabeled' instances: 163 // (i) need to be massaged to be in the same form as truly unlabeled data 164 // (ii) we also set up 'groundtruth_instances' as an alias (reference) to 'filtered_unlabeled_instanced' 165 // to trigger calculating the error on the predicted vaues 166 167 boolean has_groundtruth_data = WekaUtil.instancesHavePredictAttribute(unlabeled_instances,WekaUtil.AROUSAL_ATTRIBUTE_NAME); 168 169 // The following deals with (i) internally, ensuring that what is returned is suitable for making predictions on 170 Instances filtered_unlabeled_instances 171 = WekaUtil.filterInstancesForApplying(unlabeled_instances,has_groundtruth_data, 172 WekaUtil.AROUSAL_ATTRIBUTE_NAME,"472"); 173 174 // The following deals with (ii) 175 Instances groundtruth_instances = (has_groundtruth_data) ? filtered_unlabeled_instances : null; 176 177 Instances labeled_instances = WekaUtil.makePredictions(classifier, filtered_unlabeled_instances, groundtruth_instances); 178 179 try { 180 // Save labeled data 181 182 System.out.println("Saving labeled instances: " + classified_data_output_filename); 183 FileWriter fw = new FileWriter(classified_data_output_filename); 184 BufferedWriter bw = new BufferedWriter(fw); 185 186 bw.write(labeled_instances.toString()); 187 bw.newLine(); 188 bw.flush(); 189 bw.close(); 190 191 } 192 catch (Exception e) { 193 e.printStackTrace(); 194 } 195 202 196 } 203 197 } -
main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaTrainArousalModel.java
r34786 r34788 1 1 package org.greenstone.mars; 2 2 3 import java.util.Random; 4 5 import weka.core.converters.ConverterUtils.DataSource; 3 //import weka.core.converters.ConverterUtils.DataSource; 6 4 import weka.core.Instances; 7 5 8 import weka.filters.Filter;9 import weka.filters.unsupervised.attribute.Remove;6 //import weka.filters.Filter; 7 //import weka.filters.unsupervised.attribute.Remove; 10 8 9 import weka.classifiers.Classifier; 11 10 import weka.classifiers.Evaluation; 12 import weka.classifiers.trees.REPTree;13 11 14 12 // Based on: … … 30 28 class WekaTrainArousalModel 31 29 { 32 // Scheme: weka.classifiers.trees.REPTree -M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0 30 // Scheme: weka.classifiers.trees.REPTree -M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0 -- (i.e., default vals) 33 31 // Relation: deam-essentia-features-arousal-valence-weka.filters.unsupervised.attribute.Remove-R472-weka.filters.unsupervised.attribute.Remove-R458-466 34 32 // Instances: 1743 33 34 // Note, above missed -R 425 and 427 (ordinal vals) 35 35 36 36 public static void main(String[] args) 37 37 { 38 /* 38 39 if (args.length != 2) { 39 40 System.err.println("Error: incorrect number of command-line arguments"); 40 41 System.err.println("Usage: input_training_data.arff output-model.{model|ser}"); 41 42 System.exit(1); 43 }*/ 44 45 WekaUtil.checkUsageTraining(args); 46 47 String input_data_filename = args[0]; 48 String output_model_filename = args[1]; 49 50 // knock out 472 (valance) from dataset 51 Instances filtered_data_instances = WekaUtil.loadAndFilterDataInstances(input_data_filename, "472"); 52 Classifier classifier = WekaUtil.trainREPTree(filtered_data_instances); 53 System.out.println(classifier); 54 55 Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances); 56 System.out.println(eval.toSummaryString()); 57 58 try { 59 System.out.println("Saving REPTree classifier model as: " + output_model_filename); 60 weka.core.SerializationHelper.write(output_model_filename, classifier); 42 61 } 43 String input_arff_filename = args[0]; 44 String output_model_filename = args[1]; 62 catch (Exception e) { 63 e.printStackTrace(); 64 } 45 65 66 67 /* 46 68 System.out.println("Training on ARFF file: " + input_arff_filename); 47 69 … … 49 71 DataSource data_source = new DataSource(input_arff_filename); 50 72 Instances data_instances = data_source.getDataSet(); 51 73 74 // ********* 52 75 Instances filtered_data_instances = WekaUtil.applyFilter(data_instances,"472"); // top-up with removal of 'valance' 53 76 … … 65 88 System.out.println(classifier); 66 89 90 / * 67 91 // Evaluate 68 92 Evaluation eval = new Evaluation(filtered_data_instances); … … 71 95 eval.crossValidateModel(classifier, filtered_data_instances, folds, rand); 72 96 System.out.println(eval.toSummaryString()); 97 * / 98 99 Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances); 100 System.out.println(eval.toSummaryString()); 73 101 74 102 System.out.println("Saving REPTree classifier model as: " + output_model_filename); … … 80 108 e.printStackTrace(); 81 109 } 82 110 */ 83 111 } 84 112 } -
main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaUtil.java
r34786 r34788 8 8 import java.io.FileReader; 9 9 import java.io.FileWriter; 10 11 import java.util.Random; 10 12 11 13 import weka.core.converters.ConverterUtils.DataSource; … … 19 21 20 22 import weka.classifiers.Classifier; 21 23 import weka.classifiers.Evaluation; 24 25 import weka.classifiers.trees.REPTree; 22 26 23 27 // Based on: … … 27 31 { 28 32 public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms"; 29 public final static String VALANCE_ATTRIBUTE_NAME = "valance_sample_26500ms"; 30 31 33 public final static String VALANCE_ATTRIBUTE_NAME = "valence_sample_26500ms"; 34 35 36 public static void checkUsageTraining(String[] args) 37 { 38 if (args.length != 2) { 39 System.err.println("Error: incorrect number of command-line arguments"); 40 System.err.println("Usage: input_training_data.arff output-model.{model|ser}"); 41 System.exit(1); 42 } 43 44 } 45 46 public static void checkUsageApplyModel(String[] args) 47 { 48 if (args.length != 3) { 49 System.err.println("Error: incorrect number of command-line arguments"); 50 System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}"); 51 System.exit(1); 52 } 53 } 54 32 55 public static Instances applyFilter(Instances data_instances, String additional_remove) 33 56 { … … 43 66 // The following filtering produces as tree with 41 nodes 44 67 45 String remove_option_args = ("-R 425,427,458-466"); 68 String remove_option_args = ("-R 425,427,458-466"); // knock out 11 attributes (472 -> 461) 46 69 47 70 … … 51 74 52 75 if (additional_remove != null) { 53 // top up, e.g. '472' representing valance in ground-truth files 76 // top up, 77 // e.g. '471' representing arousal in ground-truth files 78 // e.g. '472' representing valance in ground-truth files 54 79 remove_option_args += "," + additional_remove; 55 80 } … … 58 83 Instances filtered_data_instances = null; 59 84 try { 60 // remove ordinal attributes and any additional topups, such as 'valance' (when predicting 'arousal') 85 // remove ordinal attributes and any additional topups, 86 // such as 'valance' (when predicting 'arousal') and vice versa 61 87 62 88 String[] filter_options = weka.core.Utils.splitOptions(remove_option_args); … … 75 101 76 102 77 public static void checkDatasetInstancesCompatible(Instances new_instances )103 public static void checkDatasetInstancesCompatible(Instances new_instances, String additional_remove) 78 104 { 79 105 … … 83 109 DataSource gt_data_source = new DataSource(gt_datasource_filename); 84 110 Instances gt_instances = gt_data_source.getDataSet(); 85 gt_instances = applyFilter(gt_instances, "472"); // top up with removing 'valance'111 gt_instances = applyFilter(gt_instances,additional_remove); // remove 'valance' or 'arousal' 86 112 87 113 gt_instances.setClassIndex(gt_instances.numAttributes() - 1); … … 101 127 102 128 } 129 130 131 public static Instances loadAndFilterDataInstances(String input_filename, String additional_attribute_remove) 132 { 133 System.out.println("Training on file: " + input_filename); 134 Instances filtered_data_instances = null; 135 136 try { 137 DataSource data_source = new DataSource(input_filename); 138 Instances data_instances = data_source.getDataSet(); 139 140 // Training dataset has two ground-truth attributes: 'arousal' and 'valance'. 141 // When training for one, need to knock out the other. This is the purpose 142 // of 'additional_attribute_remove' 143 filtered_data_instances = applyFilter(data_instances,additional_attribute_remove); 144 145 // With the 'other' ground-truth attribute taken out, the column to predict 146 // will always be the last column 147 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1); 148 } 149 catch (Exception e) { 150 e.printStackTrace(); 151 System.exit(1); 152 } 153 154 return filtered_data_instances; 155 } 156 157 public static Classifier trainREPTree(Instances data_instances) 158 { 159 REPTree classifier = null; 160 161 try { 162 // Build classifier 163 classifier = new REPTree(); 164 String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0"); 165 classifier.setOptions(reptree_options); 166 167 classifier.buildClassifier(data_instances); 168 } 169 catch (Exception e) { 170 e.printStackTrace(); 171 System.exit(1); 172 } 173 174 return classifier; 175 } 176 177 178 public static Evaluation evaluateClassifier(Classifier classifier, Instances data_instances) 179 { 180 Evaluation eval = null; 181 try { 182 eval = new Evaluation(data_instances); 183 Random rand = new Random(1); 184 int folds = 10; 185 186 eval.crossValidateModel(classifier, data_instances, folds, rand); 187 } 188 catch (Exception e) { 189 e.printStackTrace(); 190 System.exit(1); 191 } 192 193 return eval; 194 } 195 196 /* Apply Model Specific Methods */ 197 198 public static Classifier loadClassifierModel(String classifier_input_filename 199 200 ) 201 { 202 System.out.println("Loading Weka saved Classifier: " + classifier_input_filename); 203 Classifier classifier = null; 204 try { 205 FileInputStream fis = new FileInputStream(classifier_input_filename); 206 BufferedInputStream bis= new BufferedInputStream(fis); 207 classifier = (Classifier)SerializationHelper.read(bis); 208 } 209 catch (Exception e) { 210 e.printStackTrace(); 211 System.exit(1); 212 } 213 214 return classifier; 215 } 216 217 218 public static Instances loadInstancesForClassification(String unlabeled_input_filename) 219 { 220 System.out.println("Loading unlabeled instances: " + unlabeled_input_filename); 221 Instances unlabeled_instances = null; 222 223 try { 224 DataSource data_source = new DataSource(unlabeled_input_filename); 225 unlabeled_instances = data_source.getDataSet(); 226 } 227 catch (Exception e) { 228 e.printStackTrace(); 229 System.exit(1); 230 } 231 232 return unlabeled_instances; 233 } 234 235 public static boolean instancesHavePredictAttribute(Instances data_instances, String predict_attribute_name) 236 { 237 Attribute predict_attribute = data_instances.attribute(predict_attribute_name); 238 239 return predict_attribute != null; 240 } 241 242 243 244 public static Instances filterInstancesForApplying(Instances unlabeled_instances, boolean has_groundtruth_data, 245 String predict_attribute_name, String additional_attribute_remove) 246 { 247 Instances filtered_unlabeled_instances = null; 248 249 250 // Work out if we're dealing with a ground-truth ARFF file or not 251 // (i.e. already has the desired attribute) 252 253 //Attribute predict_attribute = unlabeled_instances.attribute(predict_attribute_name); 254 255 if (!has_groundtruth_data) { 256 257 filtered_unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,null); // no additional top-up to remove 258 int num_attributes = filtered_unlabeled_instances.numAttributes(); 259 260 Attribute new_predict_attribute = new Attribute(predict_attribute_name); 261 filtered_unlabeled_instances.insertAttributeAt(new_predict_attribute,num_attributes); 262 // ****** 263 //num_attributes++; 264 } 265 else { 266 // Dealing with ground-truth data: 267 // => already has 'arousal' and 'valance' attributes 268 // => need to keep the 'predict_attribute_name' and remove the other one 269 // => (its -R value of which is specified in 'additional_attribute_remove') 270 271 // Need to massage instances into same form as an unclassified data input file 272 filtered_unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,additional_attribute_remove); 273 274 // reference share this as 'groundtruth_instances' to trigger error calculation and output 275 // ****** 276 // groundtruth_instances = filtered_unlabeled_instances; 277 } 278 279 int num_attributes = filtered_unlabeled_instances.numAttributes(); 280 281 // Set class attribute 282 filtered_unlabeled_instances.setClassIndex(num_attributes - 1); 283 284 // ***** Do I still want to run the check???? 285 WekaUtil.checkDatasetInstancesCompatible(filtered_unlabeled_instances, additional_attribute_remove); 286 287 return filtered_unlabeled_instances; 288 } 289 290 291 public static Instances makePredictions(Classifier classifier, Instances unlabeled_instances, Instances groundtruth_instances) 292 { 293 294 // Create copy where the predictions are to be made 295 Instances labeled_instances = new Instances(unlabeled_instances); 296 297 try { 298 // Label instances 299 300 final int num_instances = unlabeled_instances.numInstances(); 301 for (int i=0; i<num_instances; i++) { 302 Instance unlabeled_instance = unlabeled_instances.instance(i); 303 304 System.out.print("Making prediction for: " + i + "/" + num_instances); 305 double classified_value = classifier.classifyInstance(unlabeled_instance); 306 labeled_instances.instance(i).setClassValue(classified_value); 307 308 String formatted_classified_value = String.format("% 06.3f", classified_value); 309 310 System.out.print(" value = " + formatted_classified_value); 311 312 if (groundtruth_instances != null) { 313 Instance gt_instance = groundtruth_instances.instance(i); 314 double gt_class_value = gt_instance.classValue(); 315 double error = Math.abs(classified_value - gt_class_value); 316 317 String formatted_error = String.format("%.3f", error); 318 System.out.print(" [error: " + formatted_error + "]"); 319 } 320 System.out.println(); 321 } 322 } 323 catch (Exception e) { 324 e.printStackTrace(); 325 System.exit(1); 326 } 327 328 return labeled_instances; 329 } 103 330 104 331 }
Note:
See TracChangeset
for help on using the changeset viewer.