source: main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaApplyArousalModel.java@ 34797

Last change on this file since 34797 was 34797, checked in by davidb, 3 years ago

Fixed up spelling mixup valance -> valence

File size: 6.3 KB
Line 
1package org.greenstone.mars;
2
3//import java.util.Random;
4
5import java.io.BufferedInputStream;
6import java.io.FileInputStream;
7
8import java.io.BufferedReader;
9import java.io.BufferedWriter;
10import java.io.FileReader;
11import java.io.FileWriter;
12
13import weka.core.converters.ConverterUtils.DataSource;
14import weka.core.Attribute;
15import weka.core.Instance;
16import weka.core.Instances;
17import weka.core.SerializationHelper;
18
19import weka.filters.Filter;
20import weka.filters.unsupervised.attribute.Remove;
21
22import weka.classifiers.Classifier;
23
24
25// Based on:
26// https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
27
28class WekaApplyArousalModel
29{
30
31 //public final static String PREDICT_ATTRIBUTE_NAME = WekaUtil.AROUSAL_ATTRIBUTE_NAME;
32
33
34
35 public static void main(String[] args)
36 {
37 /*
38 if (args.length != 3) {
39 System.err.println("Error: incorrect number of command-line arguments");
40 System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}");
41 System.exit(1);
42 }
43 */
44
45 WekaUtil.checkUsageApplyModel(args);
46
47 String classifier_input_filename = args[0];
48 String unclassified_data_input_filename = args[1];
49 String classified_data_output_filename = args[2];
50
51
52 Classifier classifier = WekaUtil.loadClassifierModel(classifier_input_filename);
53
54 /*
55 System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
56
57 try {
58 FileInputStream fis = new FileInputStream(classifier_input_filename);
59 BufferedInputStream bis= new BufferedInputStream(fis);
60 Classifier classifier = (Classifier)SerializationHelper.read(bis);
61
62 //
63 // Load in unlabeled data
64 //
65
66 System.out.println("Loading unlabeled instances: " + unclassified_data_input_filename);
67
68 DataSource data_source = new DataSource(unclassified_data_input_filename);
69 Instances unlabeled_instances = data_source.getDataSet();
70
71 // Work out if we're dealing with a ground-truth ARFF file or not
72 // (i.e. already has the desired attribute)
73
74 Instances groundtruth_instances = null;
75 Attribute predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME);
76
77 if (predict_attribute == null) {
78
79 unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,null); // no additional top-up to remove
80 int num_attributes = unlabeled_instances.numAttributes();
81
82 Attribute arousal_attribute = new Attribute(WekaUtil.AROUSAL_ATTRIBUTE_NAME);
83 unlabeled_instances.insertAttributeAt(arousal_attribute,num_attributes);
84 num_attributes++;
85 }
86 else {
87 // Dealing with ground-truth data:
88 // => already has 'arousal' attribute
89 // => in fact has 'valence' attribute too, which we want to remove
90
91 unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,"472"); // top-up with removal of 'valence'
92
93 // reference share this as 'groundtruth_instances' to trigger error calculation and output
94 groundtruth_instances = unlabeled_instances;
95 }
96
97 int num_attributes = unlabeled_instances.numAttributes();
98
99 // Set class attribute
100 unlabeled_instances.setClassIndex(num_attributes - 1);
101
102 WekaUtil.checkDatasetInstancesCompatible(unlabeled_instances);
103
104
105
106
107
108 // Create copy where the predictions are to be made
109 Instances labeled_instances = new Instances(unlabeled_instances);
110
111 //
112 // Label instances
113 //
114
115 final int num_instances = unlabeled_instances.numInstances();
116 for (int i=0; i<num_instances; i++) {
117 Instance unlabeled_instance = unlabeled_instances.instance(i);
118
119 System.out.print("Making prediction for: " + i + "/" + num_instances);
120 double classified_value = classifier.classifyInstance(unlabeled_instance);
121 labeled_instances.instance(i).setClassValue(classified_value);
122
123 String formatted_classified_value = String.format("% 06.3f", classified_value);
124
125 System.out.print(" value = " + formatted_classified_value);
126
127 if (groundtruth_instances != null) {
128 Instance gt_instance = groundtruth_instances.instance(i);
129 double gt_class_value = gt_instance.classValue();
130 double error = Math.abs(classified_value - gt_class_value);
131
132 String formatted_error = String.format("%.3f", error);
133 System.out.print(" [error: " + formatted_error + "]");
134 }
135 System.out.println();
136 }
137
138 //
139 // Save labeled data
140 //
141
142 System.out.println("Saving labeled instances: " + classified_data_output_filename);
143 FileWriter fw = new FileWriter(classified_data_output_filename);
144 BufferedWriter bw = new BufferedWriter(fw);
145
146 bw.write(labeled_instances.toString());
147 bw.newLine();
148 bw.flush();
149 bw.close();
150
151 }
152 catch (Exception e) {
153 e.printStackTrace();
154 }
155
156 */
157
158
159 Instances unlabeled_instances= WekaUtil.loadInstancesForClassification(unclassified_data_input_filename);
160
161 // It is permissible to run this code and supply it with a data file that includes groundtruth in it.
162 // In this situation, the 'unlabeled' instances:
163 // (i) need to be massaged to be in the same form as truly unlabeled data
164 // (ii) we also set up 'groundtruth_instances' as an alias (reference) to 'filtered_unlabeled_instanced'
165 // to trigger calculating the error on the predicted vaues
166
167 boolean has_groundtruth_data = WekaUtil.instancesHavePredictAttribute(unlabeled_instances,WekaUtil.AROUSAL_ATTRIBUTE_NAME);
168
169 // The following deals with (i) internally, ensuring that what is returned is suitable for making predictions on
170 Instances filtered_unlabeled_instances
171 = WekaUtil.filterInstancesForApplying(unlabeled_instances,has_groundtruth_data,
172 WekaUtil.AROUSAL_ATTRIBUTE_NAME,"472");
173
174 // The following deals with (ii)
175 Instances groundtruth_instances = (has_groundtruth_data) ? filtered_unlabeled_instances : null;
176
177 Instances labeled_instances = WekaUtil.makePredictions(classifier, filtered_unlabeled_instances, groundtruth_instances);
178
179 try {
180 // Save labeled data
181
182 System.out.println("Saving labeled instances: " + classified_data_output_filename);
183 FileWriter fw = new FileWriter(classified_data_output_filename);
184 BufferedWriter bw = new BufferedWriter(fw);
185
186 bw.write(labeled_instances.toString());
187 bw.newLine();
188 bw.flush();
189 bw.close();
190
191 }
192 catch (Exception e) {
193 e.printStackTrace();
194 }
195
196 }
197}
Note: See TracBrowser for help on using the repository browser.