source: main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaUtil.java@ 35207

Last change on this file since 35207 was 34802, checked in by davidb, 3 years ago

Saving instances changed to a DataSink, and located in WekaUtil

File size: 11.2 KB
Line 
1package org.greenstone.mars;
2
3import java.io.BufferedInputStream;
4import java.io.FileInputStream;
5
6import java.io.BufferedReader;
7import java.io.BufferedWriter;
8import java.io.FileReader;
9import java.io.FileWriter;
10
11import java.util.Random;
12
13import weka.core.Attribute;
14import weka.core.Instance;
15import weka.core.Instances;
16import weka.core.SerializationHelper;
17import weka.core.converters.ConverterUtils.DataSource;
18import weka.core.converters.ConverterUtils.DataSink;
19
20import weka.filters.Filter;
21import weka.filters.unsupervised.attribute.Remove;
22
23import weka.classifiers.Classifier;
24import weka.classifiers.Evaluation;
25
26import weka.classifiers.trees.REPTree;
27
28// Based on:
29// https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
30
31class WekaUtil
32{
33 public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
34 public final static String VALENCE_ATTRIBUTE_NAME = "valence_sample_26500ms";
35
36
37 public static void checkUsageTraining(String[] args)
38 {
39 if (args.length != 2) {
40 System.err.println("Error: incorrect number of command-line arguments");
41 System.err.println("Usage: input_training_data.arff output-serialized.model");
42 System.exit(1);
43 }
44
45 }
46
47 public static void checkUsageApplyModel(String[] args)
48 {
49 if (args.length != 3) {
50 System.err.println("Error: incorrect number of command-line arguments");
51 System.err.println("Usage: trained-serialized.model unclassified-data.{arff|csv} classified-data.{arff|csv}");
52 System.exit(1);
53 }
54 }
55
56 public static void checkUsageApplyAVModels(String[] args)
57 {
58 if (args.length != 4) {
59 System.err.println("Error: incorrect number of command-line arguments");
60 System.err.println("Usage: arousal-serialized.model valence-serialized.model unclassified-data.{arff|csv} classified-data.{arff|csv}");
61 System.exit(1);
62 }
63 }
64
65 public static Instances applyFilter(Instances data_instances, String additional_remove)
66 {
67 // Avoid ordinal values in dataset
68 // (as tricky to work with in amc-essentia collection as essentia feature extractor
69 // on a single file is not guaranteed to generate the same set of ordinal vals in
70 // CSV file generated)
71
72 // Note, for RepTree, with vals to algorithm, does produce some outlier rules
73 // using spectral_spread.dvar with high int values, e.g.:
74 // lowlevel.spectral_spread.dvar < 3831853219840
75 //
76 // The following filtering produces as tree with 41 nodes
77
78 String remove_option_args = ("-R 425,427,458-466"); // knock out 11 attributes (472 -> 461)
79
80
81 // For 'arousal', filtering to the following produces a smaller REPTree
82 // with only 27 nodes
83 //String remove_option_args = ("-R 458-466");
84
85 if (additional_remove != null) {
86 // top up,
87 // e.g. '471' representing arousal in ground-truth files
88 // e.g. '472' representing valence in ground-truth files
89 remove_option_args += "," + additional_remove;
90 }
91
92
93 Instances filtered_data_instances = null;
94 try {
95 // remove ordinal attributes and any additional topups,
96 // such as 'valence' (when predicting 'arousal') and vice versa
97
98 String[] filter_options = weka.core.Utils.splitOptions(remove_option_args);
99 Remove filter_remove = new Remove();
100 filter_remove.setOptions(filter_options);
101 filter_remove.setInputFormat(data_instances);
102 filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
103 }
104 catch (Exception e) {
105 e.printStackTrace();
106 System.exit(1);
107 }
108
109 return filtered_data_instances;
110 }
111
112
113 public static void checkDatasetInstancesCompatible(Instances new_instances, String additional_remove)
114 {
115
116 final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";
117
118 try {
119 DataSource gt_data_source = new DataSource(gt_datasource_filename);
120 Instances gt_instances = gt_data_source.getDataSet();
121 gt_instances = applyFilter(gt_instances,additional_remove); // remove 'valence' or 'arousal'
122
123 gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
124
125 String equal_header_message = gt_instances.equalHeadersMsg(new_instances);
126
127 if (equal_header_message != null) {
128 System.err.println("Headers to datasets were not equal!");
129 System.err.println(equal_header_message);
130 System.exit(1);
131 }
132 }
133 catch (Exception e) {
134 e.printStackTrace();
135 System.exit(1);
136 }
137
138 }
139
140
141 public static Instances loadAndFilterDataInstances(String input_filename, String additional_attribute_remove)
142 {
143 System.out.println("Training on file: " + input_filename);
144 Instances filtered_data_instances = null;
145
146 try {
147 DataSource data_source = new DataSource(input_filename);
148 Instances data_instances = data_source.getDataSet();
149
150 // Training dataset has two ground-truth attributes: 'arousal' and 'valence'.
151 // When training for one, need to knock out the other. This is the purpose
152 // of 'additional_attribute_remove'
153 filtered_data_instances = applyFilter(data_instances,additional_attribute_remove);
154
155 // With the 'other' ground-truth attribute taken out, the column to predict
156 // will always be the last column
157 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1);
158 }
159 catch (Exception e) {
160 e.printStackTrace();
161 System.exit(1);
162 }
163
164 return filtered_data_instances;
165 }
166
167 public static Classifier trainREPTree(Instances data_instances)
168 {
169 REPTree classifier = null;
170
171 try {
172 // Build classifier
173 classifier = new REPTree();
174 String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
175 classifier.setOptions(reptree_options);
176
177 classifier.buildClassifier(data_instances);
178 }
179 catch (Exception e) {
180 e.printStackTrace();
181 System.exit(1);
182 }
183
184 return classifier;
185 }
186
187
188 public static Evaluation evaluateClassifier(Classifier classifier, Instances data_instances)
189 {
190 Evaluation eval = null;
191 try {
192 eval = new Evaluation(data_instances);
193 Random rand = new Random(1);
194 int folds = 10;
195
196 eval.crossValidateModel(classifier, data_instances, folds, rand);
197 }
198 catch (Exception e) {
199 e.printStackTrace();
200 System.exit(1);
201 }
202
203 return eval;
204 }
205
206 /* Apply Model Specific Methods */
207
208 public static Classifier loadClassifierModel(String classifier_input_filename
209
210 )
211 {
212 System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
213 Classifier classifier = null;
214 try {
215 FileInputStream fis = new FileInputStream(classifier_input_filename);
216 BufferedInputStream bis= new BufferedInputStream(fis);
217 classifier = (Classifier)SerializationHelper.read(bis);
218 }
219 catch (Exception e) {
220 e.printStackTrace();
221 System.exit(1);
222 }
223
224 return classifier;
225 }
226
227
228 public static Instances loadInstancesForClassification(String unlabeled_input_filename)
229 {
230 System.out.println("Loading unlabeled instances: " + unlabeled_input_filename);
231 Instances unlabeled_instances = null;
232
233 try {
234 DataSource data_source = new DataSource(unlabeled_input_filename);
235 unlabeled_instances = data_source.getDataSet();
236 }
237 catch (Exception e) {
238 e.printStackTrace();
239 System.exit(1);
240 }
241
242 return unlabeled_instances;
243 }
244
245 public static boolean instancesHavePredictAttribute(Instances data_instances, String predict_attribute_name)
246 {
247 Attribute predict_attribute = data_instances.attribute(predict_attribute_name);
248
249 return predict_attribute != null;
250 }
251
252
253
254 public static Instances filterInstancesForApplying(Instances unlabeled_instances, boolean has_groundtruth_data,
255 String predict_attribute_name, String additional_attribute_remove)
256 {
257 Instances filtered_unlabeled_instances = null;
258
259 if (!has_groundtruth_data) {
260
261 filtered_unlabeled_instances = applyFilter(unlabeled_instances,null); // no additional top-up to remove
262 int num_attributes = filtered_unlabeled_instances.numAttributes();
263
264 Attribute new_predict_attribute = new Attribute(predict_attribute_name);
265 filtered_unlabeled_instances.insertAttributeAt(new_predict_attribute,num_attributes);
266 // ******
267 //num_attributes++;
268 }
269 else {
270 // Dealing with ground-truth data:
271 // => already has 'arousal' and 'valence' attributes
272 // => need to keep the 'predict_attribute_name' and remove the other one
273 // => (its -R value of which is specified in 'additional_attribute_remove')
274
275 // Need to massage instances into same form as an unclassified data input file
276 filtered_unlabeled_instances = applyFilter(unlabeled_instances,additional_attribute_remove);
277
278 // reference share this as 'groundtruth_instances' to trigger error calculation and output
279 // ******
280 // groundtruth_instances = filtered_unlabeled_instances;
281 }
282
283 int num_attributes = filtered_unlabeled_instances.numAttributes();
284
285 // Set class attribute
286 filtered_unlabeled_instances.setClassIndex(num_attributes - 1);
287
288 return filtered_unlabeled_instances;
289 }
290
291 public static void appendUnclassifiedAttribute(Instances data_instances, String predict_attribute_name)
292 {
293 Instances filtered_unlabeled_instances = null;
294
295 int num_attributes = data_instances.numAttributes();
296
297 Attribute new_predict_attribute = new Attribute(predict_attribute_name);
298 data_instances.insertAttributeAt(new_predict_attribute,num_attributes);
299 num_attributes++;
300
301 // Set class attribute
302 data_instances.setClassIndex(num_attributes - 1);
303 }
304
305
306 public static Instances makePredictions(Classifier classifier, Instances unlabeled_instances, Instances groundtruth_instances)
307 {
308
309 // Create copy where the predictions are to be made
310 Instances labeled_instances = new Instances(unlabeled_instances);
311
312 try {
313 // Label instances
314
315 final int num_instances = unlabeled_instances.numInstances();
316 for (int i=0; i<num_instances; i++) {
317 Instance unlabeled_instance = unlabeled_instances.instance(i);
318
319 System.out.print("Making prediction for: " + i + "/" + num_instances);
320 double classified_value = classifier.classifyInstance(unlabeled_instance);
321 labeled_instances.instance(i).setClassValue(classified_value);
322
323 String formatted_classified_value = String.format("% 06.3f", classified_value);
324
325 System.out.print(" value = " + formatted_classified_value);
326
327 if (groundtruth_instances != null) {
328 Instance gt_instance = groundtruth_instances.instance(i);
329 double gt_class_value = gt_instance.classValue();
330 double error = Math.abs(classified_value - gt_class_value);
331
332 String formatted_error = String.format("%.3f", error);
333 System.out.print(" [error: " + formatted_error + "]");
334 }
335 System.out.println();
336 }
337 }
338 catch (Exception e) {
339 e.printStackTrace();
340 System.exit(1);
341 }
342
343 return labeled_instances;
344 }
345
346 //
347 // https://waikato.github.io/weka-wiki/formats_and_processing/save_instances_to_arff/
348 //
349 public static void saveInstancesAsDataSink(Instances data_instances, String output_filename)
350 {
351 try {
352 System.out.println("Saving labeled instances: " + output_filename);
353 DataSink.write(output_filename, data_instances);
354 }
355 catch (Exception e) {
356 System.err.println("Failed to save data to: " + output_filename);
357 e.printStackTrace();
358 System.exit(1);
359 }
360 }
361
362}
Note: See TracBrowser for help on using the repository browser.