Multithreaded Machine Learning Training & Inference in browser using Tensorflow.js & Comlink.js

Figure 0: Shows the evolutions of web apps and highlights the intent of this blog in red boxes — PWA with TFJS
  1. lower latency as raw user’s data need not be transmitted over the wire to fetch the model’s prediction from server
  2. better users’ privacy as raw user’s data need not leave the device for a prediction
  3. cost efficient for the company as some of server’s computation is transferred to the end user
Figure 1: Eventually we need the machine to automatically figure out the body of digitRecognizer function


  1. Prepare the training data
  2. Create a model architecture
  3. Train the model
  4. Analyze the results and repeat steps 1–3

Prepare the data

As we plan to train in a browser, we cannot download hundreds of training image files as each network round trip is hundred times slower when compared to loading from local disk unlike in usual training on backend servers.

Code Block 1: Code to parse MNIST array buffer as specified in

Create a model architecture

Code Block 2: A worker class that encapsulates both the model and training data
Layer (type) Output shape Param #
conv2d_Conv2D1 (Conv2D) [null,26,26,8] 80
max_pooling2d_MaxPooling2D1 [null,13,13,8] 0
conv2d_Conv2D2 (Conv2D) [null,11,11,16] 1168
max_pooling2d_MaxPooling2D2 [null,5,5,16] 0
flatten_Flatten1 (Flatten) [null,400] 0
dense_Dense1 (Dense) [null,128] 51328
dense_Dense2 (Dense) [null,10] 1290
Total params: 53866
Trainable params: 53866
Non-trainable params: 0

Train the Model

Code Block 3: The train method of VisionModelWorker class

Analyze the Results

Figure 2: Loss and Accuracy curves during training visualized using tfjs-vis library

However, as a worker cannot access DOM, tfjs-vis callback function handlers will not run in it directly. So we need a way to define tfjs-vis functions in main thread and invoke them with model loss and accuracy metrics data from a worker thread.

Code Block 4: Communicate visualization information between main and worker threads using Comlink
class VisionModelWorker {
// ... constructor like before
async create(saveToDisk = false) {
this.model = tf.sequential();
// ... create the model architecture
if (saveToDisk) {
// ... other methods
Code Block 5: Visualize Model Summary on the main thread
Figure 3: Model Summary visualized in DOM by tfjs-vis library
Figure 4: Screencast that show GPU utilization while training

Use the trained model — Inference

  1. Provide a HTML5 canvas for user to draw a digit
  2. Capture the canvas output as a 3d tensor
  3. Show the prediction and the probability on UI

User Input via Canvas

Code Block 6: Shows how to create a canvas which lets users draw anything

Canvas Image to Tensor

// On Main Thread
const rawImage = document.getElementById(imgPlaceholderId);
function captureImage(e) {
rawImage.src = self.canvas.toDataURL('image/png');
const rawImgArr = tf.browser.fromPixels(rawImage, 1).arraySync();
const result = await visionModelWorker.predict(rawImgArr);
// In Worker
function predict(rawImgArr) {
const x = tf.tensor3d(rawImgArr);
... // predict logic

Finale: Show Model Output

const resized = tf.image.resizeBilinear(x, [imgWidth, imgHeight]);        const tensor = resized.expandDims(0);        
const probabilities = model.predict(tensor).arraySync();
const [prediction, confidence] = [
tf.argMax(probabilities, 1).arraySync(),
tf.max(probabilities, 1).arraySync()
Figure 5: Shows user’s hand written digit recognition in action




Applied Deep Learning Engineer | LinkedIn

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

If you don’t track your machine learning experiments, you’re cursed to lose all your time

The business learns NLP (Part 3): Word embeddings, or, how to represent words as vectors

Finding Lane Lines — Udacity Self Driving Nanodegree Project 1

Feature Engineering Guide for Beginners…(Part-1)

Application of modern deep networks in a Single Image Super-Resolution task

Quantum Computation on tensor network

Residual Networks

Using Reinforcement Learning techniques to build an AI bot for the game Flappy Bird

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Nitin Pasumarthy

Nitin Pasumarthy

Applied Deep Learning Engineer | LinkedIn

More from Medium

Naïve Bayes vs. SVM for Image Classification

How to reverse integer using modulus operator(Python)

Comparison of Basic Deep Learning Cloud Platforms

student study material engagement prediction model using weka