Today, let's try to train a full digit recognizer trained on the MNIST dataset. Imports as always.
from fastai.vision.all import *
matplotlib.rc('image', cmap='Greys')
path = untar_data(URLs.MNIST)
Path.BASE_PATH = path
Let's see what we have in the archive
path.ls()
(#2) [Path('training'),Path('testing')]
[len((path/'training'/str(i)).ls()) for i in range(10)]
[5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
[len((path/'testing'/str(i)).ls()) for i in range(10)]
[980, 1135, 1032, 1010, 982, 892, 958, 1028, 974, 1009]
So we have around 6k in each training set in 1k in each validation set. Let's build a digit classifier!
First we need to define the DataBlock (a.k.a let the framework know, how the dataset is structured)
digits = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
get_y=parent_label)
dls = digits.dataloaders(path)
Now we can train a resnet18 network (I've no idea what's in it, but it was used in the book previously)
learn = cnn_learner(dls, resnet18, metrics=error_rate, pretrained=True)
learn.fine_tune(5)
learn.recorder.plot_loss(skip_start=0)
Wow, we're getting a tiny error rate, let's see where it got confused.
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_top_losses(5, nrows=1)
I can't blame the network, I wouldn't necessarily guess myself. Let's export to ONNX and we'll try to run the inference in the browser.
At this point we have a fully trained network ready to be deployed. Let's try to run it in the browser!
We can use onnx.js
to try running the recognizer in the browser. Let's export this out to onnx, so we can use it in other environments than PyTorch. We'll need a dummy input, to trigger activations in the network I guess. Not really sure about this point, but from what I can gather, the export works by running inference on an example input and recording what operations were triggered. The resnet input is a convolution layer, which to me sounds something like scanning, so let's just feed it something valid and hope this is fine.
dummy_input = torch.rand(64, 3, 7, 7, device='cuda')
dummy_input.shape
torch.Size([64, 3, 7, 7])
onnx_model_path = "my_mnist.onnx"
torch.onnx.export(learn.model, dummy_input, onnx_model_path)
print(str(os.path.getsize(onnx_model_path) / 1024 / 1024.0) + "Mb")
44.725908279418945Mb
That's quite heavy. We'll have to see if we can do something about that later. For our use case using even resnet18 is definitely overkill, especially since the training images are way smaller that what's accepter by resnet, and we don't have colour.
Then I go on to try to run it using onnx.js
and loading the model fails with TypeError: cannot resolve operator 'Shape' with opsets: ai.onnx v9
.
I'll have to come back to this later, some googling hints that I need to modify some layers in the network, but I'm not yet ready for that.