source: main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaUtil.java@ 34797

Last change on this file since 34797 was 34797, checked in by davidb, 3 years ago

Fixed up spelling mixup valance -> valence

File size: 10.2 KB
Line 
1package org.greenstone.mars;
2
3import java.io.BufferedInputStream;
4import java.io.FileInputStream;
5
6import java.io.BufferedReader;
7import java.io.BufferedWriter;
8import java.io.FileReader;
9import java.io.FileWriter;
10
11import java.util.Random;
12
13import weka.core.converters.ConverterUtils.DataSource;
14import weka.core.Attribute;
15import weka.core.Instance;
16import weka.core.Instances;
17import weka.core.SerializationHelper;
18
19import weka.filters.Filter;
20import weka.filters.unsupervised.attribute.Remove;
21
22import weka.classifiers.Classifier;
23import weka.classifiers.Evaluation;
24
25import weka.classifiers.trees.REPTree;
26
27// Based on:
28// https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
29
30class WekaUtil
31{
32 public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
33 public final static String VALENCE_ATTRIBUTE_NAME = "valence_sample_26500ms";
34
35
36 public static void checkUsageTraining(String[] args)
37 {
38 if (args.length != 2) {
39 System.err.println("Error: incorrect number of command-line arguments");
40 System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
41 System.exit(1);
42 }
43
44 }
45
46 public static void checkUsageApplyModel(String[] args)
47 {
48 if (args.length != 3) {
49 System.err.println("Error: incorrect number of command-line arguments");
50 System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}");
51 System.exit(1);
52 }
53 }
54
55 public static Instances applyFilter(Instances data_instances, String additional_remove)
56 {
57 // Avoid ordinal values in dataset
58 // (as tricky to work with in amc-essentia collection as essentia feature extractor
59 // on a single file is not guaranteed to generate the same set of ordinal vals in
60 // CSV file generated)
61
62 // Note, for RepTree, with vals to algorithm, does produce some outlier rules
63 // using spectral_spread.dvar with high int values, e.g.:
64 // lowlevel.spectral_spread.dvar < 3831853219840
65 //
66 // The following filtering produces as tree with 41 nodes
67
68 String remove_option_args = ("-R 425,427,458-466"); // knock out 11 attributes (472 -> 461)
69
70
71 // For 'arousal', filtering to the following produces a smaller REPTree
72 // with only 27 nodes
73 //String remove_option_args = ("-R 458-466");
74
75 if (additional_remove != null) {
76 // top up,
77 // e.g. '471' representing arousal in ground-truth files
78 // e.g. '472' representing valence in ground-truth files
79 remove_option_args += "," + additional_remove;
80 }
81
82
83 Instances filtered_data_instances = null;
84 try {
85 // remove ordinal attributes and any additional topups,
86 // such as 'valence' (when predicting 'arousal') and vice versa
87
88 String[] filter_options = weka.core.Utils.splitOptions(remove_option_args);
89 Remove filter_remove = new Remove();
90 filter_remove.setOptions(filter_options);
91 filter_remove.setInputFormat(data_instances);
92 filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
93 }
94 catch (Exception e) {
95 e.printStackTrace();
96 System.exit(1);
97 }
98
99 return filtered_data_instances;
100 }
101
102
103 public static void checkDatasetInstancesCompatible(Instances new_instances, String additional_remove)
104 {
105
106 final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";
107
108 try {
109 DataSource gt_data_source = new DataSource(gt_datasource_filename);
110 Instances gt_instances = gt_data_source.getDataSet();
111 gt_instances = applyFilter(gt_instances,additional_remove); // remove 'valence' or 'arousal'
112
113 gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
114
115 String equal_header_message = gt_instances.equalHeadersMsg(new_instances);
116
117 if (equal_header_message != null) {
118 System.err.println("Headers to datasets were not equal!");
119 System.err.println(equal_header_message);
120 System.exit(1);
121 }
122 }
123 catch (Exception e) {
124 e.printStackTrace();
125 System.exit(1);
126 }
127
128 }
129
130
131 public static Instances loadAndFilterDataInstances(String input_filename, String additional_attribute_remove)
132 {
133 System.out.println("Training on file: " + input_filename);
134 Instances filtered_data_instances = null;
135
136 try {
137 DataSource data_source = new DataSource(input_filename);
138 Instances data_instances = data_source.getDataSet();
139
140 // Training dataset has two ground-truth attributes: 'arousal' and 'valence'.
141 // When training for one, need to knock out the other. This is the purpose
142 // of 'additional_attribute_remove'
143 filtered_data_instances = applyFilter(data_instances,additional_attribute_remove);
144
145 // With the 'other' ground-truth attribute taken out, the column to predict
146 // will always be the last column
147 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1);
148 }
149 catch (Exception e) {
150 e.printStackTrace();
151 System.exit(1);
152 }
153
154 return filtered_data_instances;
155 }
156
157 public static Classifier trainREPTree(Instances data_instances)
158 {
159 REPTree classifier = null;
160
161 try {
162 // Build classifier
163 classifier = new REPTree();
164 String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
165 classifier.setOptions(reptree_options);
166
167 classifier.buildClassifier(data_instances);
168 }
169 catch (Exception e) {
170 e.printStackTrace();
171 System.exit(1);
172 }
173
174 return classifier;
175 }
176
177
178 public static Evaluation evaluateClassifier(Classifier classifier, Instances data_instances)
179 {
180 Evaluation eval = null;
181 try {
182 eval = new Evaluation(data_instances);
183 Random rand = new Random(1);
184 int folds = 10;
185
186 eval.crossValidateModel(classifier, data_instances, folds, rand);
187 }
188 catch (Exception e) {
189 e.printStackTrace();
190 System.exit(1);
191 }
192
193 return eval;
194 }
195
196 /* Apply Model Specific Methods */
197
198 public static Classifier loadClassifierModel(String classifier_input_filename
199
200 )
201 {
202 System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
203 Classifier classifier = null;
204 try {
205 FileInputStream fis = new FileInputStream(classifier_input_filename);
206 BufferedInputStream bis= new BufferedInputStream(fis);
207 classifier = (Classifier)SerializationHelper.read(bis);
208 }
209 catch (Exception e) {
210 e.printStackTrace();
211 System.exit(1);
212 }
213
214 return classifier;
215 }
216
217
218 public static Instances loadInstancesForClassification(String unlabeled_input_filename)
219 {
220 System.out.println("Loading unlabeled instances: " + unlabeled_input_filename);
221 Instances unlabeled_instances = null;
222
223 try {
224 DataSource data_source = new DataSource(unlabeled_input_filename);
225 unlabeled_instances = data_source.getDataSet();
226 }
227 catch (Exception e) {
228 e.printStackTrace();
229 System.exit(1);
230 }
231
232 return unlabeled_instances;
233 }
234
235 public static boolean instancesHavePredictAttribute(Instances data_instances, String predict_attribute_name)
236 {
237 Attribute predict_attribute = data_instances.attribute(predict_attribute_name);
238
239 return predict_attribute != null;
240 }
241
242
243
244 public static Instances filterInstancesForApplying(Instances unlabeled_instances, boolean has_groundtruth_data,
245 String predict_attribute_name, String additional_attribute_remove)
246 {
247 Instances filtered_unlabeled_instances = null;
248
249
250 // Work out if we're dealing with a ground-truth ARFF file or not
251 // (i.e. already has the desired attribute)
252
253 //Attribute predict_attribute = unlabeled_instances.attribute(predict_attribute_name);
254
255 if (!has_groundtruth_data) {
256
257 filtered_unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,null); // no additional top-up to remove
258 int num_attributes = filtered_unlabeled_instances.numAttributes();
259
260 Attribute new_predict_attribute = new Attribute(predict_attribute_name);
261 filtered_unlabeled_instances.insertAttributeAt(new_predict_attribute,num_attributes);
262 // ******
263 //num_attributes++;
264 }
265 else {
266 // Dealing with ground-truth data:
267 // => already has 'arousal' and 'valence' attributes
268 // => need to keep the 'predict_attribute_name' and remove the other one
269 // => (its -R value of which is specified in 'additional_attribute_remove')
270
271 // Need to massage instances into same form as an unclassified data input file
272 filtered_unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,additional_attribute_remove);
273
274 // reference share this as 'groundtruth_instances' to trigger error calculation and output
275 // ******
276 // groundtruth_instances = filtered_unlabeled_instances;
277 }
278
279 int num_attributes = filtered_unlabeled_instances.numAttributes();
280
281 // Set class attribute
282 filtered_unlabeled_instances.setClassIndex(num_attributes - 1);
283
284 // ***** Do I still want to run the check????
285 WekaUtil.checkDatasetInstancesCompatible(filtered_unlabeled_instances, additional_attribute_remove);
286
287 return filtered_unlabeled_instances;
288 }
289
290
291 public static Instances makePredictions(Classifier classifier, Instances unlabeled_instances, Instances groundtruth_instances)
292 {
293
294 // Create copy where the predictions are to be made
295 Instances labeled_instances = new Instances(unlabeled_instances);
296
297 try {
298 // Label instances
299
300 final int num_instances = unlabeled_instances.numInstances();
301 for (int i=0; i<num_instances; i++) {
302 Instance unlabeled_instance = unlabeled_instances.instance(i);
303
304 System.out.print("Making prediction for: " + i + "/" + num_instances);
305 double classified_value = classifier.classifyInstance(unlabeled_instance);
306 labeled_instances.instance(i).setClassValue(classified_value);
307
308 String formatted_classified_value = String.format("% 06.3f", classified_value);
309
310 System.out.print(" value = " + formatted_classified_value);
311
312 if (groundtruth_instances != null) {
313 Instance gt_instance = groundtruth_instances.instance(i);
314 double gt_class_value = gt_instance.classValue();
315 double error = Math.abs(classified_value - gt_class_value);
316
317 String formatted_error = String.format("%.3f", error);
318 System.out.print(" [error: " + formatted_error + "]");
319 }
320 System.out.println();
321 }
322 }
323 catch (Exception e) {
324 e.printStackTrace();
325 System.exit(1);
326 }
327
328 return labeled_instances;
329 }
330
331}
Note: See TracBrowser for help on using the repository browser.