1 | package org.greenstone.mars;
|
---|
2 |
|
---|
3 | import java.io.BufferedInputStream;
|
---|
4 | import java.io.FileInputStream;
|
---|
5 |
|
---|
6 | import java.io.BufferedReader;
|
---|
7 | import java.io.BufferedWriter;
|
---|
8 | import java.io.FileReader;
|
---|
9 | import java.io.FileWriter;
|
---|
10 |
|
---|
11 | import java.util.Random;
|
---|
12 |
|
---|
13 | import weka.core.Attribute;
|
---|
14 | import weka.core.Instance;
|
---|
15 | import weka.core.Instances;
|
---|
16 | import weka.core.SerializationHelper;
|
---|
17 | import weka.core.converters.ConverterUtils.DataSource;
|
---|
18 | import weka.core.converters.ConverterUtils.DataSink;
|
---|
19 |
|
---|
20 | import weka.filters.Filter;
|
---|
21 | import weka.filters.unsupervised.attribute.Remove;
|
---|
22 |
|
---|
23 | import weka.classifiers.Classifier;
|
---|
24 | import weka.classifiers.Evaluation;
|
---|
25 |
|
---|
26 | import weka.classifiers.trees.REPTree;
|
---|
27 |
|
---|
28 | // Based on:
|
---|
29 | // https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
|
---|
30 |
|
---|
31 | class WekaUtil
|
---|
32 | {
|
---|
33 | public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
|
---|
34 | public final static String VALENCE_ATTRIBUTE_NAME = "valence_sample_26500ms";
|
---|
35 |
|
---|
36 |
|
---|
37 | public static void checkUsageTraining(String[] args)
|
---|
38 | {
|
---|
39 | if (args.length != 2) {
|
---|
40 | System.err.println("Error: incorrect number of command-line arguments");
|
---|
41 | System.err.println("Usage: input_training_data.arff output-serialized.model");
|
---|
42 | System.exit(1);
|
---|
43 | }
|
---|
44 |
|
---|
45 | }
|
---|
46 |
|
---|
47 | public static void checkUsageApplyModel(String[] args)
|
---|
48 | {
|
---|
49 | if (args.length != 3) {
|
---|
50 | System.err.println("Error: incorrect number of command-line arguments");
|
---|
51 | System.err.println("Usage: trained-serialized.model unclassified-data.{arff|csv} classified-data.{arff|csv}");
|
---|
52 | System.exit(1);
|
---|
53 | }
|
---|
54 | }
|
---|
55 |
|
---|
56 | public static void checkUsageApplyAVModels(String[] args)
|
---|
57 | {
|
---|
58 | if (args.length != 4) {
|
---|
59 | System.err.println("Error: incorrect number of command-line arguments");
|
---|
60 | System.err.println("Usage: arousal-serialized.model valence-serialized.model unclassified-data.{arff|csv} classified-data.{arff|csv}");
|
---|
61 | System.exit(1);
|
---|
62 | }
|
---|
63 | }
|
---|
64 |
|
---|
65 | public static Instances applyFilter(Instances data_instances, String additional_remove)
|
---|
66 | {
|
---|
67 | // Avoid ordinal values in dataset
|
---|
68 | // (as tricky to work with in amc-essentia collection as essentia feature extractor
|
---|
69 | // on a single file is not guaranteed to generate the same set of ordinal vals in
|
---|
70 | // CSV file generated)
|
---|
71 |
|
---|
72 | // Note, for RepTree, with vals to algorithm, does produce some outlier rules
|
---|
73 | // using spectral_spread.dvar with high int values, e.g.:
|
---|
74 | // lowlevel.spectral_spread.dvar < 3831853219840
|
---|
75 | //
|
---|
76 | // The following filtering produces as tree with 41 nodes
|
---|
77 |
|
---|
78 | String remove_option_args = ("-R 425,427,458-466"); // knock out 11 attributes (472 -> 461)
|
---|
79 |
|
---|
80 |
|
---|
81 | // For 'arousal', filtering to the following produces a smaller REPTree
|
---|
82 | // with only 27 nodes
|
---|
83 | //String remove_option_args = ("-R 458-466");
|
---|
84 |
|
---|
85 | if (additional_remove != null) {
|
---|
86 | // top up,
|
---|
87 | // e.g. '471' representing arousal in ground-truth files
|
---|
88 | // e.g. '472' representing valence in ground-truth files
|
---|
89 | remove_option_args += "," + additional_remove;
|
---|
90 | }
|
---|
91 |
|
---|
92 |
|
---|
93 | Instances filtered_data_instances = null;
|
---|
94 | try {
|
---|
95 | // remove ordinal attributes and any additional topups,
|
---|
96 | // such as 'valence' (when predicting 'arousal') and vice versa
|
---|
97 |
|
---|
98 | String[] filter_options = weka.core.Utils.splitOptions(remove_option_args);
|
---|
99 | Remove filter_remove = new Remove();
|
---|
100 | filter_remove.setOptions(filter_options);
|
---|
101 | filter_remove.setInputFormat(data_instances);
|
---|
102 | filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
|
---|
103 | }
|
---|
104 | catch (Exception e) {
|
---|
105 | e.printStackTrace();
|
---|
106 | System.exit(1);
|
---|
107 | }
|
---|
108 |
|
---|
109 | return filtered_data_instances;
|
---|
110 | }
|
---|
111 |
|
---|
112 |
|
---|
113 | public static void checkDatasetInstancesCompatible(Instances new_instances, String additional_remove)
|
---|
114 | {
|
---|
115 |
|
---|
116 | final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";
|
---|
117 |
|
---|
118 | try {
|
---|
119 | DataSource gt_data_source = new DataSource(gt_datasource_filename);
|
---|
120 | Instances gt_instances = gt_data_source.getDataSet();
|
---|
121 | gt_instances = applyFilter(gt_instances,additional_remove); // remove 'valence' or 'arousal'
|
---|
122 |
|
---|
123 | gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
|
---|
124 |
|
---|
125 | String equal_header_message = gt_instances.equalHeadersMsg(new_instances);
|
---|
126 |
|
---|
127 | if (equal_header_message != null) {
|
---|
128 | System.err.println("Headers to datasets were not equal!");
|
---|
129 | System.err.println(equal_header_message);
|
---|
130 | System.exit(1);
|
---|
131 | }
|
---|
132 | }
|
---|
133 | catch (Exception e) {
|
---|
134 | e.printStackTrace();
|
---|
135 | System.exit(1);
|
---|
136 | }
|
---|
137 |
|
---|
138 | }
|
---|
139 |
|
---|
140 |
|
---|
141 | public static Instances loadAndFilterDataInstances(String input_filename, String additional_attribute_remove)
|
---|
142 | {
|
---|
143 | System.out.println("Training on file: " + input_filename);
|
---|
144 | Instances filtered_data_instances = null;
|
---|
145 |
|
---|
146 | try {
|
---|
147 | DataSource data_source = new DataSource(input_filename);
|
---|
148 | Instances data_instances = data_source.getDataSet();
|
---|
149 |
|
---|
150 | // Training dataset has two ground-truth attributes: 'arousal' and 'valence'.
|
---|
151 | // When training for one, need to knock out the other. This is the purpose
|
---|
152 | // of 'additional_attribute_remove'
|
---|
153 | filtered_data_instances = applyFilter(data_instances,additional_attribute_remove);
|
---|
154 |
|
---|
155 | // With the 'other' ground-truth attribute taken out, the column to predict
|
---|
156 | // will always be the last column
|
---|
157 | filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1);
|
---|
158 | }
|
---|
159 | catch (Exception e) {
|
---|
160 | e.printStackTrace();
|
---|
161 | System.exit(1);
|
---|
162 | }
|
---|
163 |
|
---|
164 | return filtered_data_instances;
|
---|
165 | }
|
---|
166 |
|
---|
167 | public static Classifier trainREPTree(Instances data_instances)
|
---|
168 | {
|
---|
169 | REPTree classifier = null;
|
---|
170 |
|
---|
171 | try {
|
---|
172 | // Build classifier
|
---|
173 | classifier = new REPTree();
|
---|
174 | String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
|
---|
175 | classifier.setOptions(reptree_options);
|
---|
176 |
|
---|
177 | classifier.buildClassifier(data_instances);
|
---|
178 | }
|
---|
179 | catch (Exception e) {
|
---|
180 | e.printStackTrace();
|
---|
181 | System.exit(1);
|
---|
182 | }
|
---|
183 |
|
---|
184 | return classifier;
|
---|
185 | }
|
---|
186 |
|
---|
187 |
|
---|
188 | public static Evaluation evaluateClassifier(Classifier classifier, Instances data_instances)
|
---|
189 | {
|
---|
190 | Evaluation eval = null;
|
---|
191 | try {
|
---|
192 | eval = new Evaluation(data_instances);
|
---|
193 | Random rand = new Random(1);
|
---|
194 | int folds = 10;
|
---|
195 |
|
---|
196 | eval.crossValidateModel(classifier, data_instances, folds, rand);
|
---|
197 | }
|
---|
198 | catch (Exception e) {
|
---|
199 | e.printStackTrace();
|
---|
200 | System.exit(1);
|
---|
201 | }
|
---|
202 |
|
---|
203 | return eval;
|
---|
204 | }
|
---|
205 |
|
---|
206 | /* Apply Model Specific Methods */
|
---|
207 |
|
---|
208 | public static Classifier loadClassifierModel(String classifier_input_filename
|
---|
209 |
|
---|
210 | )
|
---|
211 | {
|
---|
212 | System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
|
---|
213 | Classifier classifier = null;
|
---|
214 | try {
|
---|
215 | FileInputStream fis = new FileInputStream(classifier_input_filename);
|
---|
216 | BufferedInputStream bis= new BufferedInputStream(fis);
|
---|
217 | classifier = (Classifier)SerializationHelper.read(bis);
|
---|
218 | }
|
---|
219 | catch (Exception e) {
|
---|
220 | e.printStackTrace();
|
---|
221 | System.exit(1);
|
---|
222 | }
|
---|
223 |
|
---|
224 | return classifier;
|
---|
225 | }
|
---|
226 |
|
---|
227 |
|
---|
228 | public static Instances loadInstancesForClassification(String unlabeled_input_filename)
|
---|
229 | {
|
---|
230 | System.out.println("Loading unlabeled instances: " + unlabeled_input_filename);
|
---|
231 | Instances unlabeled_instances = null;
|
---|
232 |
|
---|
233 | try {
|
---|
234 | DataSource data_source = new DataSource(unlabeled_input_filename);
|
---|
235 | unlabeled_instances = data_source.getDataSet();
|
---|
236 | }
|
---|
237 | catch (Exception e) {
|
---|
238 | e.printStackTrace();
|
---|
239 | System.exit(1);
|
---|
240 | }
|
---|
241 |
|
---|
242 | return unlabeled_instances;
|
---|
243 | }
|
---|
244 |
|
---|
245 | public static boolean instancesHavePredictAttribute(Instances data_instances, String predict_attribute_name)
|
---|
246 | {
|
---|
247 | Attribute predict_attribute = data_instances.attribute(predict_attribute_name);
|
---|
248 |
|
---|
249 | return predict_attribute != null;
|
---|
250 | }
|
---|
251 |
|
---|
252 |
|
---|
253 |
|
---|
254 | public static Instances filterInstancesForApplying(Instances unlabeled_instances, boolean has_groundtruth_data,
|
---|
255 | String predict_attribute_name, String additional_attribute_remove)
|
---|
256 | {
|
---|
257 | Instances filtered_unlabeled_instances = null;
|
---|
258 |
|
---|
259 | if (!has_groundtruth_data) {
|
---|
260 |
|
---|
261 | filtered_unlabeled_instances = applyFilter(unlabeled_instances,null); // no additional top-up to remove
|
---|
262 | int num_attributes = filtered_unlabeled_instances.numAttributes();
|
---|
263 |
|
---|
264 | Attribute new_predict_attribute = new Attribute(predict_attribute_name);
|
---|
265 | filtered_unlabeled_instances.insertAttributeAt(new_predict_attribute,num_attributes);
|
---|
266 | // ******
|
---|
267 | //num_attributes++;
|
---|
268 | }
|
---|
269 | else {
|
---|
270 | // Dealing with ground-truth data:
|
---|
271 | // => already has 'arousal' and 'valence' attributes
|
---|
272 | // => need to keep the 'predict_attribute_name' and remove the other one
|
---|
273 | // => (its -R value of which is specified in 'additional_attribute_remove')
|
---|
274 |
|
---|
275 | // Need to massage instances into same form as an unclassified data input file
|
---|
276 | filtered_unlabeled_instances = applyFilter(unlabeled_instances,additional_attribute_remove);
|
---|
277 |
|
---|
278 | // reference share this as 'groundtruth_instances' to trigger error calculation and output
|
---|
279 | // ******
|
---|
280 | // groundtruth_instances = filtered_unlabeled_instances;
|
---|
281 | }
|
---|
282 |
|
---|
283 | int num_attributes = filtered_unlabeled_instances.numAttributes();
|
---|
284 |
|
---|
285 | // Set class attribute
|
---|
286 | filtered_unlabeled_instances.setClassIndex(num_attributes - 1);
|
---|
287 |
|
---|
288 | return filtered_unlabeled_instances;
|
---|
289 | }
|
---|
290 |
|
---|
291 | public static void appendUnclassifiedAttribute(Instances data_instances, String predict_attribute_name)
|
---|
292 | {
|
---|
293 | Instances filtered_unlabeled_instances = null;
|
---|
294 |
|
---|
295 | int num_attributes = data_instances.numAttributes();
|
---|
296 |
|
---|
297 | Attribute new_predict_attribute = new Attribute(predict_attribute_name);
|
---|
298 | data_instances.insertAttributeAt(new_predict_attribute,num_attributes);
|
---|
299 | num_attributes++;
|
---|
300 |
|
---|
301 | // Set class attribute
|
---|
302 | data_instances.setClassIndex(num_attributes - 1);
|
---|
303 | }
|
---|
304 |
|
---|
305 |
|
---|
306 | public static Instances makePredictions(Classifier classifier, Instances unlabeled_instances, Instances groundtruth_instances)
|
---|
307 | {
|
---|
308 |
|
---|
309 | // Create copy where the predictions are to be made
|
---|
310 | Instances labeled_instances = new Instances(unlabeled_instances);
|
---|
311 |
|
---|
312 | try {
|
---|
313 | // Label instances
|
---|
314 |
|
---|
315 | final int num_instances = unlabeled_instances.numInstances();
|
---|
316 | for (int i=0; i<num_instances; i++) {
|
---|
317 | Instance unlabeled_instance = unlabeled_instances.instance(i);
|
---|
318 |
|
---|
319 | System.out.print("Making prediction for: " + i + "/" + num_instances);
|
---|
320 | double classified_value = classifier.classifyInstance(unlabeled_instance);
|
---|
321 | labeled_instances.instance(i).setClassValue(classified_value);
|
---|
322 |
|
---|
323 | String formatted_classified_value = String.format("% 06.3f", classified_value);
|
---|
324 |
|
---|
325 | System.out.print(" value = " + formatted_classified_value);
|
---|
326 |
|
---|
327 | if (groundtruth_instances != null) {
|
---|
328 | Instance gt_instance = groundtruth_instances.instance(i);
|
---|
329 | double gt_class_value = gt_instance.classValue();
|
---|
330 | double error = Math.abs(classified_value - gt_class_value);
|
---|
331 |
|
---|
332 | String formatted_error = String.format("%.3f", error);
|
---|
333 | System.out.print(" [error: " + formatted_error + "]");
|
---|
334 | }
|
---|
335 | System.out.println();
|
---|
336 | }
|
---|
337 | }
|
---|
338 | catch (Exception e) {
|
---|
339 | e.printStackTrace();
|
---|
340 | System.exit(1);
|
---|
341 | }
|
---|
342 |
|
---|
343 | return labeled_instances;
|
---|
344 | }
|
---|
345 |
|
---|
346 | //
|
---|
347 | // https://waikato.github.io/weka-wiki/formats_and_processing/save_instances_to_arff/
|
---|
348 | //
|
---|
349 | public static void saveInstancesAsDataSink(Instances data_instances, String output_filename)
|
---|
350 | {
|
---|
351 | try {
|
---|
352 | System.out.println("Saving labeled instances: " + output_filename);
|
---|
353 | DataSink.write(output_filename, data_instances);
|
---|
354 | }
|
---|
355 | catch (Exception e) {
|
---|
356 | System.err.println("Failed to save data to: " + output_filename);
|
---|
357 | e.printStackTrace();
|
---|
358 | System.exit(1);
|
---|
359 | }
|
---|
360 | }
|
---|
361 |
|
---|
362 | }
|
---|