Regression με την ml5.js

Για να ακολουθήσετε αυτό το tutorial θα πρέπει πρώτα να έχετε ολοκληρώσει το προηγούμενο tutorial Classification με την ml5

Στον κόσμο του machine learning ο όρος Regression δεν διαφέρει και πολύ από αυτόν της στατιστικής ανάλυσης. Αυτό που μας ενδιαφέρει, από τη σκοπιά της ml5, είναι ότι το Regression είναι ένας τύπος προβλήματος στο οποίο θέλουμε σαν αποτέλεσμα μια και μόνο τιμή, αντί των βαθμολογημένων ετικετών που ήταν αποτέλεσμα του Classification προβλήματος που είδαμε σε προηγούμενο tutorial.

Θα συνεχίσουμε λοιπόν από το προηγούμενο tutorial και θα κάνουμε κάτι λίγο διαφορετικό. Η κάθε νότα όπως είδαμε έχει μια συγκεκριμένη συχνότητα. Μπορούμε να υπολογίσουμε με την ml5 μια "μέση" τιμή συχνότητας με βάση το σημείο X, Y που επιλέγουμε κάθε φορά με τη χρήση του ποντικιού και ανάλογα το πόσο μακρυά ή κοντα βρίσκεται από άλλες νότες.

Αυτό το πρόβλημα είναι τύπου Regression γιατί ζητούμε μια μόνο τιμή (την πρόβλεψη) και για να το λύσουμε αυτο με την ml5 αρκεί να κάνουμε μερικές μικρές τροποποιήσεις στον κώδικα του προηγούμενου tutorial

Αρχικές ρυθμίσεις

Για αρχή πάμε στην μέθοδο setup() και αλλάζουμε τα options του νευρωνικού δικτύου μας. Εντοπίζουμε αυτό το σημείο στον κώδικα:

function setup() {
    //...
    let nnOptions = {
        inputs: ['x', 'y'],
        outputs: ['label'],
        task: 'classification',
        debug: 'true',
    };  

    nn = ml5.neuralNetwork(nnOptions);
}

Και κάνουμε την εξής αλλαγή:

function setup() {
    //...
    let nnOptions = {
        inputs: ['x', 'y'],
        outputs: ['frequency'],
        task: 'regression',
        debug: 'true',
    };  

    nn = ml5.neuralNetwork(nnOptions);
}

Έτσι λεμε στο νευρωνικό μας δίκτυο οτι θελουμε να λυσουμε ενα προβλημα τυπου Regression και πως το αποτέλεσμα θα είναι μια frequency.

Συλλογή δεδομένων

Κατόπιν πρέπει να δώσουμε τα σωστά δεδομένα στο μοντέλο μας προκειμένου να εκπαιδευτεί. Οπότε στην κατάσταση συλλογής δεδομένων πρέπει να κάνουμε κάποιες αλλαγές. 

Εντοπίζουμε τον παρακάτω κώδικα

if (appState == 'collect') {
    
        let outputValue = {
            note: currentNote
        }
}
Και κάνουμε την εξής αλλαγή:
if (appState == 'collect') {
    
    let outputValue = {
        frequency: notesToFrequencies[currentNote]
    }
}

Πρόβλεψη αντι ταξινόμησης

Επίσης κάνουμε αυτή τη μικρή αλλαγή για να πούμε στο nn ότι θέλουμε να κάνουμε μια πρόβλεψη και άρα να μας επιστρέψει μια τιμή μονο, αντί της ταξινόμησης που επιστρέφει κατηγορίες με βαθμολογία.

Έτσι αντι για αυτό

  } else if (appState == 'predict') {
    
        nn.classify(inputValues, onGotResults);

  }

Έχουμε αυτό

  } else if (appState == 'predict') {
    
        nn.predict(inputValues, onGotResults);

  }

Πόσο θα εκπαιδευτεί ο αλγόριθμος;

Από ότι φαίνεται τα προβλήματα τυπου Regression είναι πιο εύκολα για τον αλγόριθμο μας , οπότε δεν χρειαζόμαστε πολλές εκπαιδευτικές εποχές (epochs)

    let trainingOptions = {
        epochs: 50
    }

Αποτέλεσμα της πρόβλεψης

Για να δείξουμε το αποτέλεσμα της πρόβλεψης θα χρησιμοποιήσουμε και πάλι εικόνα και ήχο.

