Multithreaded Machine Learning Training & Inference in browser using Tensorflow.js & Comlink.js

Web development has come a long way from static visualizations to complex single page apps (SPA), to the even more recent Progressive Web Apps (PWAs). SPAs allowed complex business logic in a web app and PWAs took it a step further by working without an active internet connection like native Android / iOS / desktop apps. In this blog, I want to show you how we can push the limits of a web app even further by training a Deep Convolutional Neural Network on GPUs completely in a browser and still maintain 60fps. But before that, let us understand why something like this is useful.

Figure 0: Shows the evolutions of web apps and highlights the intent of this blog in red boxes — PWA with TFJS

Machine Learning (ML) is used in situations where the code needs to adapt based on the situation or when it is very hard to solve using traditional algorithms. ML on the browser, or on the edge in general, has several advantages like,

  1. lower latency as raw user’s data need not be transmitted over the wire to fetch the model’s prediction from server
  2. better users’ privacy as raw user’s data need not leave the device for a prediction
  3. cost efficient for the company as some of server’s computation is transferred to the end user

However, one has to be mindful of what is offloaded onto the clients as they have much lower compute and memory capacities than servers. On top of that as JavaScript is single threaded, it is mandatory to leave the main thread for user interactions and rendering, by offloading all expensive ML computations to worker thread(s). This is required for a smooth user experience. In this blog, we will see how to train a model and make predictions with it completely in browser’s worker threads.

The task of our model is to recognize hand written digits between 0 to 9 from raw image pixels. Something like,

Figure 1: Eventually we need the machine to automatically figure out the body of digitRecognizer function

We may attempt to design the digitRecognizer function by defining complex math equations and approximation algorithms for each of the 10 digits. Luckily, modern ML algorithms can see a lots of examples of each digit and figure out the complex equations of each digit for us. We call this step Training, which is computationally expensive and done rarely. Once the algorithm is trained well on this task, we can call it with our own hand writing to test it. This step is called Inference and can be used in realtime with user’s input.


Training or Supervised learning to be more precise, generally consists of four steps,

  1. Prepare the training data
  2. Create a model architecture
  3. Train the model
  4. Analyze the results and repeat steps 1–3

Prepare the data

The training data, of hand written digits and their labels, can be downloaded from its creator’s website at Gathering the right unbiased training data is often challenging and should be done with utmost care as ML algorithms are known to find shortcuts to complete the tasks only to get high accuracies, without learning the underlying concept. Checkout Design effective classes in ML Classification Algorithms and Encoding fixed length high cardinality non-numeric columns for a ML algorithm where I shared more on this topic.

As we plan to train in a browser, we cannot download hundreds of training image files as each network round trip is hundred times slower when compared to loading from local disk unlike in usual training on backend servers.

Luckily MNIST dataset is available as a single binary file with instructions on how to decode each image in the above website. With modern JavaScript concepts like ArrayBuffers, Typed Arrays and DataViews we can easily decode this raw binary format.

Code Block 1: Code to parse MNIST array buffer as specified in

As shared in the dataset specification at, the first 16 bytes contain meta information — images’ width, height and number of images. After that, each byte is a grayscale pixel value between 0 to 255, which we parse using DataView class and save them in a Uint8Array. shows how to invoke this function to create train images and labels, test images and labels. This class is later instantiated on worker threads for training.

Create a model architecture

In a worker.js file, we create and train a deep convolutional neural network model using Tensorflow JS. As a model needs training data, we encapsulate both the model and its data in a class, say VisionModelWorker.

Code Block 2: A worker class that encapsulates both the model and training data

getData() downloads and parses the train and test datasets using parseBuffer() method we created earlier. create() method creates the model architecture using Sequential API similar to tf.keras in Python. This model has two convolutional layers and two max pooling layers, followed by two dense layers. The last dense output layer also converts raw logits to a probability distribution using softmax. model.summary() prints this information in browser console,

Layer (type) Output shape Param #
conv2d_Conv2D1 (Conv2D) [null,26,26,8] 80
max_pooling2d_MaxPooling2D1 [null,13,13,8] 0
conv2d_Conv2D2 (Conv2D) [null,11,11,16] 1168
max_pooling2d_MaxPooling2D2 [null,5,5,16] 0
flatten_Flatten1 (Flatten) [null,400] 0
dense_Dense1 (Dense) [null,128] 51328
dense_Dense2 (Dense) [null,10] 1290
Total params: 53866
Trainable params: 53866
Non-trainable params: 0

While training the model learns optimal values for 53,866 parameters, such that it can recognize “any” hand written with high confidence. However, the types and number of layers, their sizes, arguments for each are all identified by running multiple experiments. The values of these hyperparameters are different for every problem & dataset and their tuning is usually the hardest part of Deep Learning.

Train the Model

Each hand written digit image, which is a 28x28x1 matrix, is passed through the network and undergoes several transformations and eventually becomes a 10x1 matrix. Each value in that matrix corresponds to the probability of each digit between 0–9. This is compared with the actual answer to compute a number called loss / error, i.e. how off is the model from the actual answers. An optimizer function, Adam in this case, looks at this loss value and adjusts all the 53,866 parameters such that the loss is lesser when the same image is passed through the network again. This is repeated for all the images in training set several times, until the average loss is as low as possible. There are many more ML concepts, which are not discussed here and if you are interested to know some more theory behind each, checkout my Summary of Deep Learning Concepts blog.

