Classification με την ml5.js (5/5)

Βήμα 3ο: Πρόβλεψη

Είμαστε σχεδόν έτοιμοι να κανουμε τις προβλέψεις μας. Μας απομένει να δώσουμε την εντολή στο nn μας και έπειτα θα πρέπει να βρούμε έναν τρόπο να απεικονίσουμε την πρόβλεψη.

Πρώτα όμως ας δώσουμε στο πρόγραμμά μας τη δυνατότητα να ξέρει σε ποια κατάσταση (state) βρίσκεται. Για να το πετύχουμε αυτό φτιάχνουμε μια μεταβλητή για να παρακολουθούμε την κατάσταση που βρίσκεται το app μας, την appState. Στην αρχή του προγράμματος θέλουμε κατευθείαν αυτο να βρίσκεται στην κατάσταση "συλλογής στοιχείων" ή αλλιώς σε "collect" state. 

let nn;
let currentNote = 'C';
let appState ='collect';

Κάθε φορά που κάνουμε κλικ με το ποντίκι θέλουμε σε πρώτη φάση να συλλέγουμε δεδομένα και σε άλλη φάση να κάνουμε την πρόβλεψη. Αυτο θα το κάνουμε με την χρήση μιας απλής if...else if.

function mousePressed() {
  let inputValues = {
    x: mouseX,
    y: mouseY
  }
  if (appState == 'collect') {
    
    let outputValue = {
        note: currentNote
    }

    nn.addData(inputValues, outputValue);

    drawNote (currentNote, mouseX, mouseY);

  } else if (appState == 'predict') {
    
       //predict
  }
}

Επίσης θέλουμε να αλλάξουμε το state όταν ξεκινάει και όταν τελειώνει η εκπαίδευση

function keyPressed() {

  switch (key) {
    case 't':
      console.log('training model with collected data');
      appState = 'train'
      trainData();
      break;
      ...

function trainingFinished() {
    
    console.log('training finished.');
    appState = 'predict';
  
}

Τώρα είμαστε έτοιμοι ώστε να πούμε στο nn μας να κάνει προβλέψεις κάνοντας κλικ και με βάση (input) την τρέχουσα θέση του ποντικιού, καλώντας την μέθοδο classify() του nn μας.

} else if (appState == 'predict') {
    
        nn.classify(inputValues, onGotResults);

  }
function onGotResults(error, results) {
  
  if (error) {
    console.error(error);
    return;
  }
  
  console.log(results);
 
}

Χρησιμοποιούμε την classify γιατί το πρόβλημα είναι τέτοιου τύπου (classification). Στο μέλλον θα δούμε και άλλου τύπου προβλήματα όπως είναι το Regression.

Όπως προσέξατε η μέθοδος classify() παίρνει 2 παραμέτρους.

Η πρώτη είναι το input της πρόβλεψης - ταξινόμησης. Εκεί περνάμε τη μεταβλητή inputValues η οποία έχει καθοριστεί λίγο πιο πάνω στη μέθοδο mousePressed() και περιέχει τις θέσεις x, y του δείκτη του ποντικιού πάνω στον καμβά μας.

Η  δεύτερη παράμετρος πρόκειται για μια callback function η οποία καλείται μόλις γίνει η πρόβλεψη από το nn μας και μας επιστρέφει τα αποτελέσματα της πρόβλεψης σε ένα πίνακα

Αν τώρα αρχίσουμε να κάνουμε κλίκ σε διάφορα σημεία του καμβά , θα δούμε στην κονσόλα μας τα αποτελέσματα (results). Ας ρίξουμε μια ματιά σε ένα από αυτά

0: Object
	label: "C"
	confidence: 0.7744138240814209
1: Object
	label: "E"
	confidence: 0.14987757802009583
2: Object
	label: "D"
	confidence: 0.07570852339267731

Για να παράγω αυτό το αποτέλεσμα σημαίνει ότι εκπαίδευσα το μοντέλο μού με 3 μόνο νότες (C, D, E), για αυτό βλέπω 3 αποτελέσματα. Αν το είχα κάνει και με τις 7 νότες, τότε θα είχα 7 αποτελέσματα στον πίνακα results. Επίσης θα παρατηρήσατε ότι από μόνη της η ml5 ονόμασε την κάθε έξοδο του nn μας ανάλογα με το Input που της δώσαμε για εκπαίδευση, όπως φυσικά και τον αριθμό των outputs.

Το κάθε στοιχείο στον πίνακα results περιέχει το όνομα της κλάσης (label) και το πόσο σίγουρο είναι το nn μας(confidence) ότι ήταν η σωστή απάντηση (πρόβλεψη) για το συγκεκριμένο input. Αν προσθέσουμε όλα τα confidence scores θα μας δώσουν αποτέλεσμα 1 (100%).

Εμάς στην ουσία μας ενδιαφέρει το στοιχείο στον πίνακα το οποίο έχει το μεγαλύτερο score και επειδή η ml5 επιστρέφει τον πίνακα ταξινομημένο κατα φθίνουσα σειρά score, αυτό που θέλουμε είναι το στοιχείο 0 του πίνακα (results[0]).

Απεικονίζοντας το αποτέλεσμα

Θα ήταν πιο πρακτικό αν μπορούσαμε να δούμε τα αποτελέσματα επιτόπου στον καμβά μας. Για αυτό λοιπόν θα φτιάξουμε μια δικη μας μέθοδο που θα ζωγραφίζει πάλι έναν κύκλο με το όνομα της νότας (που προβλέφθηκε) και θα τον χρωματίζει με μωβ χρώμα.

function onGotResults(error, results) {

  if (error) {
    console.error(error);
    return;
  }

  console.log(results);

  let note = results[0].label
  drawNotePrediction(note, mouseX, mouseY);

}
//helpers
...
function drawNotePrediction (noteName, x, y) {
    stroke(0);
    fill(0, 0, 255, 100);
    ellipse(x, y, 24);
    fill(0);
    noStroke();
    textAlign(CENTER, CENTER);
    text(noteName, x, y);
}

Μπορείτε να δείτε το τελικό αποτέλεσμα και να πειραματιστείτε με την εφαρμογή μας στον p5 web editor

Δείτε στον p5 editor