Ignore:
Timestamp:
2021-02-02T00:46:14+13:00 (3 years ago)
Author:
davidb
Message:

Code refactored, and then valence version of training and applying model developed

Location:
main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars
Files:
2 added
3 edited

Legend:

Unmodified
Added
Removed
  • main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaApplyArousalModel.java

    r34786 r34788  
    2828class WekaApplyArousalModel
    2929{
    30     //public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
    31     //public final static String VALANCE_ATTRIBUTE_NAME = "valance_sample_26500ms";
    3230
    33     public final static String PREDICT_ATTRIBUTE_NAME = WekaUtil.AROUSAL_ATTRIBUTE_NAME;
     31    //public final static String PREDICT_ATTRIBUTE_NAME = WekaUtil.AROUSAL_ATTRIBUTE_NAME;
    3432   
    3533
    36     /*
    37     public static Instances applyFilter(Instances data_instances,String additional_remove)
    38     {
    39     String remove_option_args = ("-R 425,458-466"); // remove ordinal attributes
    40     if (additional_remove != null) {
    41         // top up, e.g. '472' representing valance in ground-truth files
    42         remove_option_args += "," + additional_remove;
    43     }
    44 
    45 
    46     Instances filtered_data_instances = null;
    47     try {
    48         String[] filter_options = weka.core.Utils.splitOptions(remove_option_args); // remove ordinal attributes and 'valance'
    49         Remove filter_remove = new Remove();
    50         filter_remove.setOptions(filter_options);
    51         filter_remove.setInputFormat(data_instances);
    52         filtered_data_instances = Filter.useFilter(data_instances, filter_remove);
    53     }
    54     catch (Exception e) {
    55         e.printStackTrace();
    56         System.exit(1);
    57     }
    58    
    59     return filtered_data_instances;
    60     }
    61     */
    62    
    63     /*
    64    
    65     public static void checkDatasetInstancesCompatible(Instances new_instances)
    66     {
    67 
    68     final String gt_datasource_filename = "collect/deam/etc/deam-essentia-features-arousal-valence.arff";
    69 
    70     try {
    71         DataSource gt_data_source = new DataSource(gt_datasource_filename);
    72         Instances gt_instances = gt_data_source.getDataSet();
    73         gt_instances = applyFilter(gt_instances,"472"); // top up with removing 'valance'
    74        
    75         gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
    76 
    77         String equal_header_message = gt_instances.equalHeadersMsg(new_instances);
    78        
    79         if (equal_header_message != null) {
    80         System.err.println("Headers to datasets were not equal!");
    81         System.err.println(equal_header_message);
    82         System.exit(1);
    83         }
    84     }
    85     catch (Exception e) {
    86         e.printStackTrace();
    87         System.exit(1);
    88     }
    89 
    90     }
    91     */
    9234   
    9335    public static void main(String[] args)
    9436    {
     37    /*
    9538    if (args.length != 3) {
    9639        System.err.println("Error: incorrect number of command-line arguments");
     
    9841        System.exit(1);
    9942    }
    100        
     43    */
     44
     45    WekaUtil.checkUsageApplyModel(args);
     46   
    10147    String classifier_input_filename        = args[0];
    10248    String unclassified_data_input_filename = args[1];
    10349    String classified_data_output_filename  = args[2];
    10450
     51
     52    Classifier classifier = WekaUtil.loadClassifierModel(classifier_input_filename);
     53
     54    /*
    10555        System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
    10656
     
    152102        WekaUtil.checkDatasetInstancesCompatible(unlabeled_instances);
    153103
     104
     105
     106
     107       
    154108        // Create copy where the predictions are to be made
    155109        Instances labeled_instances = new Instances(unlabeled_instances);
     
    199153        e.printStackTrace();
    200154    }
    201        
     155
     156    */
     157
     158
     159    Instances unlabeled_instances= WekaUtil.loadInstancesForClassification(unclassified_data_input_filename);
     160
     161    // It is permissible to run this code and supply it with a data file that includes groundtruth in it.
     162    // In this situation, the 'unlabeled' instances:
     163    //  (i) need to be massaged to be in the same form as truly unlabeled data
     164    // (ii) we also set up 'groundtruth_instances' as an alias (reference) to 'filtered_unlabeled_instanced'
     165    //      to trigger calculating the error on the predicted vaues
     166
     167    boolean has_groundtruth_data = WekaUtil.instancesHavePredictAttribute(unlabeled_instances,WekaUtil.AROUSAL_ATTRIBUTE_NAME);
     168
     169    // The following deals with (i) internally, ensuring that what is returned is suitable for making predictions on
     170    Instances filtered_unlabeled_instances
     171        = WekaUtil.filterInstancesForApplying(unlabeled_instances,has_groundtruth_data,
     172                          WekaUtil.AROUSAL_ATTRIBUTE_NAME,"472");
     173
     174    // The following deals with (ii)
     175    Instances groundtruth_instances = (has_groundtruth_data) ? filtered_unlabeled_instances : null;
     176   
     177    Instances labeled_instances = WekaUtil.makePredictions(classifier, filtered_unlabeled_instances, groundtruth_instances);
     178       
     179    try {
     180        // Save labeled data
     181       
     182        System.out.println("Saving labeled instances: " + classified_data_output_filename);     
     183        FileWriter fw = new FileWriter(classified_data_output_filename);
     184        BufferedWriter bw = new BufferedWriter(fw);
     185       
     186        bw.write(labeled_instances.toString());
     187        bw.newLine();
     188        bw.flush();
     189        bw.close();
     190             
     191    }
     192    catch (Exception e) {
     193        e.printStackTrace();
     194    }
     195       
    202196    }
    203197}
  • main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaTrainArousalModel.java

    r34786 r34788  
    11package org.greenstone.mars;
    22   
    3 import java.util.Random;
    4 
    5 import weka.core.converters.ConverterUtils.DataSource;
     3//import weka.core.converters.ConverterUtils.DataSource;
    64import weka.core.Instances;
    75 
    8 import weka.filters.Filter;
    9 import weka.filters.unsupervised.attribute.Remove;
     6//import weka.filters.Filter;
     7//import weka.filters.unsupervised.attribute.Remove;
    108
     9import weka.classifiers.Classifier;
    1110import weka.classifiers.Evaluation;
    12 import weka.classifiers.trees.REPTree;
    1311
    1412// Based on:
     
    3028class WekaTrainArousalModel
    3129{
    32     // Scheme:       weka.classifiers.trees.REPTree -M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0
     30    // Scheme:       weka.classifiers.trees.REPTree -M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0   -- (i.e., default vals)
    3331    // Relation:     deam-essentia-features-arousal-valence-weka.filters.unsupervised.attribute.Remove-R472-weka.filters.unsupervised.attribute.Remove-R458-466
    3432    // Instances:    1743
     33
     34    // Note, above missed -R 425 and 427 (ordinal vals)
    3535   
    3636    public static void main(String[] args)
    3737    {
     38    /*
    3839    if (args.length != 2) {
    3940        System.err.println("Error: incorrect number of command-line arguments");
    4041        System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
    4142        System.exit(1);
     43        }*/
     44   
     45    WekaUtil.checkUsageTraining(args);
     46   
     47    String input_data_filename = args[0];
     48    String output_model_filename = args[1];
     49
     50    // knock out 472 (valance) from dataset
     51    Instances filtered_data_instances = WekaUtil.loadAndFilterDataInstances(input_data_filename, "472");
     52    Classifier classifier = WekaUtil.trainREPTree(filtered_data_instances);
     53    System.out.println(classifier);
     54
     55    Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances);
     56    System.out.println(eval.toSummaryString());
     57
     58    try {
     59        System.out.println("Saving REPTree classifier model as: " + output_model_filename);
     60        weka.core.SerializationHelper.write(output_model_filename, classifier);
    4261    }
    43     String input_arff_filename = args[0];
    44     String output_model_filename = args[1];
     62    catch (Exception e) {
     63        e.printStackTrace();
     64    }
    4565   
     66
     67    /*
    4668        System.out.println("Training on ARFF file: " + input_arff_filename);
    4769
     
    4971        DataSource data_source = new DataSource(input_arff_filename);
    5072        Instances data_instances = data_source.getDataSet();
    51        
     73
     74        // *********
    5275        Instances filtered_data_instances = WekaUtil.applyFilter(data_instances,"472"); // top-up with removal of 'valance'
    5376
     
    6588        System.out.println(classifier);
    6689
     90        / *
    6791        // Evaluate     
    6892        Evaluation eval = new Evaluation(filtered_data_instances);
     
    7195        eval.crossValidateModel(classifier, filtered_data_instances, folds, rand);
    7296        System.out.println(eval.toSummaryString());
     97        * /
     98       
     99        Evaluation eval = WekaUtil.evaluateClassifier(classifier,filtered_data_instances);
     100        System.out.println(eval.toSummaryString());
    73101
    74102        System.out.println("Saving REPTree classifier model as: " + output_model_filename);
     
    80108        e.printStackTrace();
    81109    }
    82  
     110    */
    83111    }
    84112}
  • main/trunk/model-sites-dev/mars/src/java/org/greenstone/mars/WekaUtil.java

    r34786 r34788  
    88import java.io.FileReader;
    99import java.io.FileWriter;
     10
     11import java.util.Random;
    1012
    1113import weka.core.converters.ConverterUtils.DataSource;
     
    1921
    2022import weka.classifiers.Classifier;
    21 
     23import weka.classifiers.Evaluation;
     24
     25import weka.classifiers.trees.REPTree;
    2226
    2327// Based on:
     
    2731{
    2832    public final static String AROUSAL_ATTRIBUTE_NAME = "arousal_sample_26500ms";
    29     public final static String VALANCE_ATTRIBUTE_NAME = "valance_sample_26500ms";
    30 
    31 
     33    public final static String VALANCE_ATTRIBUTE_NAME = "valence_sample_26500ms";
     34
     35
     36    public static void checkUsageTraining(String[] args)
     37    {
     38    if (args.length != 2) {
     39        System.err.println("Error: incorrect number of command-line arguments");
     40        System.err.println("Usage: input_training_data.arff output-model.{model|ser}");
     41        System.exit(1);
     42    }
     43   
     44    }
     45   
     46    public static void checkUsageApplyModel(String[] args)
     47    {
     48        if (args.length != 3) {
     49        System.err.println("Error: incorrect number of command-line arguments");
     50        System.err.println("Usage: trained-model.{model|ser} unclassified-data.{arff|csv} classified-data.{arff|csv}");
     51        System.exit(1);
     52    }
     53    }
     54   
    3255    public static Instances applyFilter(Instances data_instances, String additional_remove)
    3356    {
     
    4366    // The following filtering produces as tree with 41 nodes
    4467   
    45     String remove_option_args = ("-R 425,427,458-466");
     68    String remove_option_args = ("-R 425,427,458-466"); // knock out 11 attributes (472 -> 461)
    4669
    4770   
     
    5174
    5275    if (additional_remove != null) {
    53         // top up, e.g. '472' representing valance in ground-truth files
     76        // top up,
     77        //   e.g. '471' representing arousal in ground-truth files
     78        //   e.g. '472' representing valance in ground-truth files
    5479        remove_option_args += "," + additional_remove;
    5580    }
     
    5883    Instances filtered_data_instances = null;
    5984    try {
    60         // remove ordinal attributes and any additional topups, such as 'valance' (when predicting 'arousal')
     85        // remove ordinal attributes and any additional topups,
     86        // such as 'valance' (when predicting 'arousal') and vice versa
    6187       
    6288        String[] filter_options = weka.core.Utils.splitOptions(remove_option_args);
     
    75101
    76102   
    77     public static void checkDatasetInstancesCompatible(Instances new_instances)
     103    public static void checkDatasetInstancesCompatible(Instances new_instances, String additional_remove)
    78104    {
    79105
     
    83109        DataSource gt_data_source = new DataSource(gt_datasource_filename);
    84110        Instances gt_instances = gt_data_source.getDataSet();
    85         gt_instances = applyFilter(gt_instances,"472"); // top up with removing 'valance'
     111        gt_instances = applyFilter(gt_instances,additional_remove); // remove 'valance' or 'arousal'
    86112       
    87113        gt_instances.setClassIndex(gt_instances.numAttributes() - 1);
     
    101127
    102128    }
     129
     130
     131    public static Instances loadAndFilterDataInstances(String input_filename, String additional_attribute_remove)
     132    {
     133        System.out.println("Training on file: " + input_filename);
     134    Instances filtered_data_instances = null;
     135   
     136    try {
     137        DataSource data_source = new DataSource(input_filename);
     138        Instances data_instances = data_source.getDataSet();
     139
     140        // Training dataset has two ground-truth attributes: 'arousal' and 'valance'.
     141        // When training for one, need to knock out the other.  This is the purpose
     142        // of 'additional_attribute_remove'
     143        filtered_data_instances = applyFilter(data_instances,additional_attribute_remove);
     144
     145        // With the 'other' ground-truth attribute taken out, the column to predict
     146        // will always be the last column
     147        filtered_data_instances.setClassIndex(filtered_data_instances.numAttributes()-1);       
     148    }
     149    catch (Exception e) {
     150        e.printStackTrace();
     151        System.exit(1);
     152    }
     153   
     154    return filtered_data_instances;
     155    }
     156
     157    public static Classifier trainREPTree(Instances data_instances)
     158    {
     159    REPTree classifier = null;
     160   
     161    try {
     162        // Build classifier
     163        classifier = new REPTree();
     164        String[] reptree_options = weka.core.Utils.splitOptions("-M 2 -V 0.001 -N 3 -S 1 -L -1 -I 0.0");
     165        classifier.setOptions(reptree_options);
     166
     167        classifier.buildClassifier(data_instances);
     168    }
     169    catch (Exception e) {
     170        e.printStackTrace();
     171        System.exit(1);
     172    }
     173
     174    return classifier;
     175    }
     176
     177
     178    public static Evaluation evaluateClassifier(Classifier classifier, Instances data_instances)
     179    {
     180    Evaluation eval = null;
     181    try {
     182        eval = new Evaluation(data_instances);
     183        Random rand = new Random(1);
     184        int folds = 10;
     185       
     186        eval.crossValidateModel(classifier, data_instances, folds, rand);
     187    }
     188    catch (Exception e) {
     189        e.printStackTrace();
     190        System.exit(1);
     191    }
     192   
     193    return eval;
     194    }
     195
     196    /* Apply Model Specific Methods */
     197
     198    public static Classifier loadClassifierModel(String classifier_input_filename
     199
     200                           )
     201    {
     202        System.out.println("Loading Weka saved Classifier: " + classifier_input_filename);
     203    Classifier classifier = null;
     204    try {
     205        FileInputStream fis = new FileInputStream(classifier_input_filename);   
     206        BufferedInputStream bis= new BufferedInputStream(fis);
     207        classifier = (Classifier)SerializationHelper.read(bis);
     208    }
     209    catch (Exception e) {
     210        e.printStackTrace();
     211        System.exit(1);
     212    }
     213
     214    return classifier;
     215    }
     216
     217
     218    public static Instances loadInstancesForClassification(String unlabeled_input_filename)
     219    {
     220    System.out.println("Loading unlabeled instances: " + unlabeled_input_filename);
     221    Instances unlabeled_instances = null;
     222
     223    try {
     224        DataSource data_source = new DataSource(unlabeled_input_filename);
     225        unlabeled_instances = data_source.getDataSet();
     226    }
     227    catch (Exception e) {
     228        e.printStackTrace();
     229        System.exit(1);
     230    }
     231
     232    return unlabeled_instances;
     233    }
     234
     235    public static boolean instancesHavePredictAttribute(Instances data_instances, String predict_attribute_name)
     236    {
     237    Attribute predict_attribute = data_instances.attribute(predict_attribute_name);
     238
     239    return predict_attribute != null;
     240    }
     241
     242   
     243
     244    public static Instances filterInstancesForApplying(Instances unlabeled_instances, boolean has_groundtruth_data,
     245                               String predict_attribute_name, String additional_attribute_remove)
     246    {
     247    Instances filtered_unlabeled_instances = null;
     248
     249       
     250    // Work out if we're dealing with a ground-truth ARFF file or not
     251    // (i.e. already has the desired attribute)
     252
     253    //Attribute predict_attribute = unlabeled_instances.attribute(predict_attribute_name);
     254
     255    if (!has_groundtruth_data) {
     256
     257        filtered_unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,null); // no additional top-up to remove       
     258        int num_attributes = filtered_unlabeled_instances.numAttributes();
     259       
     260        Attribute new_predict_attribute = new Attribute(predict_attribute_name);
     261        filtered_unlabeled_instances.insertAttributeAt(new_predict_attribute,num_attributes);
     262        // ******
     263        //num_attributes++;     
     264    }
     265    else {
     266        // Dealing with ground-truth data:
     267        //   => already has 'arousal' and 'valance' attributes
     268        //   => need to keep the 'predict_attribute_name' and remove the other one
     269        //   => (its -R value of which is specified in 'additional_attribute_remove')
     270       
     271        // Need to massage instances into same form as an unclassified data input file
     272        filtered_unlabeled_instances = WekaUtil.applyFilter(unlabeled_instances,additional_attribute_remove);
     273       
     274        // reference share this as 'groundtruth_instances' to trigger error calculation and output
     275        // ******
     276        // groundtruth_instances = filtered_unlabeled_instances;
     277    }
     278   
     279    int num_attributes = filtered_unlabeled_instances.numAttributes();
     280
     281    // Set class attribute
     282    filtered_unlabeled_instances.setClassIndex(num_attributes - 1);
     283   
     284    // ***** Do I still want to run the check????
     285    WekaUtil.checkDatasetInstancesCompatible(filtered_unlabeled_instances, additional_attribute_remove);
     286   
     287    return filtered_unlabeled_instances;
     288    }
     289
     290
     291    public static Instances makePredictions(Classifier classifier, Instances unlabeled_instances, Instances groundtruth_instances)
     292    {
     293       
     294    // Create copy where the predictions are to be made
     295    Instances labeled_instances = new Instances(unlabeled_instances);
     296
     297    try {
     298        // Label instances
     299   
     300        final int num_instances = unlabeled_instances.numInstances();       
     301        for (int i=0; i<num_instances; i++) {
     302        Instance unlabeled_instance = unlabeled_instances.instance(i);
     303       
     304        System.out.print("Making prediction for: " + i + "/" + num_instances);
     305        double classified_value = classifier.classifyInstance(unlabeled_instance);
     306        labeled_instances.instance(i).setClassValue(classified_value);
     307       
     308        String formatted_classified_value = String.format("% 06.3f", classified_value);
     309       
     310        System.out.print(" value = " + formatted_classified_value);
     311       
     312        if (groundtruth_instances != null) {
     313            Instance gt_instance = groundtruth_instances.instance(i);
     314            double gt_class_value = gt_instance.classValue();
     315            double error = Math.abs(classified_value - gt_class_value);
     316           
     317            String formatted_error = String.format("%.3f", error);
     318            System.out.print("  [error: " + formatted_error + "]");
     319        }
     320        System.out.println();
     321        }
     322    }
     323    catch (Exception e) {
     324        e.printStackTrace();
     325        System.exit(1);
     326    }
     327   
     328    return labeled_instances;
     329    }
    103330   
    104331}
Note: See TracChangeset for help on using the changeset viewer.