source: gs3-extensions/mars-src/trunk/src/java/org/greenstone/gsdl3/util/WekaDBWrapper.java@ 36859

Last change on this file since 36859 was 36859, checked in by davidb, 19 months ago

Coding developments that mean param passed arousal and valence values not used; query_resutls_ capped to max_docs_

File size: 10.8 KB
Line 
1/*
2 * WekaDBWrapper.java
3 * Copyright (C) 2011 New Zealand Digital Library, http://www.nzdl.org
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 2 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program; if not, write to the Free Software
17 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 */
19package org.greenstone.gsdl3.util;
20
21import java.io.*;
22import java.util.Vector;
23import java.util.Collections;
24import java.util.regex.Pattern;
25import java.util.regex.Matcher;
26
27import weka.core.Attribute;
28import weka.core.Instance;
29import weka.core.Instances;
30
31import org.apache.log4j.*;
32
33import org.greenstone.gsdl3.util.WekaFindInstanceKNN;
34
35/** Java wrapper class for access to the Weka
36 * Devised (in the first instance) to operate as: java -jar weka.jar <arg1> <arg2>
37 *
38 * Inspired by MGSearchWrapper.java
39 */
40
41public class WekaDBWrapper
42{
43 /** the query result, filled in by runQuery */
44 protected Vector query_result_;
45
46 protected int offset_ = 100;
47 protected int length_ = 20; // **** Unused
48
49 // Approximate matching not yet utilized
50 protected double radius_; // **** Unused
51
52 protected int max_docs_;
53
54 protected double arousal_;
55 protected double valence_;
56
57 static Logger logger = Logger.getLogger (org.greenstone.gsdl3.util.WekaDBWrapper.class.getName ());
58
59 public WekaDBWrapper() {
60 query_result_ = null;
61 }
62
63 // query param methods
64
65 /** start point (offset) into the array of feature vectors for a track
66 - 100 by default which equals 10 seconds (assuming 0.1 frame size) */
67 public void setOffset(int offset) {
68 offset_ = offset;
69 }
70
71 /** the number of consecutive frames used in match
72 - 20 by default which equals 2 seconds (assuming 0.1 frame size) */
73 public void setLength(int length) {
74 length_ = length;
75 }
76
77 /** distance used in approximate matching support - default is 50 */
78 public void setRadius(double radius) {
79 radius_ = radius;
80 }
81
82 public void setMaxDocs(int max_docs) {
83 max_docs_ = max_docs;
84 }
85
86 public void setArousal(double arousal) {
87 arousal_ = arousal;
88 }
89 public void setValence(double valence) {
90 valence_ = valence;
91 }
92
93 /** returns a string with all the current query param settings */
94 // the following was in MG version, do we need this in WekaDB version? // ****
95 //public String getQueryParams() {}
96
97
98 protected boolean addQueryResult(boolean first_entry, String doc_id,
99 Vector<Double> rankVector, Vector<Integer> offsetVector)
100 {
101
102 if (first_entry) {
103 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,rankVector,offsetVector);
104 query_result_.add(wekaDB_doc_info);
105 first_entry = false;
106 }
107 else {
108 double rank = rankVector.get(0);
109 int offset = offsetVector.get(0);
110 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,rank,offset);
111
112 query_result_.add(wekaDB_doc_info);
113 }
114
115 return first_entry;
116 }
117
118
119 /** actually carry out the query.
120 Use the set methods to set query results.
121 Writes the result to query_result.
122 * - maintains state between requests as can be slow
123 * base_dir and index_path should join together to provide
124 * the absolute location of the mg index files eg ..../index/dtx/demo
125 * base_dir must end with a file separator (OS dependant)
126 */
127
128
129 public void runQuery(String wekaDB_index_dir, String knn_model_file,
130 String assoc_index_dir, String query_string) {
131
132 // combine index_dir with audiodb fileanem
133
134 String full_knn_model_filename = wekaDB_index_dir + File.separatorChar + knn_model_file;
135
136 //String full_chr12_filename = assoc_index_dir + File.separatorChar
137 // + query_string + File.separatorChar + "doc.chr12";
138
139 System.err.println("**** full knn model filename = " + full_knn_model_filename);
140
141 // Example returned result from Weka KNN
142 // => first line is the input instance ('filename+segment',Arousal,Valence)
143 // following (indented lines) nearest neighbour matches in same format
144 //
145 // ds_22716_5743-6,-0.549489,-0.118439
146 // ds_22716_5743-6,-0.549489,-0.118439
147 // ds_31008_6550-30,-0.549489,-0.118439
148 // ds_72651_26831-6,-0.549489,-0.118439
149 // ds_26196_9214-18,-0.549489,-0.118439
150
151
152 WekaFindInstanceKNN.init(full_knn_model_filename);
153
154 String doc_id = query_string;
155 int segment = offset_;
156
157 String query_doc_id_segment = doc_id + "-" + segment;
158
159 double query_arousal_val = arousal_;
160 double query_valence_val = valence_;
161
162 int k_nearest_num = max_docs_;
163 int expanded_k_nearest_num = max_docs_ * 5; // * internally get more matches, then sift through to arrive at the best 'max_docs_'
164
165 Pattern doc_seg_re = Pattern.compile("^(\\w+)-(\\d+)$");
166 //Matcher query_doc_seg_match = doc_seq_re.matcher(query_doc_id_segment);
167
168 Instances nearest_instances
169 = WekaFindInstanceKNN.kNearestNeighbours(query_doc_id_segment,query_arousal_val,query_valence_val,k_nearest_num);
170
171
172 Vector expanded_query_result = new Vector();
173
174 int nearest_instances_len = nearest_instances.size();
175
176 int clamped_expanded_k_nearest_num = Math.min(expanded_k_nearest_num,nearest_instances_len);
177
178 double pos_penalty = 0.1;
179
180 for (int ei=0; ei<clamped_expanded_k_nearest_num; ei++) {
181 Instance instance = nearest_instances.instance(ei);
182 logger.info("\tProcessing returned instance: " + instance);
183
184 String matching_doc_id_segment = instance.stringValue(0);
185
186
187 //Pattern p = Pattern.compile("^(\\w+)-(\\d+)$");
188 Matcher m = doc_seg_re.matcher(matching_doc_id_segment);
189 if (m.matches()) {
190
191 String matching_doc_id = m.group(1);
192 int matching_segment_offset = Integer.parseInt(m.group(2));
193
194 if (matching_doc_id.equals(doc_id)) {
195 continue;
196 }
197
198 double matching_arousal_val = instance.value(1);
199 double matching_valence_val = instance.value(2);
200
201 double matching_diff = (Math.abs(query_arousal_val - matching_arousal_val)
202 + Math.abs(query_valence_val - matching_valence_val))/4.0;
203 double matching_rank = 1.0 - matching_diff - (pos_penalty * (double)ei);
204
205 logger.info("\tAdding in: matching_doc_id = " + matching_doc_id);
206 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(matching_doc_id,matching_rank,matching_segment_offset);
207 expanded_query_result.add(wekaDB_doc_info);
208 }
209 else {
210 logger.error("Returned AV k-nearest neighbour match '"+matching_doc_id_segment+"' could not be parsed as <doc-id>-<segment>" );
211 }
212 }
213
214 query_result_ = new Vector();
215
216 int i = 0;
217 while (i < k_nearest_num) {
218 if (i >= expanded_query_result.size()) {
219 break;
220 }
221
222 query_result_.add(expanded_query_result.get(i));
223 i++;
224 }
225
226 Collections.sort(query_result_);
227 }
228
229 public void runQueryOLD(String wekaDB_index_dir, String knn_model_file,
230 String assoc_index_dir, String query_string) {
231
232 // combine index_dir with audiodb fileanem
233
234 String full_knn_model_filename = wekaDB_index_dir + File.separatorChar + knn_model_file;
235 String full_chr12_filename = assoc_index_dir + File.separatorChar
236 + query_string + File.separatorChar + "doc.chr12";
237
238 int num_matches_within_track = 6;
239
240 // ****
241 String [] cmd_array = new String[] {
242 "java", "-jar", "weka.jar",
243 "-d", full_knn_model_filename,
244 "-Q", "nsequence",
245 "-p", String.format("%d",offset_),
246 "-n", String.format("%d",num_matches_within_track),
247 "-l", String.format("%d",length_),
248 "-r", String.format("%d",max_docs_),
249 "-f", full_chr12_filename
250 };
251
252 System.err.println("**** cmd_array = " + String.join(" ", cmd_array));
253
254 Runtime runtime = Runtime.getRuntime();
255 try {
256 Process wekaDB_proc = runtime.exec(cmd_array);
257 //int exitVal = wekaDB_proc.waitFor();
258 //System.err.println("*** exit status = " + exitVal);
259
260 InputStream wis = wekaDB_proc.getInputStream();
261 InputStreamReader wisr = new InputStreamReader(wis);
262 BufferedReader wbr = new BufferedReader(wisr);
263
264 query_result_ = new Vector();
265
266 boolean first_entry = true;
267 int line_count = 0;
268
269 String root_doc_id = null;
270 Vector<Double> rankVector = new Vector<Double>();
271 Vector<Integer> offsetVector = new Vector<Integer>();
272
273 // Example output
274 // D8 0.00105175
275 // 1.69786e-16 392 392
276 // 0.00113568 392 673
277 // 0.00127239 392 910
278 // 0.00139736 392 481
279 // 0.00145331 392 303
280 // D2 0.00429758
281 // 0.00403335 392 865
282 // 0.00411288 392 458
283 // 0.00442461 392 866
284 // 0.00444272 392 864
285 // 0.00447434 392 424
286 // ...
287
288 String line;
289 while ((line = wbr.readLine()) != null) {
290 String[] tokens = line.split("\\s+");
291 line_count++;
292
293 if (tokens.length==2) {
294 // processing a top-level doc line
295
296 if (line_count>1) {
297 // struck new top-level entry => store vector vals for previous block
298
299 first_entry = addQueryResult(first_entry,root_doc_id,rankVector,offsetVector);
300 // and now reset vectors to empty to be ready for next chain of values
301 rankVector = new Vector<Double>();
302 offsetVector = new Vector<Integer>();
303 }
304
305 root_doc_id = tokens[0];
306 }
307 else {
308 // should be 3 items
309 double euclidean_dist = Double.parseDouble(tokens[0]);
310 int src_frame = Integer.parseInt(tokens[1]);
311 int target_frame = Integer.parseInt(tokens[2]);
312
313 // ****
314
315 // enforce 1.0 as upper limit due to rounding errors
316 // in audioDB distance calculations
317 double rank = Math.min(1.0 - euclidean_dist,1.0);
318
319 if ((line_count==2) && (src_frame==target_frame)) {
320 // Found match with self
321 continue;
322 }
323
324 rankVector.add(rank);
325 offsetVector.add(target_frame);
326 }
327
328 }
329
330 addQueryResult(first_entry,root_doc_id,rankVector,offsetVector);
331
332 wbr.close();
333
334 // sort query_result_ on 'rank' field
335 // note: compareTo() method impelemented to sort into descending order
336
337 Collections.sort(query_result_);
338
339
340 }
341 catch (Exception e) {
342 logger.error("Failed to execute the following command: " + String.join(" ", cmd_array));
343 e.printStackTrace();
344 }
345
346 }
347
348
349 /** get the result out of the wrapper */
350 public Vector getQueryResult()
351 {
352 return query_result_;
353 }
354}
355
Note: See TracBrowser for help on using the repository browser.