1 | package org.greenstone.mars;
|
---|
2 |
|
---|
3 | import java.io.BufferedInputStream;
|
---|
4 | import java.io.FileInputStream;
|
---|
5 |
|
---|
6 | import java.io.BufferedReader;
|
---|
7 | import java.io.BufferedWriter;
|
---|
8 | import java.io.FileReader;
|
---|
9 | import java.io.FileWriter;
|
---|
10 |
|
---|
11 | import java.util.Random;
|
---|
12 |
|
---|
13 | import weka.core.converters.ConverterUtils.DataSource;
|
---|
14 | import weka.core.Attribute;
|
---|
15 | import weka.core.Instance;
|
---|
16 | import weka.core.Instances;
|
---|
17 | import weka.core.SerializationHelper;
|
---|
18 |
|
---|
19 | import weka.filters.Filter;
|
---|
20 | import weka.filters.unsupervised.attribute.Remove;
|
---|
21 |
|
---|
22 | import weka.classifiers.Classifier;
|
---|
23 | import weka.classifiers.Evaluation;
|
---|
24 |
|
---|
25 | import weka.classifiers.trees.REPTree;
|
---|
26 |
|
---|
27 | // Based on:
|
---|
28 | // https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
|
---|
29 |
|
---|
30 | class WekaUtil
|
---|
31 | {
|
---|
32 | public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
|
---|
33 | public final static String VALENCE_ATTRIBUTE_NAME = "valence_sample_26500ms";
|
---|
34 |
|
---|
35 |
|
---|
36 | public static void checkUsageTraining(String[] args)
|
---|
37 | {
|
---|
38 | if (args.length != 2) {
|
---|
39 | System.err.println("Error: incorrect number of command-line arguments");
|
---|
40 | System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
|
---|
41 | System.exit(1);
|
---|
42 | }
|
---|
43 |
|
---|
44 | }
|
---|
45 |
|
---|
46 | public static void checkUsageApplyModel(String[] args)
|
---|
47 | {
|
---|
48 | if (args.length != 3) {
|
---|
49 | System.err.println("Error: incorrect number of command-line arguments");
|
---|
50 | System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}");
|
---|
51 | System.exit(1);
|
---|
52 | }
|
---|
53 | }
|
---|
54 |
|
---|
55 | public static Instances applyFilter(Instances data_instances, String additional_remove)
|
---|
56 | {
|
---|
57 | // Avoid ordinal values in dataset
|
---|
58 | // (as tricky to work with in amc-essentia collection as essentia feature extractor
|
---|
59 | // on a single file is not guaranteed to generate the same set of ordinal vals in
|
---|
60 | // CSV file generated)
|
---|
61 |
|
---|
62 | // Note, for RepTree, with vals to algorithm, does produce some outlier rules
|
---|
63 | // using spectral_spread.dvar with high int values, e.g.:
|
---|
64 | // lowlevel.spectral_spread.dvar < 3831853219840
|
---|
65 | //
|
---|
66 | // The following filtering produces as tree with 41 nodes
|
---|
67 |
|
---|
68 | String remove_option_args = ("-R 425,427,458-466"); // knock out 11 attributes (472 -> 461)
|
---|
69 |
|
---|
70 |
|
---|
71 | // For 'arousal', filtering to the following produces a smaller REPTree
|
---|
72 | // with only 27 nodes
|
---|
73 | //String remove_option_args = ("-R 458-466");
|
---|
74 |
|
---|
75 | if (additional_remove != null) {
|
---|
76 | // top up,
|
---|
77 | // e.g. '471' representing arousal in ground-truth files
|
---|
78 | // e.g. '472' representing valence in ground-truth files
|
---|
79 | remove_option_args += "," + additional_remove;
|
---|
80 | }
|
---|
81 |
|
---|
82 |
|
---|
83 | Instances filtered_data_instances = null;
|
---|
84 | try {
|
---|
85 | // remove ordinal attributes and any additional topups,
|
---|
86 | // such as 'valence' (when predicting 'arousal') and vice versa
|
---|
87 |
|
---|
88 | String[] filter_options = weka.core.Utils.splitOptions(remove_option_args);
|
---|
89 | Remove filter_remove = new Remove();
|
---|
90 | filter_remove.setOptions(filter_options);
|
---|
91 | filter_remove.setInputFormat(data_instances);
|
---|
92 | filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
|
---|
93 | }
|
---|
94 | catch (Exception e) {
|
---|
95 | e.printStackTrace();
|
---|
96 | System.exit(1);
|
---|
97 | }
|
---|
98 |
|
---|
99 | return filtered_data_instances;
|
---|
100 | }
|
---|
101 |
|
---|
102 |
|
---|
103 | public static void checkDatasetInstancesCompatible(Instances new_instances, String additional_remove)
|
---|
104 | {
|
---|
105 |
|
---|
106 | final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";
|
---|
107 |
|
---|
108 | try {
|
---|
109 | DataSource gt_data_source = new DataSource(gt_datasource_filename);
|
---|
110 | Instances gt_instances = gt_data_source.getDataSet();
|
---|
111 | gt_instances = applyFilter(gt_instances,additional_remove); // remove 'valence' or 'arousal'
|
---|
112 |
|
---|
113 | gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
|
---|
114 |
|
---|
115 | String equal_header_message = gt_instances.equalHeadersMsg(new_instances);
|
---|
116 |
|
---|
117 | if (equal_header_message != null) {
|
---|
118 | System.err.println("Headers to datasets were not equal!");
|
---|
119 | System.err.println(equal_header_message);
|
---|
120 | System.exit(1);
|
---|
121 | }
|
---|
122 | }
|
---|
123 | catch (Exception e) {
|
---|
124 | e.printStackTrace();
|
---|
125 | System.exit(1);
|
---|
126 | }
|
---|
127 |
|
---|
128 | }
|
---|
129 |
|
---|
130 |
|
---|
131 | public static Instances loadAndFilterDataInstances(String input_filename, String additional_attribute_remove)
|
---|
132 | {
|
---|
133 | System.out.println("Training on file: " + input_filename);
|
---|
134 | Instances filtered_data_instances = null;
|
---|
135 |
|
---|
136 | try {
|
---|
137 | DataSource data_source = new DataSource(input_filename);
|
---|
138 | Instances data_instances = data_source.getDataSet();
|
---|
139 |
|
---|
140 | // Training dataset has two ground-truth attributes: 'arousal' and 'valence'.
|
---|
141 | // When training for one, need to knock out the other. This is the purpose
|
---|
142 | // of 'additional_attribute_remove'
|
---|
143 | filtered_data_instances = applyFilter(data_instances,additional_attribute_remove);
|
---|
144 |
|
---|
145 | // With the 'other' ground-truth attribute taken out, the column to predict
|
---|
146 | // will always be the last column
|
---|
147 | filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1);
|
---|
148 | }
|
---|
149 | catch (Exception e) {
|
---|
150 | e.printStackTrace();
|
---|
151 | System.exit(1);
|
---|
152 | }
|
---|
153 |
|
---|
154 | return filtered_data_instances;
|
---|
155 | }
|
---|
156 |
|
---|
157 | public static Classifier trainREPTree(Instances data_instances)
|
---|
158 | {
|
---|
159 | REPTree classifier = null;
|
---|
160 |
|
---|
161 | try {
|
---|
162 | // Build classifier
|
---|
163 | classifier = new REPTree();
|
---|
164 | String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
|
---|
165 | classifier.setOptions(reptree_options);
|
---|
166 |
|
---|
167 | classifier.buildClassifier(data_instances);
|
---|
168 | }
|
---|
169 | catch (Exception e) {
|
---|
170 | e.printStackTrace();
|
---|
171 | System.exit(1);
|
---|
172 | }
|
---|
173 |
|
---|
174 | return classifier;
|
---|
175 | }
|
---|
176 |
|
---|
177 |
|
---|
178 | public static Evaluation evaluateClassifier(Classifier classifier, Instances data_instances)
|
---|
179 | {
|
---|
180 | Evaluation eval = null;
|
---|
181 | try {
|
---|
182 | eval = new Evaluation(data_instances);
|
---|
183 | Random rand = new Random(1);
|
---|
184 | int folds = 10;
|
---|
185 |
|
---|
186 | eval.crossValidateModel(classifier, data_instances, folds, rand);
|
---|
187 | }
|
---|
188 | catch (Exception e) {
|
---|
189 | e.printStackTrace();
|
---|
190 | System.exit(1);
|
---|
191 | }
|
---|
192 |
|
---|
193 | return eval;
|
---|
194 | }
|
---|
195 |
|
---|
196 | /* Apply Model Specific Methods */
|
---|
197 |
|
---|
198 | public static Classifier loadClassifierModel(String classifier_input_filename
|
---|
199 |
|
---|
200 | )
|
---|
201 | {
|
---|
202 | System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
|
---|
203 | Classifier classifier = null;
|
---|
204 | try {
|
---|
205 | FileInputStream fis = new FileInputStream(classifier_input_filename);
|
---|
206 | BufferedInputStream bis= new BufferedInputStream(fis);
|
---|
207 | classifier = (Classifier)SerializationHelper.read(bis);
|
---|
208 | }
|
---|
209 | catch (Exception e) {
|
---|
210 | e.printStackTrace();
|
---|
211 | System.exit(1);
|
---|
212 | }
|
---|
213 |
|
---|
214 | return classifier;
|
---|
215 | }
|
---|
216 |
|
---|
217 |
|
---|
218 | public static Instances loadInstancesForClassification(String unlabeled_input_filename)
|
---|
219 | {
|
---|
220 | System.out.println("Loading unlabeled instances: " + unlabeled_input_filename);
|
---|
221 | Instances unlabeled_instances = null;
|
---|
222 |
|
---|
223 | try {
|
---|
224 | DataSource data_source = new DataSource(unlabeled_input_filename);
|
---|
225 | unlabeled_instances = data_source.getDataSet();
|
---|
226 | }
|
---|
227 | catch (Exception e) {
|
---|
228 | e.printStackTrace();
|
---|
229 | System.exit(1);
|
---|
230 | }
|
---|
231 |
|
---|
232 | return unlabeled_instances;
|
---|
233 | }
|
---|
234 |
|
---|
235 | public static boolean instancesHavePredictAttribute(Instances data_instances, String predict_attribute_name)
|
---|
236 | {
|
---|
237 | Attribute predict_attribute = data_instances.attribute(predict_attribute_name);
|
---|
238 |
|
---|
239 | return predict_attribute != null;
|
---|
240 | }
|
---|
241 |
|
---|
242 |
|
---|
243 |
|
---|
244 | public static Instances filterInstancesForApplying(Instances unlabeled_instances, boolean has_groundtruth_data,
|
---|
245 | String predict_attribute_name, String additional_attribute_remove)
|
---|
246 | {
|
---|
247 | Instances filtered_unlabeled_instances = null;
|
---|
248 |
|
---|
249 |
|
---|
250 | // Work out if we're dealing with a ground-truth ARFF file or not
|
---|
251 | // (i.e. already has the desired attribute)
|
---|
252 |
|
---|
253 | //Attribute predict_attribute = unlabeled_instances.attribute(predict_attribute_name);
|
---|
254 |
|
---|
255 | if (!has_groundtruth_data) {
|
---|
256 |
|
---|
257 | filtered_unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,null); // no additional top-up to remove
|
---|
258 | int num_attributes = filtered_unlabeled_instances.numAttributes();
|
---|
259 |
|
---|
260 | Attribute new_predict_attribute = new Attribute(predict_attribute_name);
|
---|
261 | filtered_unlabeled_instances.insertAttributeAt(new_predict_attribute,num_attributes);
|
---|
262 | // ******
|
---|
263 | //num_attributes++;
|
---|
264 | }
|
---|
265 | else {
|
---|
266 | // Dealing with ground-truth data:
|
---|
267 | // => already has 'arousal' and 'valence' attributes
|
---|
268 | // => need to keep the 'predict_attribute_name' and remove the other one
|
---|
269 | // => (its -R value of which is specified in 'additional_attribute_remove')
|
---|
270 |
|
---|
271 | // Need to massage instances into same form as an unclassified data input file
|
---|
272 | filtered_unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,additional_attribute_remove);
|
---|
273 |
|
---|
274 | // reference share this as 'groundtruth_instances' to trigger error calculation and output
|
---|
275 | // ******
|
---|
276 | // groundtruth_instances = filtered_unlabeled_instances;
|
---|
277 | }
|
---|
278 |
|
---|
279 | int num_attributes = filtered_unlabeled_instances.numAttributes();
|
---|
280 |
|
---|
281 | // Set class attribute
|
---|
282 | filtered_unlabeled_instances.setClassIndex(num_attributes - 1);
|
---|
283 |
|
---|
284 | // ***** Do I still want to run the check????
|
---|
285 | WekaUtil.checkDatasetInstancesCompatible(filtered_unlabeled_instances, additional_attribute_remove);
|
---|
286 |
|
---|
287 | return filtered_unlabeled_instances;
|
---|
288 | }
|
---|
289 |
|
---|
290 |
|
---|
291 | public static Instances makePredictions(Classifier classifier, Instances unlabeled_instances, Instances groundtruth_instances)
|
---|
292 | {
|
---|
293 |
|
---|
294 | // Create copy where the predictions are to be made
|
---|
295 | Instances labeled_instances = new Instances(unlabeled_instances);
|
---|
296 |
|
---|
297 | try {
|
---|
298 | // Label instances
|
---|
299 |
|
---|
300 | final int num_instances = unlabeled_instances.numInstances();
|
---|
301 | for (int i=0; i<num_instances; i++) {
|
---|
302 | Instance unlabeled_instance = unlabeled_instances.instance(i);
|
---|
303 |
|
---|
304 | System.out.print("Making prediction for: " + i + "/" + num_instances);
|
---|
305 | double classified_value = classifier.classifyInstance(unlabeled_instance);
|
---|
306 | labeled_instances.instance(i).setClassValue(classified_value);
|
---|
307 |
|
---|
308 | String formatted_classified_value = String.format("% 06.3f", classified_value);
|
---|
309 |
|
---|
310 | System.out.print(" value = " + formatted_classified_value);
|
---|
311 |
|
---|
312 | if (groundtruth_instances != null) {
|
---|
313 | Instance gt_instance = groundtruth_instances.instance(i);
|
---|
314 | double gt_class_value = gt_instance.classValue();
|
---|
315 | double error = Math.abs(classified_value - gt_class_value);
|
---|
316 |
|
---|
317 | String formatted_error = String.format("%.3f", error);
|
---|
318 | System.out.print(" [error: " + formatted_error + "]");
|
---|
319 | }
|
---|
320 | System.out.println();
|
---|
321 | }
|
---|
322 | }
|
---|
323 | catch (Exception e) {
|
---|
324 | e.printStackTrace();
|
---|
325 | System.exit(1);
|
---|
326 | }
|
---|
327 |
|
---|
328 | return labeled_instances;
|
---|
329 | }
|
---|
330 |
|
---|
331 | }
|
---|