Changeset 34778


Ignore:
Timestamp:
2021-02-01T17:03:58+13:00 (3 years ago)
Author:
davidb
Message:

Next logical commit point after testing on CSV files in amc-essentia/import

Location:
main/trunk/model-sites-dev/mars
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • main/trunk/model-sites-dev/mars/RUN-APPLY-AROUSAL-MODEL.sh

    r34774 r34778  
    11#!/bin/bash
     2
     3input_data_filename=${1:-collect/deam/etc/deam-essentia-features-arousal-valence.arff}
     4output_data_filename=${2:-predicted-arousal.csv}
    25
    36. ./_set_weka_classpath.bash
     
    58java -cp "$cp_args" WekaApplyArousalModel \
    69     "reptree-arousal-serialized.model" \
    7      "collect/deam/etc/deam-essentia-features-arousal-valence.arff" \
    8      "predicted-arousal.csv"
     10     "$input_data_filename" \
     11     "$output_data_filename"
     12
     13echo "****"
     14echo "* Saved output data: $output_data_filename"
     15echo "****"
     16
    917
    1018
  • main/trunk/model-sites-dev/mars/WekaApplyArousalModel.java

    r34774 r34778  
    1010
    1111import weka.core.converters.ConverterUtils.DataSource;
     12import weka.core.Attribute;
    1213import weka.core.Instance;
    1314import weka.core.Instances;
    1415import weka.core.SerializationHelper;
    1516
    16 //import weka.filters.Filter;
    17 //import weka.filters.unsupervised.attribute.Remove;
    18 
    19 //import weka.classifiers.Evaluation;
     17import weka.filters.Filter;
     18import weka.filters.unsupervised.attribute.Remove;
     19
    2020import weka.classifiers.Classifier;
    21 //import weka.classifiers.trees.REPTree;
    22 
    23 
    24 
    25 //import weka.core.Instances;
    2621
    2722
     
    3126class WekaApplyArousalModel
    3227{
     28    public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
     29    public final static String VALANCE_ATTRIBUTE_NAME = "valance_sample_26500ms";
     30
     31    public final static String PREDICT_ATTRIBUTE_NAME = "arousal_sample_26500ms";
     32   
     33
     34    public static Instances applyFilter(Instances data_instances,String additional_remove)
     35    {
     36    String remove_option_args = ("-R 425,458-466"); // remove ordinal attributes
     37    if (additional_remove != null) {
     38        // top up, e.g. '472' representing valance in ground-truth files
     39        remove_option_args += "," + additional_remove;
     40    }
     41
     42
     43    Instances filtered_data_instances = null;
     44    try {
     45        String[] filter_options = weka.core.Utils.splitOptions(remove_option_args); // remove ordinal attributes and 'valance'
     46        Remove filter_remove = new Remove();
     47        filter_remove.setOptions(filter_options);
     48        filter_remove.setInputFormat(data_instances);
     49        filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
     50    }
     51    catch (Exception e) {
     52        e.printStackTrace();
     53        System.exit(1);
     54    }
     55   
     56    return filtered_data_instances;
     57    }
     58
     59
     60   
     61    public static void checkDatasetInstancesCompatible(Instances new_instances)
     62    {
     63
     64    final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";
     65
     66    try {
     67        DataSource gt_data_source = new DataSource(gt_datasource_filename);
     68        Instances gt_instances = gt_data_source.getDataSet();
     69        gt_instances = applyFilter(gt_instances,"472"); // top up with removing 'valance'
     70       
     71        gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
     72
     73        String equal_header_message = gt_instances.equalHeadersMsg(new_instances);
     74       
     75        if (equal_header_message != null) {
     76        System.err.println("Headers to datasets were not equal!");
     77        System.err.println(equal_header_message);
     78        System.exit(1);
     79        }
     80    }
     81    catch (Exception e) {
     82        e.printStackTrace();
     83        System.exit(1);
     84    }
     85
     86    }
     87   
    3388    public static void main(String[] args)
    3489    {
     
    49104        BufferedInputStream bis= new BufferedInputStream(fis);
    50105        Classifier classifier = (Classifier)SerializationHelper.read(bis);
    51    
    52    
    53 
    54         // load unlabeled data
     106
     107        //
     108        // Load in unlabeled data
     109        //
     110       
    55111        System.out.println("Loading unlabeled instances: " + unclassified_data_input_filename);
    56         FileReader fr = new FileReader(unclassified_data_input_filename);
    57         BufferedReader br = new BufferedReader(fr);
    58         Instances gt_instances = new Instances(br);
    59         br.close();
    60 
    61         // set class attribute
    62         gt_instances.setClassIndex(gt_instances.numAttributes() - 2);
    63        
    64         // create copy
    65         Instances unlabeled_instances = new Instances(gt_instances);   
     112
     113        //FileReader fr = new FileReader(unclassified_data_input_filename);
     114        //BufferedReader br = new BufferedReader(fr);
     115        //Instances unlabeled_instances = new Instances(br);
     116        //br.close();
     117
     118        DataSource data_source = new DataSource(unclassified_data_input_filename);
     119        Instances unlabeled_instances = data_source.getDataSet();
     120
     121        // Work out if we're dealing with a ground-truth ARFF file or not
     122        // (i.e. already has the desired attribute)
     123
     124        Instances groundtruth_instances = null;     
     125        Attribute predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME);
     126       
     127        if (predict_attribute == null) {
     128
     129        unlabeled_instances = applyFilter(unlabeled_instances,null); // no additional top-up to remove     
     130        int num_attributes = unlabeled_instances.numAttributes();
     131
     132        Attribute arousal_attribute = new Attribute(AROUSAL_ATTRIBUTE_NAME);
     133        unlabeled_instances.insertAttributeAt(arousal_attribute,num_attributes);
     134        num_attributes++;
     135       
     136        //Attribute valance_attribute = new Attribute(VALANCE_ATTRIBUTE_NAME);
     137        //unlabeled_instances.insertAttributeAt(valance_attribute,num_attributes);
     138        //num_attributes++;
     139
     140        //predict_attribute = new Attribute(PREDICT_ATTRIBUTE_NAME);
     141        //unlabeled_instances.insertAttributeAt(predict_attribute,num_attributes);
     142        //unlabeled_instances.setClassIndex(num_attributes);
     143        //num_attributes++;
     144
     145        //predict_attribute = unlabeled_instances.attribute(PREDICT_ATTRIBUTE_NAME);
     146        //unlabeled_instances.setClass(predict_attribute);
     147
     148        //unlabeled_instances.setClassIndex(num_attributes - 1);
     149        }
     150        else {
     151        // Dealing with ground-truth data:
     152        //   => already has 'arousal' attribute
     153        //   => in fact has 'valance' attribute too, which we want to remove
     154       
     155        unlabeled_instances = applyFilter(unlabeled_instances,"472"); // top-up with removal of 'valance'
     156        //unlabeled_instances.setClass(predict_attribute);
     157       
     158        // reference share this as 'groundtruth_instances' to trigger error calculation and output
     159        groundtruth_instances = unlabeled_instances;
     160        }
     161
     162        int num_attributes = unlabeled_instances.numAttributes();
     163        unlabeled_instances.setClassIndex(num_attributes - 1);
     164       
     165        checkDatasetInstancesCompatible(unlabeled_instances);
     166       
     167        // Set class attribute
     168        //unlabeled_instances.setClassIndex(gt_instances.numAttributes() - 2);
     169
     170
     171        // Create copy where the predictions are to be made
    66172        Instances labeled_instances = new Instances(unlabeled_instances);
    67    
    68         // label instances
     173
     174        //
     175        // Label instances
     176        //
     177       
    69178        final int num_instances = unlabeled_instances.numInstances();       
    70179        for (int i=0; i<num_instances; i++) {
    71180        Instance unlabeled_instance = unlabeled_instances.instance(i);
     181
    72182        System.out.print("Making prediction for: " + i + "/" + num_instances);
    73         double classified_label = classifier.classifyInstance(unlabeled_instance);
    74         labeled_instances.instance(i).setClassValue(classified_label);
    75 
    76         Instance gt_instance = gt_instances.instance(i);
    77         double gt_class_value = gt_instance.classValue();
    78         System.out.println("  error: " + Math.abs(classified_label-gt_class_value));
    79         }
    80 
    81 
    82         // save labeled data           
     183        double classified_value = classifier.classifyInstance(unlabeled_instance);
     184        labeled_instances.instance(i).setClassValue(classified_value);
     185
     186        String formatted_classified_value = String.format("% 06.3f", classified_value);
     187       
     188        System.out.print(" value = " + formatted_classified_value);
     189       
     190        if (groundtruth_instances != null) {
     191            Instance gt_instance = groundtruth_instances.instance(i);
     192            double gt_class_value = gt_instance.classValue();
     193            double error = Math.abs(classified_value - gt_class_value);
     194
     195            String formatted_error = String.format("%.3f", error);
     196            System.out.print("  [error: " + formatted_error + "]");
     197        }
     198        System.out.println();
     199        }
     200
     201        //
     202        // Save labeled data
     203        //
     204       
    83205        System.out.println("Saving labeled instances: " + classified_data_output_filename);     
    84206        FileWriter fw = new FileWriter(classified_data_output_filename);
  • main/trunk/model-sites-dev/mars/WekaTrainArousalModel.java

    r34774 r34778  
    3636    if (args.length != 2) {
    3737        System.err.println("Error: incorrect number of command-line arguments");
    38         System.err.println("Usage: input_training_data.arff output-model.ser");
     38        System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
    3939        System.exit(1);
    4040    }
     
    5151        //String[] filter_options = weka.core.Utils.splitOptions("-R 472");
    5252        //String[] filter_options = weka.core.Utils.splitOptions("-R 472 -R 458-466");
    53         String[] filter_options = weka.core.Utils.splitOptions("-R 472,458-466");
     53        String[] filter_options = weka.core.Utils.splitOptions("-R 425,458-466,472"); // remove ordinal attributes and 'valance'
    5454        Remove filter_remove = new Remove();
    5555        filter_remove.setOptions(filter_options);
Note: See TracChangeset for help on using the changeset viewer.