source: main/trunk/model-sites-dev/mars/src/java/WekaTrainArousalModel.java@ 34780

Last change on this file since 34780 was 34780, checked in by davidb, 3 years ago

New location for Java source

File size: 3.1 KB
Line 
1import java.util.Random;
2
3import weka.core.converters.ConverterUtils.DataSource;
4import weka.core.Instances;
5
6import weka.filters.Filter;
7import weka.filters.unsupervised.attribute.Remove;
8
9import weka.classifiers.Evaluation;
10import weka.classifiers.trees.REPTree;
11
12// Based on:
13// https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
14
15
16// Also, for a more direct command-line approach see
17// https://waikato.github.io/weka-wiki/saving_and_loading_models/
18//
19// You save a trained classifier with the -d option (dumping), e.g.:
20//
21// java weka.classifiers.trees.J48 -C 0.25 -M 2 -t /some/where/train.arff -d /other/place/j48.model
22//
23// And you can load it with -l and use it on a test set, e.g.:
24//
25// java weka.classifiers.trees.J48 -l /other/place/j48.model -T /some/where/test.arff
26
27
28class WekaTrainArousalModel
29{
30 // Scheme: weka.classifiers.trees.REPTree -M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0
31 // Relation: deam-essentia-features-arousal-valence-weka.filters.unsupervised.attribute.Remove-R472-weka.filters.unsupervised.attribute.Remove-R458-466
32 // Instances: 1743
33
34 public static void main(String[] args)
35 {
36 if (args.length != 2) {
37 System.err.println("Error: incorrect number of command-line arguments");
38 System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
39 System.exit(1);
40 }
41 String input_arff_filename = args[0];
42 String output_model_filename = args[1];
43
44 System.out.println("Training on ARFF file: " + input_arff_filename);
45
46 try {
47 DataSource data_source = new DataSource(input_arff_filename);
48 Instances data_instances = data_source.getDataSet();
49
50 // Filter out Valence
51 //String[] filter_options = weka.core.Utils.splitOptions("-R 472");
52 //String[] filter_options = weka.core.Utils.splitOptions("-R 472 -R 458-466");
53 String[] filter_options = weka.core.Utils.splitOptions("-R 425,458-466,472"); // remove ordinal attributes and 'valance'
54 Remove filter_remove = new Remove();
55 filter_remove.setOptions(filter_options);
56 filter_remove.setInputFormat(data_instances);
57 Instances filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
58
59 // **** have removed R472, so class to predict is numAttributes()-1
60 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes() - 1);
61
62
63 // Build scheme/classifier
64 REPTree classifier = new REPTree(); // scheme
65 String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
66 classifier.setOptions(reptree_options);
67
68 classifier.buildClassifier(filtered_data_instances);
69 System.out.println(classifier);
70
71
72 Evaluation eval = new Evaluation(filtered_data_instances);
73 Random rand = new Random(1);
74 int folds = 10;
75 eval.crossValidateModel(classifier, filtered_data_instances, folds, rand);
76 System.out.println(eval.toSummaryString());
77
78 System.out.println("Saving REPTree classifier model as: " + output_model_filename);
79 weka.core.SerializationHelper.write(output_model_filename, classifier);
80
81
82 }
83 catch (Exception e) {
84 e.printStackTrace();
85 }
86
87 }
88}
Note: See TracBrowser for help on using the repository browser.