Code Block 3: The train method of VisionModelWorker class

As shown in Code Block 3, batchSize number of images are trained for epochs number of times for speed. The architecture and training code with some additional methods, we will cover soon can be found at

Analyze the Results

To track the progress of the model while it is training, its loss and accuracy curves can be plotted over time in the browser. The library tfjs-vis makes it very easy to achieve this by hooking into the call backs exposed by

Figure 2: Loss and Accuracy curves during training visualized using tfjs-vis library

However, as a worker cannot access DOM, tfjs-vis callback function handlers will not run in it directly. So we need a way to define tfjs-vis functions in main thread and invoke them with model loss and accuracy metrics data from a worker thread.

This is certainly doable with postMessage() but the code can get cluttered with a lot of these very soon. Here is where Comlink library can come in very handy.

Code Block 4: Communicate visualization information between main and worker threads using Comlink

VisionModel and VisionModelWorker classes have a one-one mapping of functions. When user pushes the Train button, run() method of VisionModel is called in the main thread, which creates the tfjs-vis callback handlers and passes them as function pointers to run() method of VisionModelWorker. After every batch and epoch of training, tfjs in worker invokes the handlers in the main thread. And that’s it! No need of any post messages on either end and Comlink takes care of everything. Wasn’t that neat!

We can save the trained model in localStorage or IndexedDB or download it to your machine using

class VisionModelWorker {
// ... constructor like before
async create(saveToDisk = false) {
this.model = tf.sequential();
// ... create the model architecture
if (saveToDisk) {
// ... other methods

To display model’s architecture in DOM, using tfjs-vis, let us use another technique than callbacks, as tf.LayersModel is not a serializable class, which is a requirement for postMessage() or Comlink. In the main thread we load the model from IndexedDB saved by the worker, create a model instance and pass it to tfjs-vis for visualization.

Code Block 5: Visualize Model Summary on the main thread

And voila, when the user presses the Train button, we can also see the model summary!

Figure 3: Model Summary visualized in DOM by tfjs-vis library

Model Summary and Loss curves are crucial to identify any bugs in training code. Based on results from Loss and Accuracy curves we can plan our next iteration for training. From my experience, this is one of the most challenging & time consuming steps in ML and courses like Structuring Machine Learning Projects really helped me organize my approach. The library tfjs uses web assembly and GPUs on the machine if available to train faster. I observed a 50x speedup by training on GPUs!

Figure 4: Screencast that show GPU utilization while training

And we are done! We have trained a model completely on browser threads on GPUs and visualized the training on the main thread using ComLink. The code for the main thread is available here and for worker is available here.

Use the trained model — Inference

We now have all the pieces in place for the digitRecognizer function, mentioned in the introduction. I trained the model for 5 epochs and obtained 88.6% accuracy on validation dataset. One can continue to fine tune this and achieve higher accuracy as well. On this trained model we can call model.predict() by passing our own hand written digit, to complete our initial goal!

model.predict() takes a 3d tensor / array in this case and returns the probability of each digit. We simply take the digit with the max probability as the answer by the model. Here is what we’ll do,

  1. Provide a HTML5 canvas for user to draw a digit
  2. Capture the canvas output as a 3d tensor
  3. Show the prediction and the probability on UI

User Input via Canvas

Users can draw their input digit on a HTML5 canvas. This idea is inspired from the course Browser-based Models with TensorFlow.js.

Code Block 6: Shows how to create a canvas which lets users draw anything

The main idea is to draw a line from previous mouse down position to the current to trace the user’s handwritten digit on the canvas. A button could be provided to clear the canvas. All this is wrapped in a class, say ModelInference, and can be found at

Canvas Image to Tensor

// On Main Thread
const rawImage = document.getElementById(imgPlaceholderId);
function captureImage(e) {
rawImage.src = self.canvas.toDataURL('image/png');
const rawImgArr = tf.browser.fromPixels(rawImage, 1).arraySync();
const result = await visionModelWorker.predict(rawImgArr);
// In Worker
function predict(rawImgArr) {
const x = tf.tensor3d(rawImgArr);
... // predict logic

Using a couple of tfjs helper functions, it is quite easy to capture Canvas data as a tensor on the main thread in ModelInference and VisionModelWorker classes. As VisionModelWorker is wrapped by ComLink earlier, we can pass rawImgArr to predict() as if it were a local function.

Finale: Show Model Output

const resized = tf.image.resizeBilinear(x, [imgWidth, imgHeight]);        const tensor = resized.expandDims(0);        
const probabilities = model.predict(tensor).arraySync();

As model.predict() takes multiple images, we create an array with single element in it using expandDims(0). Once we have a tensor, we can only use tfjs operators to transform it. probabilities is an array of size 10, with the elements denoting the probability for each digit from 0 to 9.

To extract the digit with the highest probability and the highest probability value itself tf.max and tf.argMax can be used like this,

const [prediction, confidence] = [
tf.argMax(probabilities, 1).arraySync(),
tf.max(probabilities, 1).arraySync()
Figure 5: Shows user’s hand written digit recognition in action

And there it is! We designed a model and used it to predict hand written digits completely in browser using vanilla JavaScript on multiple threads!

Applied Deep Learning Engineer | LinkedIn

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store