source: main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaTrainArousalModel.java@ 34797

Last change on this file since 34797 was 34797, checked in by davidb, 3 years ago

Fixed up spelling mixup valance -> valence

File size: 3.7 KB
Line 
1package org.greenstone.mars;
2
3//import weka.core.converters.ConverterUtils.DataSource;
4import weka.core.Instances;
5
6//import weka.filters.Filter;
7//import weka.filters.unsupervised.attribute.Remove;
8
9import weka.classifiers.Classifier;
10import weka.classifiers.Evaluation;
11
12// Based on:
13// https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
14
15
16// Also, for a more direct command-line approach see
17// https://waikato.github.io/weka-wiki/saving_and_loading_models/
18//
19// You save a trained classifier with the -d option (dumping), e.g.:
20//
21// java weka.classifiers.trees.J48 -C 0.25 -M 2 -t /some/where/train.arff -d /other/place/j48.model
22//
23// And you can load it with -l and use it on a test set, e.g.:
24//
25// java weka.classifiers.trees.J48 -l /other/place/j48.model -T /some/where/test.arff
26
27
28class WekaTrainArousalModel
29{
30 // Scheme: weka.classifiers.trees.REPTree -M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0 -- (i.e., default vals)
31 // Relation: deam-essentia-features-arousal-valence-weka.filters.unsupervised.attribute.Remove-R472-weka.filters.unsupervised.attribute.Remove-R458-466
32 // Instances: 1743
33
34 // Note, above missed -R 425 and 427 (ordinal vals)
35
36 public static void main(String[] args)
37 {
38 /*
39 if (args.length != 2) {
40 System.err.println("Error: incorrect number of command-line arguments");
41 System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
42 System.exit(1);
43 }*/
44
45 WekaUtil.checkUsageTraining(args);
46
47 String input_data_filename = args[0];
48 String output_model_filename = args[1];
49
50 // knock out 472 (valence) from dataset
51 Instances filtered_data_instances = WekaUtil.loadAndFilterDataInstances(input_data_filename, "472");
52 Classifier classifier = WekaUtil.trainREPTree(filtered_data_instances);
53 System.out.println(classifier);
54
55 Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances);
56 System.out.println(eval.toSummaryString());
57
58 try {
59 System.out.println("Saving REPTree classifier model as: " + output_model_filename);
60 weka.core.SerializationHelper.write(output_model_filename, classifier);
61 }
62 catch (Exception e) {
63 e.printStackTrace();
64 }
65
66
67 /*
68 System.out.println("Training on ARFF file: " + input_arff_filename);
69
70 try {
71 DataSource data_source = new DataSource(input_arff_filename);
72 Instances data_instances = data_source.getDataSet();
73
74 // *********
75 Instances filtered_data_instances = WekaUtil.applyFilter(data_instances,"472"); // top-up with removal of 'valence'
76
77
78 // Filtering above has removed R472, so class to predict is numAttributes()-1
79 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1);
80
81
82 // Build scheme/classifier
83 REPTree classifier = new REPTree(); // scheme
84 String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
85 classifier.setOptions(reptree_options);
86
87 classifier.buildClassifier(filtered_data_instances);
88 System.out.println(classifier);
89
90 / *
91 // Evaluate
92 Evaluation eval = new Evaluation(filtered_data_instances);
93 Random rand = new Random(1);
94 int folds = 10;
95 eval.crossValidateModel(classifier, filtered_data_instances, folds, rand);
96 System.out.println(eval.toSummaryString());
97 * /
98
99 Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances);
100 System.out.println(eval.toSummaryString());
101
102 System.out.println("Saving REPTree classifier model as: " + output_model_filename);
103 weka.core.SerializationHelper.write(output_model_filename, classifier);
104
105
106 }
107 catch (Exception e) {
108 e.printStackTrace();
109 }
110 */
111 }
112}
Note: See TracBrowser for help on using the repository browser.