1 | package org.greenstone.mars;
|
---|
2 |
|
---|
3 | //import weka.core.converters.ConverterUtils.DataSource;
|
---|
4 | import weka.core.Instances;
|
---|
5 |
|
---|
6 | //import weka.filters.Filter;
|
---|
7 | //import weka.filters.unsupervised.attribute.Remove;
|
---|
8 |
|
---|
9 | import weka.classifiers.Classifier;
|
---|
10 | import weka.classifiers.Evaluation;
|
---|
11 |
|
---|
12 | // Based on:
|
---|
13 | // https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
|
---|
14 |
|
---|
15 |
|
---|
16 | // Also, for a more direct command-line approach see
|
---|
17 | // https://waikato.github.io/weka-wiki/saving_and_loading_models/
|
---|
18 | //
|
---|
19 | // You save a trained classifier with the -d option (dumping), e.g.:
|
---|
20 | //
|
---|
21 | // java weka.classifiers.trees.J48 -C 0.25 -M 2 -t /some/where/train.arff -d /other/place/j48.model
|
---|
22 | //
|
---|
23 | // And you can load it with -l and use it on a test set, e.g.:
|
---|
24 | //
|
---|
25 | // java weka.classifiers.trees.J48 -l /other/place/j48.model -T /some/where/test.arff
|
---|
26 |
|
---|
27 |
|
---|
28 | class WekaTrainArousalModel
|
---|
29 | {
|
---|
30 | // Scheme: weka.classifiers.trees.REPTree -M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0 -- (i.e., default vals)
|
---|
31 | // Relation: deam-essentia-features-arousal-valence-weka.filters.unsupervised.attribute.Remove-R472-weka.filters.unsupervised.attribute.Remove-R458-466
|
---|
32 | // Instances: 1743
|
---|
33 |
|
---|
34 | // Note, above missed -R 425 and 427 (ordinal vals)
|
---|
35 |
|
---|
36 | public static void main(String[] args)
|
---|
37 | {
|
---|
38 | /*
|
---|
39 | if (args.length != 2) {
|
---|
40 | System.err.println("Error: incorrect number of command-line arguments");
|
---|
41 | System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
|
---|
42 | System.exit(1);
|
---|
43 | }*/
|
---|
44 |
|
---|
45 | WekaUtil.checkUsageTraining(args);
|
---|
46 |
|
---|
47 | String input_data_filename = args[0];
|
---|
48 | String output_model_filename = args[1];
|
---|
49 |
|
---|
50 | // knock out 472 (valence) from dataset
|
---|
51 | Instances filtered_data_instances = WekaUtil.loadAndFilterDataInstances(input_data_filename, "472");
|
---|
52 | Classifier classifier = WekaUtil.trainREPTree(filtered_data_instances);
|
---|
53 | System.out.println(classifier);
|
---|
54 |
|
---|
55 | Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances);
|
---|
56 | System.out.println(eval.toSummaryString());
|
---|
57 |
|
---|
58 | try {
|
---|
59 | System.out.println("Saving REPTree classifier model as: " + output_model_filename);
|
---|
60 | weka.core.SerializationHelper.write(output_model_filename, classifier);
|
---|
61 | }
|
---|
62 | catch (Exception e) {
|
---|
63 | e.printStackTrace();
|
---|
64 | }
|
---|
65 |
|
---|
66 |
|
---|
67 | /*
|
---|
68 | System.out.println("Training on ARFF file: " + input_arff_filename);
|
---|
69 |
|
---|
70 | try {
|
---|
71 | DataSource data_source = new DataSource(input_arff_filename);
|
---|
72 | Instances data_instances = data_source.getDataSet();
|
---|
73 |
|
---|
74 | // *********
|
---|
75 | Instances filtered_data_instances = WekaUtil.applyFilter(data_instances,"472"); // top-up with removal of 'valence'
|
---|
76 |
|
---|
77 |
|
---|
78 | // Filtering above has removed R472, so class to predict is numAttributes()-1
|
---|
79 | filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1);
|
---|
80 |
|
---|
81 |
|
---|
82 | // Build scheme/classifier
|
---|
83 | REPTree classifier = new REPTree(); // scheme
|
---|
84 | String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
|
---|
85 | classifier.setOptions(reptree_options);
|
---|
86 |
|
---|
87 | classifier.buildClassifier(filtered_data_instances);
|
---|
88 | System.out.println(classifier);
|
---|
89 |
|
---|
90 | / *
|
---|
91 | // Evaluate
|
---|
92 | Evaluation eval = new Evaluation(filtered_data_instances);
|
---|
93 | Random rand = new Random(1);
|
---|
94 | int folds = 10;
|
---|
95 | eval.crossValidateModel(classifier, filtered_data_instances, folds, rand);
|
---|
96 | System.out.println(eval.toSummaryString());
|
---|
97 | * /
|
---|
98 |
|
---|
99 | Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances);
|
---|
100 | System.out.println(eval.toSummaryString());
|
---|
101 |
|
---|
102 | System.out.println("Saving REPTree classifier model as: " + output_model_filename);
|
---|
103 | weka.core.SerializationHelper.write(output_model_filename, classifier);
|
---|
104 |
|
---|
105 |
|
---|
106 | }
|
---|
107 | catch (Exception e) {
|
---|
108 | e.printStackTrace();
|
---|
109 | }
|
---|
110 | */
|
---|
111 | }
|
---|
112 | }
|
---|