source: gs3-extensions/mars-src/trunk/src/java/org/greenstone/gsdl3/util/WekaDBWrapper.java@ 36864

Last change on this file since 36864 was 36864, checked in by davidb, 20 months ago

Attempt at a result set that merges later matches from the same doc-id

File size: 12.3 KB
Line 
1/*
2 * WekaDBWrapper.java
3 * Copyright (C) 2011 New Zealand Digital Library, http://www.nzdl.org
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 2 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program; if not, write to the Free Software
17 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 */
19package org.greenstone.gsdl3.util;
20
21import java.io.*;
22import java.util.Vector;
23import java.util.Collections;
24import java.util.regex.Pattern;
25import java.util.regex.Matcher;
26
27import weka.core.Attribute;
28import weka.core.Instance;
29import weka.core.Instances;
30
31import org.apache.log4j.*;
32
33import org.greenstone.gsdl3.util.WekaFindInstanceKNN;
34
35/** Java wrapper class for access to the Weka
36 * Devised (in the first instance) to operate as: java -jar weka.jar <arg1> <arg2>
37 *
38 * Inspired by MGSearchWrapper.java
39 */
40
41public class WekaDBWrapper
42{
43 public final static double AV_SEGMENT_LENGTH_SECS = 6.0;
44
45 /** the query result, filled in by runQuery */
46 protected Vector query_results_;
47
48 protected int offset_ = 100;
49 protected int length_ = 20; // **** Unused
50
51 // Approximate matching not yet utilized
52 protected double radius_; // **** Unused
53
54 protected int max_docs_;
55
56 protected double arousal_;
57 protected double valence_;
58
59 static Logger logger = Logger.getLogger (org.greenstone.gsdl3.util.WekaDBWrapper.class.getName ());
60
61 public WekaDBWrapper() {
62 query_results_ = null;
63 }
64
65 // query param methods
66
67 /** start point (offset) into the array of feature vectors for a track
68 - 100 by default which equals 10 seconds (assuming 0.1 frame size) */
69 public void setOffset(int offset) {
70 offset_ = offset;
71 }
72
73 /** the number of consecutive frames used in match
74 - 20 by default which equals 2 seconds (assuming 0.1 frame size) */
75 public void setLength(int length) {
76 length_ = length;
77 }
78
79 /** distance used in approximate matching support - default is 50 */
80 public void setRadius(double radius) {
81 radius_ = radius;
82 }
83
84 public void setMaxDocs(int max_docs) {
85 max_docs_ = max_docs;
86 }
87
88 public void setArousal(double arousal) {
89 arousal_ = arousal;
90 }
91 public void setValence(double valence) {
92 valence_ = valence;
93 }
94
95 /** returns a string with all the current query param settings */
96 // the following was in MG version, do we need this in WekaDB version? // ****
97 //public String getQueryParams() {}
98
99
100 protected boolean addQueryResult(boolean first_entry, String doc_id,
101 Vector<Double> rankVector, Vector<Integer> offsetVector)
102 {
103
104 if (first_entry) {
105 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,rankVector,offsetVector);
106 query_results_.add(wekaDB_doc_info);
107 first_entry = false;
108 }
109 else {
110 double rank = rankVector.get(0);
111 int offset = offsetVector.get(0);
112 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,rank,offset);
113
114 query_results_.add(wekaDB_doc_info);
115 }
116
117 return first_entry;
118 }
119
120
121 protected int mergeResultDoc(Vector query_results, WekaDBDocInfo new_doc_info, double inc_rank_val)
122 {
123 int merged = 0;
124
125 String new_doc_id = new_doc_info.getDocID();
126
127 final int query_results_len = query_results.size();
128
129 for (int i=0; i<query_results_len; i++) {
130 WekaDBDocInfo existing_doc_info = (WekaDBDocInfo)query_results.get(i);
131
132 String existing_doc_id = existing_doc_info.getDocID();
133 if (new_doc_id.equals(existing_doc_id)) {
134 merged = 1;
135 existing_doc_info.incTopRank(inc_rank_val);
136 break;
137 }
138 }
139
140 if (merged == 0) {
141 query_results.add(new_doc_info);
142 }
143
144 return merged;
145 }
146
147 /** actually carry out the query.
148 Use the set methods to set query results.
149 Writes the result to query_results.
150 * - maintains state between requests as can be slow
151 * base_dir and index_path should join together to provide
152 * the absolute location of the mg index files eg ..../index/dtx/demo
153 * base_dir must end with a file separator (OS dependant)
154 */
155
156
157 public void runQuery(String wekaDB_index_dir, String knn_model_file,
158 String assoc_index_dir, String query_string) {
159
160 // combine index_dir with audiodb fileanem
161
162 String full_knn_model_filename = wekaDB_index_dir + File.separatorChar + knn_model_file;
163
164 //String full_chr12_filename = assoc_index_dir + File.separatorChar
165 // + query_string + File.separatorChar + "doc.chr12";
166
167 System.err.println("**** full knn model filename = " + full_knn_model_filename);
168
169 // Example returned result from Weka KNN
170 // => first line is the input instance ('filename+segment',Arousal,Valence)
171 // following (indented lines) nearest neighbour matches in same format
172 //
173 // ds_22716_5743-6,-0.549489,-0.118439
174 // ds_22716_5743-6,-0.549489,-0.118439
175 // ds_31008_6550-30,-0.549489,-0.118439
176 // ds_72651_26831-6,-0.549489,-0.118439
177 // ds_26196_9214-18,-0.549489,-0.118439
178
179
180 WekaFindInstanceKNN.init(full_knn_model_filename);
181
182 String doc_id = query_string;
183 int segment = offset_;
184
185 String query_doc_id_segment = doc_id + "-" + segment;
186
187 double query_arousal_val = arousal_;
188 double query_valence_val = valence_;
189
190 int k_nearest_num = max_docs_;
191 int expanded_k_nearest_num = max_docs_ * 5; // * internally get more matches, then sift through to arrive at the best 'max_docs_'
192
193 Pattern doc_seg_re = Pattern.compile("^(\\w+)-(\\d+)$");
194 //Matcher query_doc_seg_match = doc_seq_re.matcher(query_doc_id_segment);
195
196 Instances nearest_instances
197 = WekaFindInstanceKNN.kNearestNeighbours(query_doc_id_segment,query_arousal_val,query_valence_val,k_nearest_num);
198
199
200 Vector expanded_query_results = new Vector();
201
202 int nearest_instances_len = nearest_instances.size();
203
204 int clamped_expanded_k_nearest_num = Math.min(expanded_k_nearest_num,nearest_instances_len);
205
206 double pos_penalty = 0.1;
207 int topup_count = 0;
208
209 for (int ei=0; ei<clamped_expanded_k_nearest_num; ei++) {
210 Instance instance = nearest_instances.instance(ei);
211 logger.info("\tProcessing returned instance: " + instance);
212
213 String matching_doc_id_segment = instance.stringValue(0);
214
215 //Pattern p = Pattern.compile("^(\\w+)-(\\d+)$");
216 Matcher m = doc_seg_re.matcher(matching_doc_id_segment);
217 if (m.matches()) {
218
219 String matching_doc_id = m.group(1);
220 int end_of_matching_segment_offset = Integer.parseInt(m.group(2));
221 //int matching_segment_offset = end_of_matching_segment_offset - (int)AV_SEGMENT_LENGTH_SECS;
222 int matching_segment_offset = end_of_matching_segment_offset;
223
224 if (matching_doc_id.equals(doc_id)) {
225 // don't add in matches that come from a matching segment in the query doc
226 continue;
227 }
228
229 double matching_arousal_val = instance.value(1);
230 double matching_valence_val = instance.value(2);
231
232 double matching_diff = (Math.abs(query_arousal_val - matching_arousal_val)
233 + Math.abs(query_valence_val - matching_valence_val))/4.0;
234 double matching_rank = 1.0 - matching_diff - (pos_penalty * (double)ei);
235
236 logger.info("\tAdding in: matching_doc_id = " + matching_doc_id);
237 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(matching_doc_id,matching_rank,matching_segment_offset);
238
239 //expanded_query_results.add(wekaDB_doc_info);
240
241 double inc_rank_val = matching_rank / (double)(topup_count+2); // starts to a 50% (/2) weighting when topup_count == 0
242 int merged = mergeResultDoc(expanded_query_results,wekaDB_doc_info,inc_rank_val);
243
244 topup_count += merged;
245
246 if ((expanded_query_results.size() > k_nearest_num) && (topup_count > k_nearest_num)) {
247 // guard to stop multiple recurring matches in the same doc dominationg the rank_val
248 break;
249 }
250 }
251 else {
252 logger.error("Returned AV k-nearest neighbour match '"+matching_doc_id_segment+"' could not be parsed as <doc-id>-<segment>" );
253 }
254 }
255
256 Collections.sort(expanded_query_results);
257
258 query_results_ = new Vector();
259
260 int i = 0;
261 while (i < k_nearest_num) {
262 if (i >= expanded_query_results.size()) {
263 break;
264 }
265
266 query_results_.add(expanded_query_results.get(i));
267 i++;
268 }
269
270 //Collections.sort(query_results_);
271 }
272
273 public void runQueryOLD(String wekaDB_index_dir, String knn_model_file,
274 String assoc_index_dir, String query_string) {
275
276 // combine index_dir with audiodb fileanem
277
278 String full_knn_model_filename = wekaDB_index_dir + File.separatorChar + knn_model_file;
279 String full_chr12_filename = assoc_index_dir + File.separatorChar
280 + query_string + File.separatorChar + "doc.chr12";
281
282 int num_matches_within_track = 6;
283
284 // ****
285 String [] cmd_array = new String[] {
286 "java", "-jar", "weka.jar",
287 "-d", full_knn_model_filename,
288 "-Q", "nsequence",
289 "-p", String.format("%d",offset_),
290 "-n", String.format("%d",num_matches_within_track),
291 "-l", String.format("%d",length_),
292 "-r", String.format("%d",max_docs_),
293 "-f", full_chr12_filename
294 };
295
296 System.err.println("**** cmd_array = " + String.join(" ", cmd_array));
297
298 Runtime runtime = Runtime.getRuntime();
299 try {
300 Process wekaDB_proc = runtime.exec(cmd_array);
301 //int exitVal = wekaDB_proc.waitFor();
302 //System.err.println("*** exit status = " + exitVal);
303
304 InputStream wis = wekaDB_proc.getInputStream();
305 InputStreamReader wisr = new InputStreamReader(wis);
306 BufferedReader wbr = new BufferedReader(wisr);
307
308 query_results_ = new Vector();
309
310 boolean first_entry = true;
311 int line_count = 0;
312
313 String root_doc_id = null;
314 Vector<Double> rankVector = new Vector<Double>();
315 Vector<Integer> offsetVector = new Vector<Integer>();
316
317 // Example output
318 // D8 0.00105175
319 // 1.69786e-16 392 392
320 // 0.00113568 392 673
321 // 0.00127239 392 910
322 // 0.00139736 392 481
323 // 0.00145331 392 303
324 // D2 0.00429758
325 // 0.00403335 392 865
326 // 0.00411288 392 458
327 // 0.00442461 392 866
328 // 0.00444272 392 864
329 // 0.00447434 392 424
330 // ...
331
332 String line;
333 while ((line = wbr.readLine()) != null) {
334 String[] tokens = line.split("\\s+");
335 line_count++;
336
337 if (tokens.length==2) {
338 // processing a top-level doc line
339
340 if (line_count>1) {
341 // struck new top-level entry => store vector vals for previous block
342
343 first_entry = addQueryResult(first_entry,root_doc_id,rankVector,offsetVector);
344 // and now reset vectors to empty to be ready for next chain of values
345 rankVector = new Vector<Double>();
346 offsetVector = new Vector<Integer>();
347 }
348
349 root_doc_id = tokens[0];
350 }
351 else {
352 // should be 3 items
353 double euclidean_dist = Double.parseDouble(tokens[0]);
354 int src_frame = Integer.parseInt(tokens[1]);
355 int target_frame = Integer.parseInt(tokens[2]);
356
357 // ****
358
359 // enforce 1.0 as upper limit due to rounding errors
360 // in audioDB distance calculations
361 double rank = Math.min(1.0 - euclidean_dist,1.0);
362
363 if ((line_count==2) && (src_frame==target_frame)) {
364 // Found match with self
365 continue;
366 }
367
368 rankVector.add(rank);
369 offsetVector.add(target_frame);
370 }
371
372 }
373
374 addQueryResult(first_entry,root_doc_id,rankVector,offsetVector);
375
376 wbr.close();
377
378 // sort query_results_ on 'rank' field
379 // note: compareTo() method impelemented to sort into descending order
380
381 Collections.sort(query_results_);
382
383
384 }
385 catch (Exception e) {
386 logger.error("Failed to execute the following command: " + String.join(" ", cmd_array));
387 e.printStackTrace();
388 }
389
390 }
391
392
393 /** get the result out of the wrapper */
394 public Vector getQueryResult()
395 {
396 return query_results_;
397 }
398}
399
Note: See TracBrowser for help on using the repository browser.