Changeset 34778
- Timestamp:
- 2021-02-01T17:03:58+13:00 (3 years ago)
- Location:
- main/trunk/model-sites-dev/mars
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
main/trunk/model-sites-dev/mars/RUN-APPLY-AROUSAL-MODEL.sh
r34774 r34778 1 1 #!/bin/bash 2 3 input_data_filename=${1:-collect/deam/etc/deam-essentia-features-arousal-valence.arff} 4 output_data_filename=${2:-predicted-arousal.csv} 2 5 3 6 . ./_set_weka_classpath.bash … … 5 8 java -cp "$cp_args" WekaApplyArousalModel \ 6 9 "reptree-arousal-serialized.model" \ 7 "collect/deam/etc/deam-essentia-features-arousal-valence.arff" \ 8 "predicted-arousal.csv" 10 "$input_data_filename" \ 11 "$output_data_filename" 12 13 echo "****" 14 echo "* Saved output data: $output_data_filename" 15 echo "****" 16 9 17 10 18 -
main/trunk/model-sites-dev/mars/WekaApplyArousalModel.java
r34774 r34778 10 10 11 11 import weka.core.converters.ConverterUtils.DataSource; 12 import weka.core.Attribute; 12 13 import weka.core.Instance; 13 14 import weka.core.Instances; 14 15 import weka.core.SerializationHelper; 15 16 16 //import weka.filters.Filter; 17 //import weka.filters.unsupervised.attribute.Remove; 18 19 //import weka.classifiers.Evaluation; 17 import weka.filters.Filter; 18 import weka.filters.unsupervised.attribute.Remove; 19 20 20 import weka.classifiers.Classifier; 21 //import weka.classifiers.trees.REPTree;22 23 24 25 //import weka.core.Instances;26 21 27 22 … … 31 26 class WekaApplyArousalModel 32 27 { 28 public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms"; 29 public final static String VALANCE_ATTRIBUTE_NAME = "valance_sample_26500ms"; 30 31 public final static String PREDICT_ATTRIBUTE_NAME = "arousal_sample_26500ms"; 32 33 34 public static Instances applyFilter(Instances data_instances,String additional_remove) 35 { 36 String remove_option_args = ("-R 425,458-466"); // remove ordinal attributes 37 if (additional_remove != null) { 38 // top up, e.g. '472' representing valance in ground-truth files 39 remove_option_args += "," + additional_remove; 40 } 41 42 43 Instances filtered_data_instances = null; 44 try { 45 String[] filter_options = weka.core.Utils.splitOptions(remove_option_args); // remove ordinal attributes and 'valance' 46 Remove filter_remove = new Remove(); 47 filter_remove.setOptions(filter_options); 48 filter_remove.setInputFormat(data_instances); 49 filtered_data_instances = Filter.useFilter(data_instances, filter_remove); 50 } 51 catch (Exception e) { 52 e.printStackTrace(); 53 System.exit(1); 54 } 55 56 return filtered_data_instances; 57 } 58 59 60 61 public static void checkDatasetInstancesCompatible(Instances new_instances) 62 { 63 64 final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff"; 65 66 try { 67 DataSource gt_data_source = new DataSource(gt_datasource_filename); 68 Instances gt_instances = gt_data_source.getDataSet(); 69 gt_instances = applyFilter(gt_instances,"472"); // top up with removing 'valance' 70 71 gt_instances.setClassIndex(gt_instances.numAttributes() - 1); 72 73 String equal_header_message = gt_instances.equalHeadersMsg(new_instances); 74 75 if (equal_header_message != null) { 76 System.err.println("Headers to datasets were not equal!"); 77 System.err.println(equal_header_message); 78 System.exit(1); 79 } 80 } 81 catch (Exception e) { 82 e.printStackTrace(); 83 System.exit(1); 84 } 85 86 } 87 33 88 public static void main(String[] args) 34 89 { … … 49 104 BufferedInputStream bis= new BufferedInputStream(fis); 50 105 Classifier classifier = (Classifier)SerializationHelper.read(bis); 51 52 53 54 // load unlabeled data 106 107 // 108 // Load in unlabeled data 109 // 110 55 111 System.out.println("Loading unlabeled instances: " + unclassified_data_input_filename); 56 FileReader fr = new FileReader(unclassified_data_input_filename); 57 BufferedReader br = new BufferedReader(fr); 58 Instances gt_instances = new Instances(br); 59 br.close(); 60 61 // set class attribute 62 gt_instances.setClassIndex(gt_instances.numAttributes() - 2); 63 64 // create copy 65 Instances unlabeled_instances = new Instances(gt_instances); 112 113 //FileReader fr = new FileReader(unclassified_data_input_filename); 114 //BufferedReader br = new BufferedReader(fr); 115 //Instances unlabeled_instances = new Instances(br); 116 //br.close(); 117 118 DataSource data_source = new DataSource(unclassified_data_input_filename); 119 Instances unlabeled_instances = data_source.getDataSet(); 120 121 // Work out if we're dealing with a ground-truth ARFF file or not 122 // (i.e. already has the desired attribute) 123 124 Instances groundtruth_instances = null; 125 Attribute predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME); 126 127 if (predict_attribute == null) { 128 129 unlabeled_instances = applyFilter(unlabeled_instances,null); // no additional top-up to remove 130 int num_attributes = unlabeled_instances.numAttributes(); 131 132 Attribute arousal_attribute = new Attribute(AROUSAL_ATTRIBUTE_NAME); 133 unlabeled_instances.insertAttributeAt(arousal_attribute,num_attributes); 134 num_attributes++; 135 136 //Attribute valance_attribute = new Attribute(VALANCE_ATTRIBUTE_NAME); 137 //unlabeled_instances.insertAttributeAt(valance_attribute,num_attributes); 138 //num_attributes++; 139 140 //predict_attribute = new Attribute(PREDICT_ATTRIBUTE_NAME); 141 //unlabeled_instances.insertAttributeAt(predict_attribute,num_attributes); 142 //unlabeled_instances.setClassIndex(num_attributes); 143 //num_attributes++; 144 145 //predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME); 146 //unlabeled_instances.setClass(predict_attribute); 147 148 //unlabeled_instances.setClassIndex(num_attributes - 1); 149 } 150 else { 151 // Dealing with ground-truth data: 152 // => already has 'arousal' attribute 153 // => in fact has 'valance' attribute too, which we want to remove 154 155 unlabeled_instances = applyFilter(unlabeled_instances,"472"); // top-up with removal of 'valance' 156 //unlabeled_instances.setClass(predict_attribute); 157 158 // reference share this as 'groundtruth_instances' to trigger error calculation and output 159 groundtruth_instances = unlabeled_instances; 160 } 161 162 int num_attributes = unlabeled_instances.numAttributes(); 163 unlabeled_instances.setClassIndex(num_attributes - 1); 164 165 checkDatasetInstancesCompatible(unlabeled_instances); 166 167 // Set class attribute 168 //unlabeled_instances.setClassIndex(gt_instances.numAttributes() - 2); 169 170 171 // Create copy where the predictions are to be made 66 172 Instances labeled_instances = new Instances(unlabeled_instances); 67 68 // label instances 173 174 // 175 // Label instances 176 // 177 69 178 final int num_instances = unlabeled_instances.numInstances(); 70 179 for (int i=0; i<num_instances; i++) { 71 180 Instance unlabeled_instance = unlabeled_instances.instance(i); 181 72 182 System.out.print("Making prediction for: " + i + "/" + num_instances); 73 double classified_label = classifier.classifyInstance(unlabeled_instance); 74 labeled_instances.instance(i).setClassValue(classified_label); 75 76 Instance gt_instance = gt_instances.instance(i); 77 double gt_class_value = gt_instance.classValue(); 78 System.out.println(" error: " + Math.abs(classified_label-gt_class_value)); 79 } 80 81 82 // save labeled data 183 double classified_value = classifier.classifyInstance(unlabeled_instance); 184 labeled_instances.instance(i).setClassValue(classified_value); 185 186 String formatted_classified_value = String.format("% 06.3f", classified_value); 187 188 System.out.print(" value = " + formatted_classified_value); 189 190 if (groundtruth_instances != null) { 191 Instance gt_instance = groundtruth_instances.instance(i); 192 double gt_class_value = gt_instance.classValue(); 193 double error = Math.abs(classified_value - gt_class_value); 194 195 String formatted_error = String.format("%.3f", error); 196 System.out.print(" [error: " + formatted_error + "]"); 197 } 198 System.out.println(); 199 } 200 201 // 202 // Save labeled data 203 // 204 83 205 System.out.println("Saving labeled instances: " + classified_data_output_filename); 84 206 FileWriter fw = new FileWriter(classified_data_output_filename); -
main/trunk/model-sites-dev/mars/WekaTrainArousalModel.java
r34774 r34778 36 36 if (args.length != 2) { 37 37 System.err.println("Error: incorrect number of command-line arguments"); 38 System.err.println("Usage: input_training_data.arff output-model. ser");38 System.err.println("Usage: input_training_data.arff output-model.{model|ser}"); 39 39 System.exit(1); 40 40 } … … 51 51 //String[] filter_options = weka.core.Utils.splitOptions("-R 472"); 52 52 //String[] filter_options = weka.core.Utils.splitOptions("-R 472 -R 458-466"); 53 String[] filter_options = weka.core.Utils.splitOptions("-R 4 72,458-466");53 String[] filter_options = weka.core.Utils.splitOptions("-R 425,458-466,472"); // remove ordinal attributes and 'valance' 54 54 Remove filter_remove = new Remove(); 55 55 filter_remove.setOptions(filter_options);
Note:
See TracChangeset
for help on using the changeset viewer.