1 | package org.greenstone.mars;
|
---|
2 |
|
---|
3 | //import weka.core.converters.ConverterUtils.DataSource;
|
---|
4 | import weka.core.Instances;
|
---|
5 |
|
---|
6 | //import weka.filters.Filter;
|
---|
7 | //import weka.filters.unsupervised.attribute.Remove;
|
---|
8 |
|
---|
9 | import weka.classifiers.Classifier;
|
---|
10 | import weka.classifiers.Evaluation;
|
---|
11 | //import weka.classifiers.trees.REPTree;
|
---|
12 |
|
---|
13 | // Based on:
|
---|
14 | // https://waikato.github.io/weka-wiki/use_weka_in_your_java_code/
|
---|
15 |
|
---|
16 |
|
---|
17 | class WekaTrainValanceModel
|
---|
18 | {
|
---|
19 | // Scheme: weka.classifiers.trees.REPTree -M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0 -- (i.e., default vals)
|
---|
20 | // Relation: deam-essentia-features-arousal-valence-weka.filters.unsupervised.attribute.Remove-R471-weka.filters.unsupervised.attribute.Remove-R458-466-weka.filters.unsupervised.attribute.Remove-R425
|
---|
21 | // Instances: 1743
|
---|
22 |
|
---|
23 | // Note, above missed -R 427 (ordinal val)
|
---|
24 |
|
---|
25 |
|
---|
26 | public static void main(String[] args)
|
---|
27 | {
|
---|
28 | /*
|
---|
29 | if (args.length != 2) {
|
---|
30 | System.err.println("Error: incorrect number of command-line arguments");
|
---|
31 | System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
|
---|
32 | System.exit(1);
|
---|
33 | }*/
|
---|
34 |
|
---|
35 | WekaUtil.checkUsageTraining(args);
|
---|
36 |
|
---|
37 | String input_data_filename = args[0];
|
---|
38 | String output_model_filename = args[1];
|
---|
39 |
|
---|
40 | // knock out 471 (arousal) from dataset
|
---|
41 | Instances filtered_data_instances = WekaUtil.loadAndFilterDataInstances(input_data_filename, "471");
|
---|
42 | Classifier classifier = WekaUtil.trainREPTree(filtered_data_instances);
|
---|
43 | System.out.println(classifier);
|
---|
44 |
|
---|
45 | Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances);
|
---|
46 | System.out.println(eval.toSummaryString());
|
---|
47 |
|
---|
48 | try {
|
---|
49 | System.out.println("Saving REPTree classifier model as: " + output_model_filename);
|
---|
50 | weka.core.SerializationHelper.write(output_model_filename, classifier);
|
---|
51 | }
|
---|
52 | catch (Exception e) {
|
---|
53 | e.printStackTrace();
|
---|
54 | }
|
---|
55 |
|
---|
56 | /*
|
---|
57 | System.out.println("Training on ARFF file: " + input_arff_filename);
|
---|
58 |
|
---|
59 | try {
|
---|
60 | DataSource data_source = new DataSource(input_arff_filename);
|
---|
61 | Instances data_instances = data_source.getDataSet();
|
---|
62 |
|
---|
63 | // *********
|
---|
64 | Instances filtered_data_instances = WekaUtil.applyFilter(data_instances,"471"); // top-up with removal of 'arousal'
|
---|
65 |
|
---|
66 | // ****** comment varies
|
---|
67 | // With removal of R471, last column is the one we want to predict => numAttributes()-1
|
---|
68 | filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1);
|
---|
69 |
|
---|
70 |
|
---|
71 | // Build scheme/classifier
|
---|
72 | REPTree classifier = new REPTree(); // scheme
|
---|
73 | String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
|
---|
74 | classifier.setOptions(reptree_options);
|
---|
75 |
|
---|
76 | classifier.buildClassifier(filtered_data_instances);
|
---|
77 | System.out.println(classifier);
|
---|
78 |
|
---|
79 |
|
---|
80 | Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances);
|
---|
81 | System.out.println(eval.toSummaryString());
|
---|
82 |
|
---|
83 | System.out.println("Saving REPTree classifier model as: " + output_model_filename);
|
---|
84 | weka.core.SerializationHelper.write(output_model_filename, classifier);
|
---|
85 |
|
---|
86 |
|
---|
87 | }
|
---|
88 | catch (Exception e) {
|
---|
89 | e.printStackTrace();
|
---|
90 | }
|
---|
91 |
|
---|
92 | */
|
---|
93 |
|
---|
94 | }
|
---|
95 | }
|
---|