1 | //import java.util.Random;
|
---|
2 |
|
---|
3 | import java.io.BufferedInputStream;
|
---|
4 | import java.io.FileInputStream;
|
---|
5 |
|
---|
6 | import java.io.BufferedReader;
|
---|
7 | import java.io.BufferedWriter;
|
---|
8 | import java.io.FileReader;
|
---|
9 | import java.io.FileWriter;
|
---|
10 |
|
---|
11 | import weka.core.converters.ConverterUtils.DataSource;
|
---|
12 | import weka.core.Attribute;
|
---|
13 | import weka.core.Instance;
|
---|
14 | import weka.core.Instances;
|
---|
15 | import weka.core.SerializationHelper;
|
---|
16 |
|
---|
17 | import weka.filters.Filter;
|
---|
18 | import weka.filters.unsupervised.attribute.Remove;
|
---|
19 |
|
---|
20 | import weka.classifiers.Classifier;
|
---|
21 |
|
---|
22 |
|
---|
23 | // Based on:
|
---|
24 | // https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
|
---|
25 |
|
---|
26 | class WekaApplyArousalModel
|
---|
27 | {
|
---|
28 | public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
|
---|
29 | public final static String VALANCE_ATTRIBUTE_NAME = "valance_sample_26500ms";
|
---|
30 |
|
---|
31 | public final static String PREDICT_ATTRIBUTE_NAME = "arousal_sample_26500ms";
|
---|
32 |
|
---|
33 |
|
---|
34 | public static Instances applyFilter(Instances data_instances,String additional_remove)
|
---|
35 | {
|
---|
36 | String remove_option_args = ("-R 425,458-466"); // remove ordinal attributes
|
---|
37 | if (additional_remove != null) {
|
---|
38 | // top up, e.g. '472' representing valance in ground-truth files
|
---|
39 | remove_option_args += "," + additional_remove;
|
---|
40 | }
|
---|
41 |
|
---|
42 |
|
---|
43 | Instances filtered_data_instances = null;
|
---|
44 | try {
|
---|
45 | String[] filter_options = weka.core.Utils.splitOptions(remove_option_args); // remove ordinal attributes and 'valance'
|
---|
46 | Remove filter_remove = new Remove();
|
---|
47 | filter_remove.setOptions(filter_options);
|
---|
48 | filter_remove.setInputFormat(data_instances);
|
---|
49 | filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
|
---|
50 | }
|
---|
51 | catch (Exception e) {
|
---|
52 | e.printStackTrace();
|
---|
53 | System.exit(1);
|
---|
54 | }
|
---|
55 |
|
---|
56 | return filtered_data_instances;
|
---|
57 | }
|
---|
58 |
|
---|
59 |
|
---|
60 |
|
---|
61 | public static void checkDatasetInstancesCompatible(Instances new_instances)
|
---|
62 | {
|
---|
63 |
|
---|
64 | final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";
|
---|
65 |
|
---|
66 | try {
|
---|
67 | DataSource gt_data_source = new DataSource(gt_datasource_filename);
|
---|
68 | Instances gt_instances = gt_data_source.getDataSet();
|
---|
69 | gt_instances = applyFilter(gt_instances,"472"); // top up with removing 'valance'
|
---|
70 |
|
---|
71 | gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
|
---|
72 |
|
---|
73 | String equal_header_message = gt_instances.equalHeadersMsg(new_instances);
|
---|
74 |
|
---|
75 | if (equal_header_message != null) {
|
---|
76 | System.err.println("Headers to datasets were not equal!");
|
---|
77 | System.err.println(equal_header_message);
|
---|
78 | System.exit(1);
|
---|
79 | }
|
---|
80 | }
|
---|
81 | catch (Exception e) {
|
---|
82 | e.printStackTrace();
|
---|
83 | System.exit(1);
|
---|
84 | }
|
---|
85 |
|
---|
86 | }
|
---|
87 |
|
---|
88 | public static void main(String[] args)
|
---|
89 | {
|
---|
90 | if (args.length != 3) {
|
---|
91 | System.err.println("Error: incorrect number of command-line arguments");
|
---|
92 | System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}");
|
---|
93 | System.exit(1);
|
---|
94 | }
|
---|
95 |
|
---|
96 | String classifier_input_filename = args[0];
|
---|
97 | String unclassified_data_input_filename = args[1];
|
---|
98 | String classified_data_output_filename = args[2];
|
---|
99 |
|
---|
100 | System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
|
---|
101 |
|
---|
102 | try {
|
---|
103 | FileInputStream fis = new FileInputStream(classifier_input_filename);
|
---|
104 | BufferedInputStream bis= new BufferedInputStream(fis);
|
---|
105 | Classifier classifier = (Classifier)SerializationHelper.read(bis);
|
---|
106 |
|
---|
107 | //
|
---|
108 | // Load in unlabeled data
|
---|
109 | //
|
---|
110 |
|
---|
111 | System.out.println("Loading unlabeled instances: " + unclassified_data_input_filename);
|
---|
112 |
|
---|
113 | //FileReader fr = new FileReader(unclassified_data_input_filename);
|
---|
114 | //BufferedReader br = new BufferedReader(fr);
|
---|
115 | //Instances unlabeled_instances = new Instances(br);
|
---|
116 | //br.close();
|
---|
117 |
|
---|
118 | DataSource data_source = new DataSource(unclassified_data_input_filename);
|
---|
119 | Instances unlabeled_instances = data_source.getDataSet();
|
---|
120 |
|
---|
121 | // Work out if we're dealing with a ground-truth ARFF file or not
|
---|
122 | // (i.e. already has the desired attribute)
|
---|
123 |
|
---|
124 | Instances groundtruth_instances = null;
|
---|
125 | Attribute predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME);
|
---|
126 |
|
---|
127 | if (predict_attribute == null) {
|
---|
128 |
|
---|
129 | unlabeled_instances = applyFilter(unlabeled_instances,null); // no additional top-up to remove
|
---|
130 | int num_attributes = unlabeled_instances.numAttributes();
|
---|
131 |
|
---|
132 | Attribute arousal_attribute = new Attribute(AROUSAL_ATTRIBUTE_NAME);
|
---|
133 | unlabeled_instances.insertAttributeAt(arousal_attribute,num_attributes);
|
---|
134 | num_attributes++;
|
---|
135 |
|
---|
136 | //Attribute valance_attribute = new Attribute(VALANCE_ATTRIBUTE_NAME);
|
---|
137 | //unlabeled_instances.insertAttributeAt(valance_attribute,num_attributes);
|
---|
138 | //num_attributes++;
|
---|
139 |
|
---|
140 | //predict_attribute = new Attribute(PREDICT_ATTRIBUTE_NAME);
|
---|
141 | //unlabeled_instances.insertAttributeAt(predict_attribute,num_attributes);
|
---|
142 | //unlabeled_instances.setClassIndex(num_attributes);
|
---|
143 | //num_attributes++;
|
---|
144 |
|
---|
145 | //predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME);
|
---|
146 | //unlabeled_instances.setClass(predict_attribute);
|
---|
147 |
|
---|
148 | //unlabeled_instances.setClassIndex(num_attributes - 1);
|
---|
149 | }
|
---|
150 | else {
|
---|
151 | // Dealing with ground-truth data:
|
---|
152 | // => already has 'arousal' attribute
|
---|
153 | // => in fact has 'valance' attribute too, which we want to remove
|
---|
154 |
|
---|
155 | unlabeled_instances = applyFilter(unlabeled_instances,"472"); // top-up with removal of 'valance'
|
---|
156 | //unlabeled_instances.setClass(predict_attribute);
|
---|
157 |
|
---|
158 | // reference share this as 'groundtruth_instances' to trigger error calculation and output
|
---|
159 | groundtruth_instances = unlabeled_instances;
|
---|
160 | }
|
---|
161 |
|
---|
162 | int num_attributes = unlabeled_instances.numAttributes();
|
---|
163 | unlabeled_instances.setClassIndex(num_attributes - 1);
|
---|
164 |
|
---|
165 | checkDatasetInstancesCompatible(unlabeled_instances);
|
---|
166 |
|
---|
167 | // Set class attribute
|
---|
168 | //unlabeled_instances.setClassIndex(gt_instances.numAttributes() - 2);
|
---|
169 |
|
---|
170 |
|
---|
171 | // Create copy where the predictions are to be made
|
---|
172 | Instances labeled_instances = new Instances(unlabeled_instances);
|
---|
173 |
|
---|
174 | //
|
---|
175 | // Label instances
|
---|
176 | //
|
---|
177 |
|
---|
178 | final int num_instances = unlabeled_instances.numInstances();
|
---|
179 | for (int i=0; i<num_instances; i++) {
|
---|
180 | Instance unlabeled_instance = unlabeled_instances.instance(i);
|
---|
181 |
|
---|
182 | System.out.print("Making prediction for: " + i + "/" + num_instances);
|
---|
183 | double classified_value = classifier.classifyInstance(unlabeled_instance);
|
---|
184 | labeled_instances.instance(i).setClassValue(classified_value);
|
---|
185 |
|
---|
186 | String formatted_classified_value = String.format("% 06.3f", classified_value);
|
---|
187 |
|
---|
188 | System.out.print(" value = " + formatted_classified_value);
|
---|
189 |
|
---|
190 | if (groundtruth_instances != null) {
|
---|
191 | Instance gt_instance = groundtruth_instances.instance(i);
|
---|
192 | double gt_class_value = gt_instance.classValue();
|
---|
193 | double error = Math.abs(classified_value - gt_class_value);
|
---|
194 |
|
---|
195 | String formatted_error = String.format("%.3f", error);
|
---|
196 | System.out.print(" [error: " + formatted_error + "]");
|
---|
197 | }
|
---|
198 | System.out.println();
|
---|
199 | }
|
---|
200 |
|
---|
201 | //
|
---|
202 | // Save labeled data
|
---|
203 | //
|
---|
204 |
|
---|
205 | System.out.println("Saving labeled instances: " + classified_data_output_filename);
|
---|
206 | FileWriter fw = new FileWriter(classified_data_output_filename);
|
---|
207 | BufferedWriter bw = new BufferedWriter(fw);
|
---|
208 |
|
---|
209 | bw.write(labeled_instances.toString());
|
---|
210 | bw.newLine();
|
---|
211 | bw.flush();
|
---|
212 | bw.close();
|
---|
213 |
|
---|
214 | }
|
---|
215 | catch (Exception e) {
|
---|
216 | e.printStackTrace();
|
---|
217 | }
|
---|
218 |
|
---|
219 | }
|
---|
220 | }
|
---|