Πρώτα θα δημιουργήσουμε μια helper μέθοδο που θα μας βοηθήσει να ζωγραφίσουμε την συχνότητα που προβλέφθηκε από τον αλγόριθμο, την drawFrequency()

function drawFrequency (frequency) {
    stroke(0);
    fill(0, 0, 255, 100);
    ellipse(mouseX, mouseY, 24);
    fill(0);
    noStroke();
    textAlign(CENTER, CENTER);
    text(floor(frequency), mouseX, mouseY);
}

Έπειτα θα τροποποιήσουμε λίγο την μέθοδο που εκτελείται μόλις ο αλγόριθμος προβλέψει (επιστρέψει) μιά τιμή (συχνότητα), την onGotResults()

Οπότε έχουμε στον προηγούμενο μας κώδικα:

function onGotResults(error, results) {
  if (error) {
    console.error(error);
    return;
  }
  console.log(results);
  let note = results[0].label
  drawNote (note, mouseX, mouseY);
  playNote (notesToFrequencies[note]);
  
}

Και το αλλάζουμε σε:

function onGotResults(error, results) {
  if (error) {
    console.error(error);
    return;
  }
  console.log(results);
  let noteFrequnecy = results[0].value;
  drawFrequency (noteFrequnecy);
  playNote (noteFrequnecy);
  
}

Τελικό αποτέλεσμα

Μπορείτε να δείτε παρακάτω ολοκληρωμένο τον κώδικα καθώς και να τον τρέξετε στον p5 web editor

Δείτε στον p5 editor
let notesToFrequencies = {
  C: 261.6,
  D: 293.6,
  E: 329.6,
  F: 349.2,
  G: 391.9,
  A: 440.0,
  B: 493.8
}

let appState ='collect'; // collect, predict
let nn;
let currentNote = 'C';
let envelope, oscillator;

function setup() {
    createCanvas(400, 400);
    background(200);
    
    envelope = new p5.Envelope();
    envelope.setADSR(0.05, 0.1, 0.5, 1);
    envelope.setRange(1.2, 0);

    oscillator = new p5.Oscillator();
    oscillator.setType('sine');
    oscillator.freq(440);
    oscillator.amp(envelope);
    oscillator.start();

    let nnOptions = {
        inputs: ['x', 'y'],
        outputs: ['frequency'],
        task: 'regression',
        debug: 'true',
    };  

    nn = ml5.neuralNetwork(nnOptions);
}



function keyPressed() {
    switch (key) {
        case 't' :
            console.log('training model with collected data');
            trainData();
            break;
        default:
            currentNote = key.toUpperCase();
            break;
    }
}

function trainData () {
    nn.normalizeData();
    let trainingOptions = {
        epochs: 50
    }
    nn.train(trainingOptions, onEachEpoch, trainingFinished);
}

function onEachEpoch(epoch, loss) {
    console.log('Epoch ' + epoch);
}

function trainingFinished() {
    console.log('training finished.');
    appState = 'predict';
}

function mousePressed() {
  let inputValues = {
    x: mouseX,
    y: mouseY
  }
  if (appState == 'collect') {
    
    let currrentFrequency = notesToFrequencies[currentNote];
    let outputValue = {
        frequency: currrentFrequency
    }

    nn.addData(inputValues, outputValue);

    drawNote (currentNote, mouseX, mouseY);
    playNote(currrentFrequency);

  } else if (appState == 'predict') {
    
        nn.predict(inputValues, onGotResults);

  }
}

function onGotResults(error, results) {
  if (error) {
    console.error(error);
    return;
  }
  console.log(results);
  let noteFrequnecy = results[0].value;
  drawFrequency (noteFrequnecy);
  playNote (noteFrequnecy);
}

//helpers
function drawNote (noteName, x, y) {
    stroke(0);
    noFill();
    ellipse(x, y, 24);
    fill(0);
    noStroke();
    textAlign(CENTER, CENTER);
    text(noteName, x, y);
}

function playNote (noteFrequency) {
    oscillator.freq(noteFrequency);
    envelope.play();
}

function drawFrequency (frequency) {
    stroke(0);
    fill(0, 0, 255, 100);
    ellipse(mouseX, mouseY, 24);
    fill(0);
    noStroke();
    textAlign(CENTER, CENTER);
    text(floor(frequency), mouseX, mouseY);
}