Changeset 34774
- Timestamp:
- 2021-02-01T00:25:56+13:00 (3 years ago)
- Location:
- main/trunk/model-sites-dev/mars
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
main/trunk/model-sites-dev/mars/RUN-APPLY-AROUSAL-MODEL.sh
r34593 r34774 3 3 . ./_set_weka_classpath.bash 4 4 5 java -cp "$cp_args" \ 6 WekaApplyArousalModel "reptree-arousal-serialized.model" 5 java -cp "$cp_args" WekaApplyArousalModel \ 6 "reptree-arousal-serialized.model" \ 7 "collect/deam/etc/deam-essentia-features-arousal-valence.arff" \ 8 "predicted-arousal.csv" 7 9 8 10 -
main/trunk/model-sites-dev/mars/WekaApplyArousalModel.java
r34593 r34774 1 1 //import java.util.Random; 2 2 3 import java.io.BufferedInputStream; 4 import java.io.FileInputStream; 5 6 import java.io.BufferedReader; 7 import java.io.BufferedWriter; 8 import java.io.FileReader; 9 import java.io.FileWriter; 10 3 11 import weka.core.converters.ConverterUtils.DataSource; 12 import weka.core.Instance; 4 13 import weka.core.Instances; 5 14 import weka.core.SerializationHelper; … … 13 22 14 23 15 import java.io.BufferedInputStream;16 import java.io.FileInputStream;17 24 18 //import java.io.BufferedReader; 19 //import java.io.BufferedWriter; 20 //import java.io.FileReader; 21 //import java.io.FileWriter; 22 import weka.core.Instances; 25 //import weka.core.Instances; 23 26 24 27 … … 30 33 public static void main(String[] args) 31 34 { 32 String classifier_input_filename = args[0]; 33 System.out.println("Loading Weka saved Classifier:" + classifier_input_filename); 35 if (args.length != 3) { 36 System.err.println("Error: incorrect number of command-line arguments"); 37 System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}"); 38 System.exit(1); 39 } 40 41 String classifier_input_filename = args[0]; 42 String unclassified_data_input_filename = args[1]; 43 String classified_data_output_filename = args[2]; 44 45 System.out.println("Loading Weka saved Classifier: " + classifier_input_filename); 34 46 35 47 try { … … 38 50 Classifier classifier = (Classifier)SerializationHelper.read(bis); 39 51 40 /*52 41 53 42 // load unlabeled data 43 Instances unlabeled = new Instances( 44 new BufferedReader( 45 new FileReader("/some/where/unlabeled.arff"))); 54 // load unlabeled data 55 System.out.println("Loading unlabeled instances: " + unclassified_data_input_filename); 56 FileReader fr = new FileReader(unclassified_data_input_filename); 57 BufferedReader br = new BufferedReader(fr); 58 Instances gt_instances = new Instances(br); 59 br.close(); 60 61 // set class attribute 62 gt_instances.setClassIndex(gt_instances.numAttributes() - 2); 63 64 // create copy 65 Instances unlabeled_instances = new Instances(gt_instances); 66 Instances labeled_instances = new Instances(unlabeled_instances); 46 67 47 // set class attribute 48 unlabeled.setClassIndex(unlabeled.numAttributes() - 1); 49 50 // create copy 51 Instances labeled = new Instances(unlabeled); 52 53 // label instances 54 for (int i = 0; i < unlabeled.numInstances(); i++) { 55 double clsLabel = tree.classifyInstance(unlabeled.instance(i)); 56 labeled.instance(i).setClassValue(clsLabel); 57 } 58 // save labeled data 59 BufferedWriter writer = new BufferedWriter( 60 new FileWriter("/some/where/labeled.arff")); 61 writer.write(labeled.toString()); 62 writer.newLine(); 63 writer.flush(); 64 writer.close(); 65 */ 66 67 /* 68 // label instances 69 final int num_instances = unlabeled_instances.numInstances(); 70 for (int i=0; i<num_instances; i++) { 71 Instance unlabeled_instance = unlabeled_instances.instance(i); 72 System.out.print("Making prediction for: " + i + "/" + num_instances); 73 double classified_label = classifier.classifyInstance(unlabeled_instance); 74 labeled_instances.instance(i).setClassValue(classified_label); 68 75 69 try { 70 DataSource data_source = new DataSource(arff_input_filename); 71 Instances data_instances = data_source.getDataSet(); 76 Instance gt_instance = gt_instances.instance(i); 77 double gt_class_value = gt_instance.classValue(); 78 System.out.println(" error: " + Math.abs(classified_label-gt_class_value)); 79 } 72 80 73 81 74 // Filter out Valence 75 String[] filter_options = weka.core.Utils.splitOptions("-R 472"); 76 //String[] filter_options = weka.core.Utils.splitOptions("-R 472 -R 458-466"); 77 Remove filter_remove = new Remove(); 78 filter_remove.setOptions(filter_options); 79 filter_remove.setInputFormat(data_instances); 80 Instances filtered_data_instances = Filter.useFilter(data_instances, filter_remove); 81 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes() - 1); 82 // save labeled data 83 System.out.println("Saving labeled instances: " + classified_data_output_filename); 84 FileWriter fw = new FileWriter(classified_data_output_filename); 85 BufferedWriter bw = new BufferedWriter(fw); 82 86 83 84 // Build scheme/classifier 85 REPTree classifier = new REPTree(); // scheme 86 String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0"); 87 classifier.setOptions(reptree_options); 88 89 classifier.buildClassifier(filtered_data_instances); 90 System.out.println(classifier); 91 92 93 Evaluation eval = new Evaluation(filtered_data_instances); 94 Random rand = new Random(1); 95 int folds = 10; 96 eval.crossValidateModel(classifier, filtered_data_instances, folds, rand); 97 System.out.println(eval.toSummaryString()); 98 99 weka.core.SerializationHelper.write("reptree-model.ser", classifier); 100 */ 101 102 87 bw.write(labeled_instances.toString()); 88 bw.newLine(); 89 bw.flush(); 90 bw.close(); 91 103 92 } 104 93 catch (Exception e) { -
main/trunk/model-sites-dev/mars/WekaTrainArousalModel.java
r34593 r34774 42 42 String output_model_filename = args[1]; 43 43 44 System.out.println("Training on ARFF file: " + input_arff_filename);44 System.out.println("Training on ARFF file: " + input_arff_filename); 45 45 46 46 try { … … 48 48 Instances data_instances = data_source.getDataSet(); 49 49 50 51 50 // Filter out Valence 52 String[] filter_options = weka.core.Utils.splitOptions("-R 472");51 //String[] filter_options = weka.core.Utils.splitOptions("-R 472"); 53 52 //String[] filter_options = weka.core.Utils.splitOptions("-R 472 -R 458-466"); 53 String[] filter_options = weka.core.Utils.splitOptions("-R 472,458-466"); 54 54 Remove filter_remove = new Remove(); 55 55 filter_remove.setOptions(filter_options); 56 56 filter_remove.setInputFormat(data_instances); 57 57 Instances filtered_data_instances = Filter.useFilter(data_instances, filter_remove); 58 59 // **** have removed R472, so class to predict is numAttributes()-1 58 60 filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes() - 1); 59 61 … … 74 76 System.out.println(eval.toSummaryString()); 75 77 76 System.out.println("Saving REPTree classifier model as: " + output_model_filename);78 System.out.println("Saving REPTree classifier model as: " + output_model_filename); 77 79 weka.core.SerializationHelper.write(output_model_filename, classifier); 78 80
Note:
See TracChangeset
for help on using the changeset viewer.