Multithreaded Machine Learning Training & Inference in browser using Tensorflow.js & Comlink.js

Figure 0: Shows the evolutions of web apps and highlights the intent of this blog in red boxes — PWA with TFJS
Figure 1: Eventually we need the machine to automatically figure out the body of digitRecognizer function


As we plan to train in a browser, we cannot download hundreds of training image files as each network round trip is hundred times slower when compared to loading from local disk unlike in usual training on backend servers.

Code Block 1: Code to parse MNIST array buffer as specified in
Code Block 2: A worker class that encapsulates both the model and training data
Layer (type) Output shape Param #
conv2d_Conv2D1 (Conv2D) [null,26,26,8] 80
max_pooling2d_MaxPooling2D1 [null,13,13,8] 0
conv2d_Conv2D2 (Conv2D) [null,11,11,16] 1168
max_pooling2d_MaxPooling2D2 [null,5,5,16] 0
flatten_Flatten1 (Flatten) [null,400] 0
dense_Dense1 (Dense) [null,128] 51328
dense_Dense2 (Dense) [null,10] 1290
Total params: 53866
Trainable params: 53866
Non-trainable params: 0
Code Block 3: The train method of VisionModelWorker class
Figure 2: Loss and Accuracy curves during training visualized using tfjs-vis library

However, as a worker cannot access DOM, tfjs-vis callback function handlers will not run in it directly. So we need a way to define tfjs-vis functions in main thread and invoke them with model loss and accuracy metrics data from a worker thread.

Code Block 4: Communicate visualization information between main and worker threads using Comlink
class VisionModelWorker {
// ... constructor like before
async create(saveToDisk = false) {
this.model = tf.sequential();
// ... create the model architecture
if (saveToDisk) {
// ... other methods
Code Block 5: Visualize Model Summary on the main thread
Figure 3: Model Summary visualized in DOM by tfjs-vis library
Figure 4: Screencast that show GPU utilization while training

Use the trained model — Inference

Code Block 6: Shows how to create a canvas which lets users draw anything
// On Main Thread
const rawImage = document.getElementById(imgPlaceholderId);
function captureImage(e) {
rawImage.src = self.canvas.toDataURL('image/png');
const rawImgArr = tf.browser.fromPixels(rawImage, 1).arraySync();
const result = await visionModelWorker.predict(rawImgArr);
// In Worker
function predict(rawImgArr) {
const x = tf.tensor3d(rawImgArr);
... // predict logic
const resized = tf.image.resizeBilinear(x, [imgWidth, imgHeight]);        const tensor = resized.expandDims(0);        
const probabilities = model.predict(tensor).arraySync();
const [prediction, confidence] = [
tf.argMax(probabilities, 1).arraySync(),
tf.max(probabilities, 1).arraySync()
Figure 5: Shows user’s hand written digit recognition in action



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store