Creating a Baseline Model for MNIST Dataset

pip install fastbookimport fastbookfrom fastai.vision.all import *
from fastai import *
path = untar_data(URLs.MNIST_SAMPLE)(path/"train").ls()(#2) [Path('/root/.fastai/data/mnist_sample/train/7'),Path('/root/.fastai/data/mnist_sample/train/3')]
threes = (path/"train"/"3").ls().sorted()sevens = (path/"train"/"7").ls().sorted()
im3_path = threes[1]
im3 = Image.open(im3_path)
im3
Output of the above code
array(im3)
Output of the above code
tensor(im3)
Output of the above code
pip install pandasim_3t = tensor(im3)df = pd.DataFrame(im_3t)df.style.background_gradient("Greys")

Baseline Model

Basic Idea

What are Baseline Models and why are they important?

Model Creation Step

seven_tensors = [tensor(Image.open(i)) for i in sevens]
three_tensors = [tensor(Image.open(i)) for i in threes]
len(three_tensors), len(seven_tensors)
(6131, 6265)
show_image(three_tensors[1]);
Image of the tensor
Image of a rank-3 tensor
stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255
stacked_threes.shapetorch.Size([6131, 28, 28])
stacked_threes.ndim
3
mean3 = stacked_threes.mean(0)show_image(mean3)
Image of the ideal 3
mean7 = stacked_sevens.mean(0)
show_image(mean7)
Image of the ideal 7
Formula for Mean Absolute Error
Formula for Root Mean Squared Error
a_3 = stacked_threes[1]
show_image(a_3);
dist_3_abs = (a_3 - mean3).abs().mean() #MAD
dist_3_sqr = ((a_3 - mean3)**2).mean().sqrt() #RMSE
dist_3_abs, dist_3_sqr
(tensor(0.1114), tensor(0.2021))
dist_7_abs = (a_3 - mean7).abs().mean()
dist_7_sqr = ((a_3 - mean7)**2).mean().sqrt()
dist_7_abs, dist_7_sqr
(tensor(0.1586), tensor(0.3021))
F.l1_loss(a_3.float(), mean7), F.mse_loss(a_3, mean7).sqrt()
(tensor(0.1586), tensor(0.3021))
valid_3_tensor = torch.stack([tensor(Image.open(i)) for i in (path/"valid"/"3").ls()])
valid_3_tensor = valid_3_tensor.float()/255
show_image(valid_3_tensor[1])
2nd image in the valid set
valid_7_tensor = torch.stack([tensor(Image.open(i)) for i in (path/"valid"/"7").ls()])
valid_7_tensor = valid_7_tensor.float()/255
show_image(valid_7_tensor[0]);
1st image in the valid set
valid_3_tensor.shape, valid_7_tensor.shape
(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))
def mnist_dist(a,b):
return (a-b).abs().mean((-1,-2))
mnist_dist(a_3, mean3)
tensor(0.1114)
valid_3_dist = mnist_dist(valid_3_tensor, mean3)valid_3_dist.shape
torch.Size([1010])valid_3_dist
tensor([0.1117, 0.1295, 0.1168,  ..., 0.1506, 0.1380, 0.1483])
def is_3(x): 
return (mnist_dist(x, mean3) < mnist_dist(x, mean7))
is_3(a_3)
tensor(True)
is_3(valid_3_tensor)
tensor([ True, False,  True,  ...,  True,  True,  True])
accuracy_3s = is_3(valid_3_tensor).float().mean()
accuracy_7s = (1 - is_3(valid_7_tensor).float()).mean()
accuracy_3s, accuracy_7s
(tensor(0.9168), tensor(0.9854))

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store