1 | /*
|
---|
2 | * WekaDBWrapper.java
|
---|
3 | * Copyright (C) 2011 New Zealand Digital Library, http://www.nzdl.org
|
---|
4 | *
|
---|
5 | * This program is free software; you can redistribute it and/or modify
|
---|
6 | * it under the terms of the GNU General Public License as published by
|
---|
7 | * the Free Software Foundation; either version 2 of the License, or
|
---|
8 | * (at your option) any later version.
|
---|
9 | *
|
---|
10 | * This program is distributed in the hope that it will be useful,
|
---|
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
13 | * GNU General Public License for more details.
|
---|
14 | *
|
---|
15 | * You should have received a copy of the GNU General Public License
|
---|
16 | * along with this program; if not, write to the Free Software
|
---|
17 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
|
---|
18 | */
|
---|
19 | package org.greenstone.gsdl3.util;
|
---|
20 |
|
---|
21 | import java.io.*;
|
---|
22 | import java.util.Vector;
|
---|
23 | import java.util.Collections;
|
---|
24 | import java.util.regex.Pattern;
|
---|
25 | import java.util.regex.Matcher;
|
---|
26 |
|
---|
27 | import weka.core.Attribute;
|
---|
28 | import weka.core.Instance;
|
---|
29 | import weka.core.Instances;
|
---|
30 |
|
---|
31 | import org.apache.log4j.*;
|
---|
32 |
|
---|
33 | import org.greenstone.gsdl3.util.WekaFindInstanceKNN;
|
---|
34 |
|
---|
35 | /** Java wrapper class for access to the Weka
|
---|
36 | * Devised (in the first instance) to operate as: java -jar weka.jar <arg1> <arg2>
|
---|
37 | *
|
---|
38 | * Inspired by MGSearchWrapper.java
|
---|
39 | */
|
---|
40 |
|
---|
41 | public class WekaDBWrapper
|
---|
42 | {
|
---|
43 | /** the query result, filled in by runQuery */
|
---|
44 | protected Vector query_result_;
|
---|
45 |
|
---|
46 | protected int offset_ = 100;
|
---|
47 | protected int length_ = 20; // **** Unused
|
---|
48 |
|
---|
49 | // Approximate matching not yet utilized
|
---|
50 | protected double radius_; // **** Unused
|
---|
51 |
|
---|
52 | protected int max_docs_;
|
---|
53 |
|
---|
54 | protected double arousal_;
|
---|
55 | protected double valence_;
|
---|
56 |
|
---|
57 | static Logger logger = Logger.getLogger (org.greenstone.gsdl3.util.WekaDBWrapper.class.getName ());
|
---|
58 |
|
---|
59 | public WekaDBWrapper() {
|
---|
60 | query_result_ = null;
|
---|
61 | }
|
---|
62 |
|
---|
63 | // query param methods
|
---|
64 |
|
---|
65 | /** start point (offset) into the array of feature vectors for a track
|
---|
66 | - 100 by default which equals 10 seconds (assuming 0.1 frame size) */
|
---|
67 | public void setOffset(int offset) {
|
---|
68 | offset_ = offset;
|
---|
69 | }
|
---|
70 |
|
---|
71 | /** the number of consecutive frames used in match
|
---|
72 | - 20 by default which equals 2 seconds (assuming 0.1 frame size) */
|
---|
73 | public void setLength(int length) {
|
---|
74 | length_ = length;
|
---|
75 | }
|
---|
76 |
|
---|
77 | /** distance used in approximate matching support - default is 50 */
|
---|
78 | public void setRadius(double radius) {
|
---|
79 | radius_ = radius;
|
---|
80 | }
|
---|
81 |
|
---|
82 | public void setMaxDocs(int max_docs) {
|
---|
83 | max_docs_ = max_docs;
|
---|
84 | }
|
---|
85 |
|
---|
86 | public void setArousal(double arousal) {
|
---|
87 | arousal_ = arousal;
|
---|
88 | }
|
---|
89 | public void setValence(double valence) {
|
---|
90 | valence_ = valence;
|
---|
91 | }
|
---|
92 |
|
---|
93 | /** returns a string with all the current query param settings */
|
---|
94 | // the following was in MG version, do we need this in WekaDB version? // ****
|
---|
95 | //public String getQueryParams() {}
|
---|
96 |
|
---|
97 |
|
---|
98 | protected boolean addQueryResult(boolean first_entry, String doc_id,
|
---|
99 | Vector<Double> rankVector, Vector<Integer> offsetVector)
|
---|
100 | {
|
---|
101 |
|
---|
102 | if (first_entry) {
|
---|
103 | WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,rankVector,offsetVector);
|
---|
104 | query_result_.add(wekaDB_doc_info);
|
---|
105 | first_entry = false;
|
---|
106 | }
|
---|
107 | else {
|
---|
108 | double rank = rankVector.get(0);
|
---|
109 | int offset = offsetVector.get(0);
|
---|
110 | WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,rank,offset);
|
---|
111 |
|
---|
112 | query_result_.add(wekaDB_doc_info);
|
---|
113 | }
|
---|
114 |
|
---|
115 | return first_entry;
|
---|
116 | }
|
---|
117 |
|
---|
118 |
|
---|
119 | /** actually carry out the query.
|
---|
120 | Use the set methods to set query results.
|
---|
121 | Writes the result to query_result.
|
---|
122 | * - maintains state between requests as can be slow
|
---|
123 | * base_dir and index_path should join together to provide
|
---|
124 | * the absolute location of the mg index files eg ..../index/dtx/demo
|
---|
125 | * base_dir must end with a file separator (OS dependant)
|
---|
126 | */
|
---|
127 |
|
---|
128 |
|
---|
129 | public void runQuery(String wekaDB_index_dir, String knn_model_file,
|
---|
130 | String assoc_index_dir, String query_string) {
|
---|
131 |
|
---|
132 | // combine index_dir with audiodb fileanem
|
---|
133 |
|
---|
134 | String full_knn_model_filename = wekaDB_index_dir + File.separatorChar + knn_model_file;
|
---|
135 |
|
---|
136 | //String full_chr12_filename = assoc_index_dir + File.separatorChar
|
---|
137 | // + query_string + File.separatorChar + "doc.chr12";
|
---|
138 |
|
---|
139 | System.err.println("**** full knn model filename = " + full_knn_model_filename);
|
---|
140 |
|
---|
141 | // Example returned result from Weka KNN
|
---|
142 | // => first line is the input instance ('filename+segment',Arousal,Valence)
|
---|
143 | // following (indented lines) nearest neighbour matches in same format
|
---|
144 | //
|
---|
145 | // ds_22716_5743-6,-0.549489,-0.118439
|
---|
146 | // ds_22716_5743-6,-0.549489,-0.118439
|
---|
147 | // ds_31008_6550-30,-0.549489,-0.118439
|
---|
148 | // ds_72651_26831-6,-0.549489,-0.118439
|
---|
149 | // ds_26196_9214-18,-0.549489,-0.118439
|
---|
150 |
|
---|
151 |
|
---|
152 | WekaFindInstanceKNN.init(full_knn_model_filename);
|
---|
153 |
|
---|
154 | String doc_id = query_string;
|
---|
155 | int segment = offset_;
|
---|
156 |
|
---|
157 | String query_doc_id_segment = doc_id + "-" + segment;
|
---|
158 |
|
---|
159 | double query_arousal_val = arousal_;
|
---|
160 | double query_valence_val = valence_;
|
---|
161 |
|
---|
162 | int k_nearest_num = max_docs_;
|
---|
163 | int expanded_k_nearest_num = max_docs_ * 5; // * internally get more matches, then sift through to arrive at the best 'max_docs_'
|
---|
164 |
|
---|
165 | Pattern doc_seg_re = Pattern.compile("^(\\w+)-(\\d+)$");
|
---|
166 | //Matcher query_doc_seg_match = doc_seq_re.matcher(query_doc_id_segment);
|
---|
167 |
|
---|
168 | Instances nearest_instances
|
---|
169 | = WekaFindInstanceKNN.kNearestNeighbours(query_doc_id_segment,query_arousal_val,query_valence_val,k_nearest_num);
|
---|
170 |
|
---|
171 |
|
---|
172 | Vector expanded_query_result = new Vector();
|
---|
173 |
|
---|
174 | int nearest_instances_len = nearest_instances.size();
|
---|
175 |
|
---|
176 | int clamped_expanded_k_nearest_num = Math.min(expanded_k_nearest_num,nearest_instances_len);
|
---|
177 |
|
---|
178 | double pos_penalty = 0.1;
|
---|
179 |
|
---|
180 | for (int ei=0; ei<clamped_expanded_k_nearest_num; ei++) {
|
---|
181 | Instance instance = nearest_instances.instance(ei);
|
---|
182 | logger.info("\tProcessing returned instance: " + instance);
|
---|
183 |
|
---|
184 | String matching_doc_id_segment = instance.stringValue(0);
|
---|
185 |
|
---|
186 |
|
---|
187 | //Pattern p = Pattern.compile("^(\\w+)-(\\d+)$");
|
---|
188 | Matcher m = doc_seg_re.matcher(matching_doc_id_segment);
|
---|
189 | if (m.matches()) {
|
---|
190 |
|
---|
191 | String matching_doc_id = m.group(1);
|
---|
192 | int matching_segment_offset = Integer.parseInt(m.group(2));
|
---|
193 |
|
---|
194 | if (matching_doc_id.equals(doc_id)) {
|
---|
195 | continue;
|
---|
196 | }
|
---|
197 |
|
---|
198 | double matching_arousal_val = instance.value(1);
|
---|
199 | double matching_valence_val = instance.value(2);
|
---|
200 |
|
---|
201 | double matching_diff = (Math.abs(query_arousal_val - matching_arousal_val)
|
---|
202 | + Math.abs(query_valence_val - matching_valence_val))/4.0;
|
---|
203 | double matching_rank = 1.0 - matching_diff - (pos_penalty * (double)ei);
|
---|
204 |
|
---|
205 | logger.info("\tAdding in: matching_doc_id = " + matching_doc_id);
|
---|
206 | WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(matching_doc_id,matching_rank,matching_segment_offset);
|
---|
207 | expanded_query_result.add(wekaDB_doc_info);
|
---|
208 | }
|
---|
209 | else {
|
---|
210 | logger.error("Returned AV k-nearest neighbour match '"+matching_doc_id_segment+"' could not be parsed as <doc-id>-<segment>" );
|
---|
211 | }
|
---|
212 | }
|
---|
213 |
|
---|
214 | query_result_ = new Vector();
|
---|
215 |
|
---|
216 | int i = 0;
|
---|
217 | while (i < k_nearest_num) {
|
---|
218 | if (i >= expanded_query_result.size()) {
|
---|
219 | break;
|
---|
220 | }
|
---|
221 |
|
---|
222 | query_result_.add(expanded_query_result.get(i));
|
---|
223 | i++;
|
---|
224 | }
|
---|
225 |
|
---|
226 | Collections.sort(query_result_);
|
---|
227 | }
|
---|
228 |
|
---|
229 | public void runQueryOLD(String wekaDB_index_dir, String knn_model_file,
|
---|
230 | String assoc_index_dir, String query_string) {
|
---|
231 |
|
---|
232 | // combine index_dir with audiodb fileanem
|
---|
233 |
|
---|
234 | String full_knn_model_filename = wekaDB_index_dir + File.separatorChar + knn_model_file;
|
---|
235 | String full_chr12_filename = assoc_index_dir + File.separatorChar
|
---|
236 | + query_string + File.separatorChar + "doc.chr12";
|
---|
237 |
|
---|
238 | int num_matches_within_track = 6;
|
---|
239 |
|
---|
240 | // ****
|
---|
241 | String [] cmd_array = new String[] {
|
---|
242 | "java", "-jar", "weka.jar",
|
---|
243 | "-d", full_knn_model_filename,
|
---|
244 | "-Q", "nsequence",
|
---|
245 | "-p", String.format("%d",offset_),
|
---|
246 | "-n", String.format("%d",num_matches_within_track),
|
---|
247 | "-l", String.format("%d",length_),
|
---|
248 | "-r", String.format("%d",max_docs_),
|
---|
249 | "-f", full_chr12_filename
|
---|
250 | };
|
---|
251 |
|
---|
252 | System.err.println("**** cmd_array = " + String.join(" ", cmd_array));
|
---|
253 |
|
---|
254 | Runtime runtime = Runtime.getRuntime();
|
---|
255 | try {
|
---|
256 | Process wekaDB_proc = runtime.exec(cmd_array);
|
---|
257 | //int exitVal = wekaDB_proc.waitFor();
|
---|
258 | //System.err.println("*** exit status = " + exitVal);
|
---|
259 |
|
---|
260 | InputStream wis = wekaDB_proc.getInputStream();
|
---|
261 | InputStreamReader wisr = new InputStreamReader(wis);
|
---|
262 | BufferedReader wbr = new BufferedReader(wisr);
|
---|
263 |
|
---|
264 | query_result_ = new Vector();
|
---|
265 |
|
---|
266 | boolean first_entry = true;
|
---|
267 | int line_count = 0;
|
---|
268 |
|
---|
269 | String root_doc_id = null;
|
---|
270 | Vector<Double> rankVector = new Vector<Double>();
|
---|
271 | Vector<Integer> offsetVector = new Vector<Integer>();
|
---|
272 |
|
---|
273 | // Example output
|
---|
274 | // D8 0.00105175
|
---|
275 | // 1.69786e-16 392 392
|
---|
276 | // 0.00113568 392 673
|
---|
277 | // 0.00127239 392 910
|
---|
278 | // 0.00139736 392 481
|
---|
279 | // 0.00145331 392 303
|
---|
280 | // D2 0.00429758
|
---|
281 | // 0.00403335 392 865
|
---|
282 | // 0.00411288 392 458
|
---|
283 | // 0.00442461 392 866
|
---|
284 | // 0.00444272 392 864
|
---|
285 | // 0.00447434 392 424
|
---|
286 | // ...
|
---|
287 |
|
---|
288 | String line;
|
---|
289 | while ((line = wbr.readLine()) != null) {
|
---|
290 | String[] tokens = line.split("\\s+");
|
---|
291 | line_count++;
|
---|
292 |
|
---|
293 | if (tokens.length==2) {
|
---|
294 | // processing a top-level doc line
|
---|
295 |
|
---|
296 | if (line_count>1) {
|
---|
297 | // struck new top-level entry => store vector vals for previous block
|
---|
298 |
|
---|
299 | first_entry = addQueryResult(first_entry,root_doc_id,rankVector,offsetVector);
|
---|
300 | // and now reset vectors to empty to be ready for next chain of values
|
---|
301 | rankVector = new Vector<Double>();
|
---|
302 | offsetVector = new Vector<Integer>();
|
---|
303 | }
|
---|
304 |
|
---|
305 | root_doc_id = tokens[0];
|
---|
306 | }
|
---|
307 | else {
|
---|
308 | // should be 3 items
|
---|
309 | double euclidean_dist = Double.parseDouble(tokens[0]);
|
---|
310 | int src_frame = Integer.parseInt(tokens[1]);
|
---|
311 | int target_frame = Integer.parseInt(tokens[2]);
|
---|
312 |
|
---|
313 | // ****
|
---|
314 |
|
---|
315 | // enforce 1.0 as upper limit due to rounding errors
|
---|
316 | // in audioDB distance calculations
|
---|
317 | double rank = Math.min(1.0 - euclidean_dist,1.0);
|
---|
318 |
|
---|
319 | if ((line_count==2) && (src_frame==target_frame)) {
|
---|
320 | // Found match with self
|
---|
321 | continue;
|
---|
322 | }
|
---|
323 |
|
---|
324 | rankVector.add(rank);
|
---|
325 | offsetVector.add(target_frame);
|
---|
326 | }
|
---|
327 |
|
---|
328 | }
|
---|
329 |
|
---|
330 | addQueryResult(first_entry,root_doc_id,rankVector,offsetVector);
|
---|
331 |
|
---|
332 | wbr.close();
|
---|
333 |
|
---|
334 | // sort query_result_ on 'rank' field
|
---|
335 | // note: compareTo() method impelemented to sort into descending order
|
---|
336 |
|
---|
337 | Collections.sort(query_result_);
|
---|
338 |
|
---|
339 |
|
---|
340 | }
|
---|
341 | catch (Exception e) {
|
---|
342 | logger.error("Failed to execute the following command: " + String.join(" ", cmd_array));
|
---|
343 | e.printStackTrace();
|
---|
344 | }
|
---|
345 |
|
---|
346 | }
|
---|
347 |
|
---|
348 |
|
---|
349 | /** get the result out of the wrapper */
|
---|
350 | public Vector getQueryResult()
|
---|
351 | {
|
---|
352 | return query_result_;
|
---|
353 | }
|
---|
354 | }
|
---|
355 |
|
---|