Multithreaded Machine Learning Training & Inference in browser using Tensorflow.js & Comlink.js

Figure 0: Shows the evolutions of web apps and highlights the intent of this blog in red boxes — PWA with TFJS
Figure 1: Eventually we need the machine to automatically figure out the body of digitRecognizer function

Training

As we plan to train in a browser, we cannot download hundreds of training image files as each network round trip is hundred times slower when compared to loading from local disk unlike in usual training on backend servers.

Code Block 1: Code to parse MNIST array buffer as specified in http://yann.lecun.com/exdb/mnist/
Code Block 2: A worker class that encapsulates both the model and training data
_________________________________________________________________
Layer (type) Output shape Param #
=================================================================
conv2d_Conv2D1 (Conv2D) [null,26,26,8] 80
_________________________________________________________________
max_pooling2d_MaxPooling2D1 [null,13,13,8] 0
_________________________________________________________________
conv2d_Conv2D2 (Conv2D) [null,11,11,16] 1168
_________________________________________________________________
max_pooling2d_MaxPooling2D2 [null,5,5,16] 0
_________________________________________________________________
flatten_Flatten1 (Flatten) [null,400] 0
_________________________________________________________________
dense_Dense1 (Dense) [null,128] 51328
_________________________________________________________________
dense_Dense2 (Dense) [null,10] 1290
=================================================================
Total params: 53866
Trainable params: 53866
Non-trainable params: 0
_________________________________________________________________
Code Block 3: The train method of VisionModelWorker class
Figure 2: Loss and Accuracy curves during training visualized using tfjs-vis library

However, as a worker cannot access DOM, tfjs-vis callback function handlers will not run in it directly. So we need a way to define tfjs-vis functions in main thread and invoke them with model loss and accuracy metrics data from a worker thread.

Code Block 4: Communicate visualization information between main and worker threads using Comlink
class VisionModelWorker {
// ... constructor like before
async create(saveToDisk = false) {
this.model = tf.sequential();
// ... create the model architecture
if (saveToDisk) {
await this.model.save(`indexeddb://my-model-1`);
}
}
// ... other methods
}
Code Block 5: Visualize Model Summary on the main thread
Figure 3: Model Summary visualized in DOM by tfjs-vis library
Figure 4: Screencast that show GPU utilization while training

Use the trained model — Inference

Code Block 6: Shows how to create a canvas which lets users draw anything
// On Main Thread
const rawImage = document.getElementById(imgPlaceholderId);
function captureImage(e) {
rawImage.src = self.canvas.toDataURL('image/png');
}
const rawImgArr = tf.browser.fromPixels(rawImage, 1).arraySync();
const result = await visionModelWorker.predict(rawImgArr);
// In Worker
function predict(rawImgArr) {
const x = tf.tensor3d(rawImgArr);
... // predict logic
}
const resized = tf.image.resizeBilinear(x, [imgWidth, imgHeight]);        const tensor = resized.expandDims(0);        
const probabilities = model.predict(tensor).arraySync();
const [prediction, confidence] = [
tf.argMax(probabilities, 1).arraySync(),
tf.max(probabilities, 1).arraySync()
];
Figure 5: Shows user’s hand written digit recognition in action

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store