source: main/trunk/model-sites-dev/mars/WekaApplyArousalModel.java@ 34778

Last change on this file since 34778 was 34778, checked in by davidb, 3 years ago

Next logical commit point after testing on CSV files in amc-essentia/import

File size: 7.2 KB
Line 
1//import java.util.Random;
2
3import java.io.BufferedInputStream;
4import java.io.FileInputStream;
5
6import java.io.BufferedReader;
7import java.io.BufferedWriter;
8import java.io.FileReader;
9import java.io.FileWriter;
10
11import weka.core.converters.ConverterUtils.DataSource;
12import weka.core.Attribute;
13import weka.core.Instance;
14import weka.core.Instances;
15import weka.core.SerializationHelper;
16
17import weka.filters.Filter;
18import weka.filters.unsupervised.attribute.Remove;
19
20import weka.classifiers.Classifier;
21
22
23// Based on:
24// https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
25
26class WekaApplyArousalModel
27{
28 public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
29 public final static String VALANCE_ATTRIBUTE_NAME = "valance_sample_26500ms";
30
31 public final static String PREDICT_ATTRIBUTE_NAME = "arousal_sample_26500ms";
32
33
34 public static Instances applyFilter(Instances data_instances,String additional_remove)
35 {
36 String remove_option_args = ("-R 425,458-466"); // remove ordinal attributes
37 if (additional_remove != null) {
38 // top up, e.g. '472' representing valance in ground-truth files
39 remove_option_args += "," + additional_remove;
40 }
41
42
43 Instances filtered_data_instances = null;
44 try {
45 String[] filter_options = weka.core.Utils.splitOptions(remove_option_args); // remove ordinal attributes and 'valance'
46 Remove filter_remove = new Remove();
47 filter_remove.setOptions(filter_options);
48 filter_remove.setInputFormat(data_instances);
49 filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
50 }
51 catch (Exception e) {
52 e.printStackTrace();
53 System.exit(1);
54 }
55
56 return filtered_data_instances;
57 }
58
59
60
61 public static void checkDatasetInstancesCompatible(Instances new_instances)
62 {
63
64 final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";
65
66 try {
67 DataSource gt_data_source = new DataSource(gt_datasource_filename);
68 Instances gt_instances = gt_data_source.getDataSet();
69 gt_instances = applyFilter(gt_instances,"472"); // top up with removing 'valance'
70
71 gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
72
73 String equal_header_message = gt_instances.equalHeadersMsg(new_instances);
74
75 if (equal_header_message != null) {
76 System.err.println("Headers to datasets were not equal!");
77 System.err.println(equal_header_message);
78 System.exit(1);
79 }
80 }
81 catch (Exception e) {
82 e.printStackTrace();
83 System.exit(1);
84 }
85
86 }
87
88 public static void main(String[] args)
89 {
90 if (args.length != 3) {
91 System.err.println("Error: incorrect number of command-line arguments");
92 System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}");
93 System.exit(1);
94 }
95
96 String classifier_input_filename = args[0];
97 String unclassified_data_input_filename = args[1];
98 String classified_data_output_filename = args[2];
99
100 System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
101
102 try {
103 FileInputStream fis = new FileInputStream(classifier_input_filename);
104 BufferedInputStream bis= new BufferedInputStream(fis);
105 Classifier classifier = (Classifier)SerializationHelper.read(bis);
106
107 //
108 // Load in unlabeled data
109 //
110
111 System.out.println("Loading unlabeled instances: " + unclassified_data_input_filename);
112
113 //FileReader fr = new FileReader(unclassified_data_input_filename);
114 //BufferedReader br = new BufferedReader(fr);
115 //Instances unlabeled_instances = new Instances(br);
116 //br.close();
117
118 DataSource data_source = new DataSource(unclassified_data_input_filename);
119 Instances unlabeled_instances = data_source.getDataSet();
120
121 // Work out if we're dealing with a ground-truth ARFF file or not
122 // (i.e. already has the desired attribute)
123
124 Instances groundtruth_instances = null;
125 Attribute predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME);
126
127 if (predict_attribute == null) {
128
129 unlabeled_instances = applyFilter(unlabeled_instances,null); // no additional top-up to remove
130 int num_attributes = unlabeled_instances.numAttributes();
131
132 Attribute arousal_attribute = new Attribute(AROUSAL_ATTRIBUTE_NAME);
133 unlabeled_instances.insertAttributeAt(arousal_attribute,num_attributes);
134 num_attributes++;
135
136 //Attribute valance_attribute = new Attribute(VALANCE_ATTRIBUTE_NAME);
137 //unlabeled_instances.insertAttributeAt(valance_attribute,num_attributes);
138 //num_attributes++;
139
140 //predict_attribute = new Attribute(PREDICT_ATTRIBUTE_NAME);
141 //unlabeled_instances.insertAttributeAt(predict_attribute,num_attributes);
142 //unlabeled_instances.setClassIndex(num_attributes);
143 //num_attributes++;
144
145 //predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME);
146 //unlabeled_instances.setClass(predict_attribute);
147
148 //unlabeled_instances.setClassIndex(num_attributes - 1);
149 }
150 else {
151 // Dealing with ground-truth data:
152 // => already has 'arousal' attribute
153 // => in fact has 'valance' attribute too, which we want to remove
154
155 unlabeled_instances = applyFilter(unlabeled_instances,"472"); // top-up with removal of 'valance'
156 //unlabeled_instances.setClass(predict_attribute);
157
158 // reference share this as 'groundtruth_instances' to trigger error calculation and output
159 groundtruth_instances = unlabeled_instances;
160 }
161
162 int num_attributes = unlabeled_instances.numAttributes();
163 unlabeled_instances.setClassIndex(num_attributes - 1);
164
165 checkDatasetInstancesCompatible(unlabeled_instances);
166
167 // Set class attribute
168 //unlabeled_instances.setClassIndex(gt_instances.numAttributes() - 2);
169
170
171 // Create copy where the predictions are to be made
172 Instances labeled_instances = new Instances(unlabeled_instances);
173
174 //
175 // Label instances
176 //
177
178 final int num_instances = unlabeled_instances.numInstances();
179 for (int i=0; i<num_instances; i++) {
180 Instance unlabeled_instance = unlabeled_instances.instance(i);
181
182 System.out.print("Making prediction for: " + i + "/" + num_instances);
183 double classified_value = classifier.classifyInstance(unlabeled_instance);
184 labeled_instances.instance(i).setClassValue(classified_value);
185
186 String formatted_classified_value = String.format("% 06.3f", classified_value);
187
188 System.out.print(" value = " + formatted_classified_value);
189
190 if (groundtruth_instances != null) {
191 Instance gt_instance = groundtruth_instances.instance(i);
192 double gt_class_value = gt_instance.classValue();
193 double error = Math.abs(classified_value - gt_class_value);
194
195 String formatted_error = String.format("%.3f", error);
196 System.out.print(" [error: " + formatted_error + "]");
197 }
198 System.out.println();
199 }
200
201 //
202 // Save labeled data
203 //
204
205 System.out.println("Saving labeled instances: " + classified_data_output_filename);
206 FileWriter fw = new FileWriter(classified_data_output_filename);
207 BufferedWriter bw = new BufferedWriter(fw);
208
209 bw.write(labeled_instances.toString());
210 bw.newLine();
211 bw.flush();
212 bw.close();
213
214 }
215 catch (Exception e) {
216 e.printStackTrace();
217 }
218
219 }
220}
Note: See TracBrowser for help on using the repository browser.