Changeset 34774


Ignore:
Timestamp:
2021-02-01T00:25:56+13:00 (3 years ago)
Author:
davidb
Message:

Development of code and resulting test runs to train and apply model

Location:
main/trunk/model-sites-dev/mars
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • main/trunk/model-sites-dev/mars/RUN-APPLY-AROUSAL-MODEL.sh

    r34593 r34774  
    33. ./_set_weka_classpath.bash
    44
    5 java -cp "$cp_args" \
    6      WekaApplyArousalModel "reptree-arousal-serialized.model"
     5java -cp "$cp_args" WekaApplyArousalModel \
     6     "reptree-arousal-serialized.model" \
     7     "collect/deam/etc/deam-essentia-features-arousal-valence.arff" \
     8     "predicted-arousal.csv"
    79
    810
  • main/trunk/model-sites-dev/mars/WekaApplyArousalModel.java

    r34593 r34774  
    11//import java.util.Random;
    22
     3import java.io.BufferedInputStream;
     4import java.io.FileInputStream;
     5
     6import java.io.BufferedReader;
     7import java.io.BufferedWriter;
     8import java.io.FileReader;
     9import java.io.FileWriter;
     10
    311import weka.core.converters.ConverterUtils.DataSource;
     12import weka.core.Instance;
    413import weka.core.Instances;
    514import weka.core.SerializationHelper;
     
    1322
    1423
    15 import java.io.BufferedInputStream;
    16 import java.io.FileInputStream;
    1724
    18 //import java.io.BufferedReader;
    19 //import java.io.BufferedWriter;
    20 //import java.io.FileReader;
    21 //import java.io.FileWriter;
    22 import weka.core.Instances;
     25//import weka.core.Instances;
    2326
    2427
     
    3033    public static void main(String[] args)
    3134    {
    32     String classifier_input_filename = args[0];
    33         System.out.println("Loading Weka saved Classifier:" + classifier_input_filename);
     35    if (args.length != 3) {
     36        System.err.println("Error: incorrect number of command-line arguments");
     37        System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}");
     38        System.exit(1);
     39    }
     40       
     41    String classifier_input_filename        = args[0];
     42    String unclassified_data_input_filename = args[1];
     43    String classified_data_output_filename  = args[2];
     44
     45        System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
    3446
    3547    try {
     
    3850        Classifier classifier = (Classifier)SerializationHelper.read(bis);
    3951   
    40     /*
     52   
    4153
    42     // load unlabeled data
    43     Instances unlabeled = new Instances(
    44                         new BufferedReader(
    45                                    new FileReader("/some/where/unlabeled.arff")));
     54        // load unlabeled data
     55        System.out.println("Loading unlabeled instances: " + unclassified_data_input_filename);
     56        FileReader fr = new FileReader(unclassified_data_input_filename);
     57        BufferedReader br = new BufferedReader(fr);
     58        Instances gt_instances = new Instances(br);
     59        br.close();
     60
     61        // set class attribute
     62        gt_instances.setClassIndex(gt_instances.numAttributes() - 2);
     63       
     64        // create copy
     65        Instances unlabeled_instances = new Instances(gt_instances);   
     66        Instances labeled_instances = new Instances(unlabeled_instances);
    4667   
    47     // set class attribute
    48     unlabeled.setClassIndex(unlabeled.numAttributes() - 1);
    49    
    50     // create copy
    51     Instances labeled = new Instances(unlabeled);
    52    
    53     // label instances
    54     for (int i = 0; i < unlabeled.numInstances(); i++) {
    55         double clsLabel = tree.classifyInstance(unlabeled.instance(i));
    56         labeled.instance(i).setClassValue(clsLabel);
    57     }
    58     // save labeled data
    59     BufferedWriter writer = new BufferedWriter(
    60                            new FileWriter("/some/where/labeled.arff"));
    61     writer.write(labeled.toString());
    62     writer.newLine();
    63     writer.flush();
    64     writer.close();
    65     */
    66    
    67     /*
     68        // label instances
     69        final int num_instances = unlabeled_instances.numInstances();       
     70        for (int i=0; i<num_instances; i++) {
     71        Instance unlabeled_instance = unlabeled_instances.instance(i);
     72        System.out.print("Making prediction for: " + i + "/" + num_instances);
     73        double classified_label = classifier.classifyInstance(unlabeled_instance);
     74        labeled_instances.instance(i).setClassValue(classified_label);
    6875
    69     try {
    70         DataSource data_source = new DataSource(arff_input_filename);
    71         Instances data_instances = data_source.getDataSet();
     76        Instance gt_instance = gt_instances.instance(i);
     77        double gt_class_value = gt_instance.classValue();
     78        System.out.println("  error: " + Math.abs(classified_label-gt_class_value));
     79        }
    7280
    7381
    74         // Filter out Valence
    75         String[] filter_options = weka.core.Utils.splitOptions("-R 472");
    76         //String[] filter_options = weka.core.Utils.splitOptions("-R 472 -R 458-466");
    77         Remove filter_remove = new Remove();
    78         filter_remove.setOptions(filter_options);
    79         filter_remove.setInputFormat(data_instances);
    80         Instances filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
    81         filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes() - 1);
     82        // save labeled data           
     83        System.out.println("Saving labeled instances: " + classified_data_output_filename);     
     84        FileWriter fw = new FileWriter(classified_data_output_filename);
     85        BufferedWriter bw = new BufferedWriter(fw);
    8286       
    83 
    84         // Build scheme/classifier
    85         REPTree classifier = new REPTree(); // scheme
    86         String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
    87         classifier.setOptions(reptree_options);
    88 
    89         classifier.buildClassifier(filtered_data_instances);
    90         System.out.println(classifier);
    91 
    92 
    93         Evaluation eval = new Evaluation(filtered_data_instances);
    94         Random rand = new Random(1);
    95         int folds = 10;
    96         eval.crossValidateModel(classifier, filtered_data_instances, folds, rand);
    97         System.out.println(eval.toSummaryString());
    98 
    99         weka.core.SerializationHelper.write("reptree-model.ser", classifier);
    100     */
    101        
    102          
     87        bw.write(labeled_instances.toString());
     88        bw.newLine();
     89        bw.flush();
     90        bw.close();
     91             
    10392    }
    10493    catch (Exception e) {
  • main/trunk/model-sites-dev/mars/WekaTrainArousalModel.java

    r34593 r34774  
    4242    String output_model_filename = args[1];
    4343   
    44         System.out.println("Training on ARFF file:" + input_arff_filename);
     44        System.out.println("Training on ARFF file: " + input_arff_filename);
    4545
    4646    try {
     
    4848        Instances data_instances = data_source.getDataSet();
    4949
    50 
    5150        // Filter out Valence
    52         String[] filter_options = weka.core.Utils.splitOptions("-R 472");
     51        //String[] filter_options = weka.core.Utils.splitOptions("-R 472");
    5352        //String[] filter_options = weka.core.Utils.splitOptions("-R 472 -R 458-466");
     53        String[] filter_options = weka.core.Utils.splitOptions("-R 472,458-466");
    5454        Remove filter_remove = new Remove();
    5555        filter_remove.setOptions(filter_options);
    5656        filter_remove.setInputFormat(data_instances);
    5757        Instances filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
     58
     59        // **** have removed R472, so class to predict is numAttributes()-1
    5860        filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes() - 1);
    5961       
     
    7476        System.out.println(eval.toSummaryString());
    7577
    76         System.out.println("Saving REPTree classifier model as:" + output_model_filename);
     78        System.out.println("Saving REPTree classifier model as: " + output_model_filename);
    7779        weka.core.SerializationHelper.write(output_model_filename, classifier);
    7880
Note: See TracChangeset for help on using the changeset viewer.