source: gs3-extensions/mars-src/trunk/src/java/org/greenstone/gsdl3/util/WekaDBWrapper.java@ 36863

Last change on this file since 36863 was 36863, checked in by davidb, 19 months ago

Further improvements to performing AV recommendation

File size: 11.0 KB
Line 
1/*
2 * WekaDBWrapper.java
3 * Copyright (C) 2011 New Zealand Digital Library, http://www.nzdl.org
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 2 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program; if not, write to the Free Software
17 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 */
19package org.greenstone.gsdl3.util;
20
21import java.io.*;
22import java.util.Vector;
23import java.util.Collections;
24import java.util.regex.Pattern;
25import java.util.regex.Matcher;
26
27import weka.core.Attribute;
28import weka.core.Instance;
29import weka.core.Instances;
30
31import org.apache.log4j.*;
32
33import org.greenstone.gsdl3.util.WekaFindInstanceKNN;
34
35/** Java wrapper class for access to the Weka
36 * Devised (in the first instance) to operate as: java -jar weka.jar <arg1> <arg2>
37 *
38 * Inspired by MGSearchWrapper.java
39 */
40
41public class WekaDBWrapper
42{
43 public final static double AV_SEGMENT_LENGTH_SECS = 6.0;
44
45 /** the query result, filled in by runQuery */
46 protected Vector query_result_;
47
48 protected int offset_ = 100;
49 protected int length_ = 20; // **** Unused
50
51 // Approximate matching not yet utilized
52 protected double radius_; // **** Unused
53
54 protected int max_docs_;
55
56 protected double arousal_;
57 protected double valence_;
58
59 static Logger logger = Logger.getLogger (org.greenstone.gsdl3.util.WekaDBWrapper.class.getName ());
60
61 public WekaDBWrapper() {
62 query_result_ = null;
63 }
64
65 // query param methods
66
67 /** start point (offset) into the array of feature vectors for a track
68 - 100 by default which equals 10 seconds (assuming 0.1 frame size) */
69 public void setOffset(int offset) {
70 offset_ = offset;
71 }
72
73 /** the number of consecutive frames used in match
74 - 20 by default which equals 2 seconds (assuming 0.1 frame size) */
75 public void setLength(int length) {
76 length_ = length;
77 }
78
79 /** distance used in approximate matching support - default is 50 */
80 public void setRadius(double radius) {
81 radius_ = radius;
82 }
83
84 public void setMaxDocs(int max_docs) {
85 max_docs_ = max_docs;
86 }
87
88 public void setArousal(double arousal) {
89 arousal_ = arousal;
90 }
91 public void setValence(double valence) {
92 valence_ = valence;
93 }
94
95 /** returns a string with all the current query param settings */
96 // the following was in MG version, do we need this in WekaDB version? // ****
97 //public String getQueryParams() {}
98
99
100 protected boolean addQueryResult(boolean first_entry, String doc_id,
101 Vector<Double> rankVector, Vector<Integer> offsetVector)
102 {
103
104 if (first_entry) {
105 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,rankVector,offsetVector);
106 query_result_.add(wekaDB_doc_info);
107 first_entry = false;
108 }
109 else {
110 double rank = rankVector.get(0);
111 int offset = offsetVector.get(0);
112 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(doc_id,rank,offset);
113
114 query_result_.add(wekaDB_doc_info);
115 }
116
117 return first_entry;
118 }
119
120
121 /** actually carry out the query.
122 Use the set methods to set query results.
123 Writes the result to query_result.
124 * - maintains state between requests as can be slow
125 * base_dir and index_path should join together to provide
126 * the absolute location of the mg index files eg ..../index/dtx/demo
127 * base_dir must end with a file separator (OS dependant)
128 */
129
130
131 public void runQuery(String wekaDB_index_dir, String knn_model_file,
132 String assoc_index_dir, String query_string) {
133
134 // combine index_dir with audiodb fileanem
135
136 String full_knn_model_filename = wekaDB_index_dir + File.separatorChar + knn_model_file;
137
138 //String full_chr12_filename = assoc_index_dir + File.separatorChar
139 // + query_string + File.separatorChar + "doc.chr12";
140
141 System.err.println("**** full knn model filename = " + full_knn_model_filename);
142
143 // Example returned result from Weka KNN
144 // => first line is the input instance ('filename+segment',Arousal,Valence)
145 // following (indented lines) nearest neighbour matches in same format
146 //
147 // ds_22716_5743-6,-0.549489,-0.118439
148 // ds_22716_5743-6,-0.549489,-0.118439
149 // ds_31008_6550-30,-0.549489,-0.118439
150 // ds_72651_26831-6,-0.549489,-0.118439
151 // ds_26196_9214-18,-0.549489,-0.118439
152
153
154 WekaFindInstanceKNN.init(full_knn_model_filename);
155
156 String doc_id = query_string;
157 int segment = offset_;
158
159 String query_doc_id_segment = doc_id + "-" + segment;
160
161 double query_arousal_val = arousal_;
162 double query_valence_val = valence_;
163
164 int k_nearest_num = max_docs_;
165 int expanded_k_nearest_num = max_docs_ * 5; // * internally get more matches, then sift through to arrive at the best 'max_docs_'
166
167 Pattern doc_seg_re = Pattern.compile("^(\\w+)-(\\d+)$");
168 //Matcher query_doc_seg_match = doc_seq_re.matcher(query_doc_id_segment);
169
170 Instances nearest_instances
171 = WekaFindInstanceKNN.kNearestNeighbours(query_doc_id_segment,query_arousal_val,query_valence_val,k_nearest_num);
172
173
174 Vector expanded_query_result = new Vector();
175
176 int nearest_instances_len = nearest_instances.size();
177
178 int clamped_expanded_k_nearest_num = Math.min(expanded_k_nearest_num,nearest_instances_len);
179
180 double pos_penalty = 0.1;
181
182 for (int ei=0; ei<clamped_expanded_k_nearest_num; ei++) {
183 Instance instance = nearest_instances.instance(ei);
184 logger.info("\tProcessing returned instance: " + instance);
185
186 String matching_doc_id_segment = instance.stringValue(0);
187
188
189 //Pattern p = Pattern.compile("^(\\w+)-(\\d+)$");
190 Matcher m = doc_seg_re.matcher(matching_doc_id_segment);
191 if (m.matches()) {
192
193 String matching_doc_id = m.group(1);
194 int end_of_matching_segment_offset = Integer.parseInt(m.group(2));
195 int matching_segment_offset = end_of_matching_segment_offset - (int)AV_SEGMENT_LENGTH_SECS;
196
197 if (matching_doc_id.equals(doc_id)) {
198 continue;
199 }
200
201 double matching_arousal_val = instance.value(1);
202 double matching_valence_val = instance.value(2);
203
204 double matching_diff = (Math.abs(query_arousal_val - matching_arousal_val)
205 + Math.abs(query_valence_val - matching_valence_val))/4.0;
206 double matching_rank = 1.0 - matching_diff - (pos_penalty * (double)ei);
207
208 logger.info("\tAdding in: matching_doc_id = " + matching_doc_id);
209 WekaDBDocInfo wekaDB_doc_info = new WekaDBDocInfo(matching_doc_id,matching_rank,matching_segment_offset);
210 expanded_query_result.add(wekaDB_doc_info);
211 }
212 else {
213 logger.error("Returned AV k-nearest neighbour match '"+matching_doc_id_segment+"' could not be parsed as <doc-id>-<segment>" );
214 }
215 }
216
217 query_result_ = new Vector();
218
219 int i = 0;
220 while (i < k_nearest_num) {
221 if (i >= expanded_query_result.size()) {
222 break;
223 }
224
225 query_result_.add(expanded_query_result.get(i));
226 i++;
227 }
228
229 Collections.sort(query_result_);
230 }
231
232 public void runQueryOLD(String wekaDB_index_dir, String knn_model_file,
233 String assoc_index_dir, String query_string) {
234
235 // combine index_dir with audiodb fileanem
236
237 String full_knn_model_filename = wekaDB_index_dir + File.separatorChar + knn_model_file;
238 String full_chr12_filename = assoc_index_dir + File.separatorChar
239 + query_string + File.separatorChar + "doc.chr12";
240
241 int num_matches_within_track = 6;
242
243 // ****
244 String [] cmd_array = new String[] {
245 "java", "-jar", "weka.jar",
246 "-d", full_knn_model_filename,
247 "-Q", "nsequence",
248 "-p", String.format("%d",offset_),
249 "-n", String.format("%d",num_matches_within_track),
250 "-l", String.format("%d",length_),
251 "-r", String.format("%d",max_docs_),
252 "-f", full_chr12_filename
253 };
254
255 System.err.println("**** cmd_array = " + String.join(" ", cmd_array));
256
257 Runtime runtime = Runtime.getRuntime();
258 try {
259 Process wekaDB_proc = runtime.exec(cmd_array);
260 //int exitVal = wekaDB_proc.waitFor();
261 //System.err.println("*** exit status = " + exitVal);
262
263 InputStream wis = wekaDB_proc.getInputStream();
264 InputStreamReader wisr = new InputStreamReader(wis);
265 BufferedReader wbr = new BufferedReader(wisr);
266
267 query_result_ = new Vector();
268
269 boolean first_entry = true;
270 int line_count = 0;
271
272 String root_doc_id = null;
273 Vector<Double> rankVector = new Vector<Double>();
274 Vector<Integer> offsetVector = new Vector<Integer>();
275
276 // Example output
277 // D8 0.00105175
278 // 1.69786e-16 392 392
279 // 0.00113568 392 673
280 // 0.00127239 392 910
281 // 0.00139736 392 481
282 // 0.00145331 392 303
283 // D2 0.00429758
284 // 0.00403335 392 865
285 // 0.00411288 392 458
286 // 0.00442461 392 866
287 // 0.00444272 392 864
288 // 0.00447434 392 424
289 // ...
290
291 String line;
292 while ((line = wbr.readLine()) != null) {
293 String[] tokens = line.split("\\s+");
294 line_count++;
295
296 if (tokens.length==2) {
297 // processing a top-level doc line
298
299 if (line_count>1) {
300 // struck new top-level entry => store vector vals for previous block
301
302 first_entry = addQueryResult(first_entry,root_doc_id,rankVector,offsetVector);
303 // and now reset vectors to empty to be ready for next chain of values
304 rankVector = new Vector<Double>();
305 offsetVector = new Vector<Integer>();
306 }
307
308 root_doc_id = tokens[0];
309 }
310 else {
311 // should be 3 items
312 double euclidean_dist = Double.parseDouble(tokens[0]);
313 int src_frame = Integer.parseInt(tokens[1]);
314 int target_frame = Integer.parseInt(tokens[2]);
315
316 // ****
317
318 // enforce 1.0 as upper limit due to rounding errors
319 // in audioDB distance calculations
320 double rank = Math.min(1.0 - euclidean_dist,1.0);
321
322 if ((line_count==2) && (src_frame==target_frame)) {
323 // Found match with self
324 continue;
325 }
326
327 rankVector.add(rank);
328 offsetVector.add(target_frame);
329 }
330
331 }
332
333 addQueryResult(first_entry,root_doc_id,rankVector,offsetVector);
334
335 wbr.close();
336
337 // sort query_result_ on 'rank' field
338 // note: compareTo() method impelemented to sort into descending order
339
340 Collections.sort(query_result_);
341
342
343 }
344 catch (Exception e) {
345 logger.error("Failed to execute the following command: " + String.join(" ", cmd_array));
346 e.printStackTrace();
347 }
348
349 }
350
351
352 /** get the result out of the wrapper */
353 public Vector getQueryResult()
354 {
355 return query_result_;
356 }
357}
358
Note: See TracBrowser for help on using the repository browser.