by CM
Posted on March 28, 2020
In this article, we will explore PyTorch, an open source machine learning library that is based on the Torch library. It is primarily developed by Facebook's AI Research lab in comparison to TensorFlow (developed by Google). PyTorch can be used for applications such as computer vision and nlp. Similar to TensorFlow it is free and open-source software. To explore PyTorch, we make use of the classic MNIST dataset.
Lets jump right into the Code. First, we import PyTorch, as well as MatPlotLib Library which will help us to vizualize the images.
import torch
import torchvision
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
Second, we can already download the training and test data.
train = datasets.MNIST("", train=True, download = True,
transform=transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST("", train=False, download = True,
transform=transforms.Compose([transforms.ToTensor()]))
Downloading the data from PyTorch might take a minute dependent on your Internet connection.
We then make use of PyTorch data loading utility, which is the heart of the torch.utils.data.DataLoader class.
trainset = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=10, shuffle=True)
We can then have a look at the dataset and its respective x,y values.
x,y = data[0][0], data[1][0]
print(x,y)
Below we see the pixel values (x) and the label (y).
==========================
OUTPUT
==========================
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0902, 0.7020, 1.0000, 0.9922, 0.6471,
0.1098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0275, 0.3216, 0.9176, 0.9882, 0.9922, 0.8392, 0.9412,
0.8431, 0.6627, 0.5647, 0.1490, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.4588, 0.9882, 0.9882, 0.9882, 0.5608, 0.0745, 0.7725,
0.9882, 0.9922, 0.7686, 0.8431, 0.3804, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9451, 0.9882, 0.9882, 0.7922, 0.0510, 0.3961, 0.9647,
0.7922, 0.5020, 0.0353, 0.3804, 0.9882, 0.2000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.4471, 1.0000, 0.9922, 0.6118, 0.0000, 0.0000, 0.3961, 0.3922,
0.0000, 0.0000, 0.0275, 0.6039, 0.9922, 0.4471, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.7373, 0.9922, 0.8392, 0.0745, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0275, 0.5882, 0.9882, 0.9882, 0.3451, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.8824, 0.9922, 0.4275, 0.0000, 0.0000, 0.0000, 0.0000, 0.1137,
0.3333, 0.6039, 0.9882, 0.9882, 0.6941, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.6863, 0.9922, 0.5216, 0.0000, 0.0000, 0.4471, 0.4431, 0.6235,
0.9882, 0.9922, 0.9882, 0.6941, 0.0118, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.8863, 0.9569, 0.8471, 0.7490, 0.9961, 0.9922, 0.9922,
0.9922, 0.8980, 0.9922, 0.6588, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.2196, 0.3294, 0.3294, 0.6980, 0.9882, 0.9882,
0.9882, 0.3098, 0.9882, 0.6588, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1490, 0.8980, 0.9882, 0.9882,
0.9882, 0.6980, 0.9882, 0.7686, 0.0392, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.1490, 0.8824, 0.9922, 0.9882, 0.6431,
0.1569, 0.0157, 0.6980, 0.9882, 0.3059, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.4706, 0.9922, 0.9922, 0.8863, 0.1490, 0.0000,
0.0000, 0.0000, 0.6627, 0.9922, 0.5490, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.1765, 0.8824, 0.9882, 0.9882, 0.1490, 0.0000, 0.0000,
0.0000, 0.0745, 0.8824, 0.9882, 0.5451, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.8980, 0.9882, 0.9882, 0.8392, 0.0000, 0.0000, 0.0000,
0.0000, 0.6039, 0.9882, 0.9882, 0.2039, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9922, 0.9882, 0.8392, 0.1098, 0.0000, 0.0000, 0.0000,
0.3961, 0.9922, 0.9882, 0.9882, 0.1098, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9961, 0.9922, 0.6588, 0.0000, 0.0000, 0.0745, 0.4078,
0.9922, 0.9961, 0.9686, 0.5882, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9922, 0.9882, 0.6588, 0.0000, 0.4706, 0.8824, 0.9882,
0.9882, 0.9686, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.8471, 0.9882, 0.9176, 0.7725, 0.9922, 0.9882, 0.9882,
0.9882, 0.3922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0627, 0.4039, 0.8941, 0.9882, 0.9451, 0.5451, 0.4039,
0.1098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000]]]) tensor(3)
As the 28*28 pixel values are hard to interpret, we can also visualize them using MatPlotLib.
plt.imshow(x.view(28,28))
In order to have good model training the dataset needs to be balanced. Luckily, the MNIST-Dataset is already prepared in the respective condition that we can use it for training a potential model immediately without having to clean or adjust data first. Indeed, let's just have a look at the distribution of the all number in the dataset. To do this, we create a dictionary.
total = 0
counter_dict = {0:0,1:0,2:0,3:0,4:0,5:0,6:0,7:0,8:0,9:0}
for data in trainset:
Xs, ys = data
for y in ys:
counter_dict[int(y)] +=1
total +=1
for i in counter_dict:
print(f"{i}: {round(counter_dict[i]/total*100,2)}","%")
We
==========================
OUTPUT
==========================
0: 9.87 %
1: 11.24 %
2: 9.93 %
3: 10.22 %
4: 9.74 %
5: 9.04 %
6: 9.86 %
7: 10.44 %
8: 9.75 %
9: 9.92 %
We now import two Torch libraries that we will use to build our neural network. The torch.nn Module is the base class for all neural network modules.Our models will also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. We can assign the submodules as regular attributes.
import torch.nn as nn
import torch.nn.functional as F
We then start building our neural network model using a class. In our class, we will have two functions. In order, to initiate both functions, when the class is used, we need to run __init__(self) & super().__init__() as they inherit. We give our model 5 fully connected layers. We just have to define the forward function, and the backward function (where gradients are computed) is automatically defined for you using autograd. You can use any of the Tensor operations in the forward function. We use 4 ReLU activation functions for the first 4 layers and a softmax activation function for the last layer, as we want to have a probability respectively classification output for each number prediction possibility.
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 64)
self.fc4 = nn.Linear(64, 64)
self.fc5 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = self.fc4(x)
return F.log_softmax(x, dim=1)
net = Net()
print(net)
We can have a look at our model.
==========================
OUTPUT
==========================
Net(
(fc1): Linear(in_features=784, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=64, bias=True)
(fc3): Linear(in_features=64, out_features=64, bias=True)
(fc4): Linear(in_features=64, out_features=64, bias=True)
(fc5): Linear(in_features=64, out_features=10, bias=True)
)
We will then initialize our optimization function (Adam) with the respective hyperparameters (epochs and learning rate).
import torch.optim as optim
optimizer = optim.Adam(net.parameters(), lr=0.001)
EPOCHS = 3
for epoch in range(EPOCHS):
for data in testset:
X, y = data
net.zero_grad()
output = net(X.view(-1,28*28))
loss = F.nll_loss(output, y)
loss.backward()
optimizer.step()
print(loss)
In the training history we can see how our model is improving.
==========================
OUTPUT
==========================
tensor(0.0056, grad_fn=)
tensor(0.0857, grad_fn=)
tensor(0.0009, grad_fn=)
tensor(0.3751, grad_fn=)
tensor(0.0104, grad_fn=)
tensor(0.0841, grad_fn=)
Let's quickly calculate the accuracy when using our testset.
correct = 0
total = 0
with torch.no_grad():
for data in trainset:
X, y = data
output = net(X.view(-1, 784))
for idx, i in enumerate(output):
if torch.argmax(i) == y[idx]:
correct +=1
total +=1
print("Accuracy: " , round(correct/total, 3))
We find a accuracy abobe 97% of correct predictions.
==========================
OUTPUT
==========================
Accuracy: 0.971
In this simple tutorial, we have used PyTorch to make sense of the MNIST-Dataset We have built a powerful Neural Network ML model that allows us to predict the numbers of a 28x28 image.