Exploring PyTorch

by CM


Posted on March 28, 2020



The Goal:

In this article, we will explore PyTorch, an open source machine learning library that is based on the Torch library. It is primarily developed by Facebook's AI Research lab in comparison to TensorFlow (developed by Google). PyTorch can be used for applications such as computer vision and nlp. Similar to TensorFlow it is free and open-source software. To explore PyTorch, we make use of the classic MNIST dataset.


Key components are:

Dataset:
>> The data files (test & train) are directly downloaded from PyTorch.

Lets jump right into the Code. First, we import PyTorch, as well as MatPlotLib Library which will help us to vizualize the images.

import torch
import torchvision
from torchvision import transforms, datasets
import matplotlib.pyplot as plt

Second, we can already download the training and test data.

train = datasets.MNIST("", train=True, download = True,
                       transform=transforms.Compose([transforms.ToTensor()]))

test = datasets.MNIST("", train=False, download = True,
                       transform=transforms.Compose([transforms.ToTensor()]))

Downloading the data from PyTorch might take a minute dependent on your Internet connection.

We then make use of PyTorch data loading utility, which is the heart of the torch.utils.data.DataLoader class.

trainset = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=10, shuffle=True)

We can then have a look at the dataset and its respective x,y values.

x,y = data[0][0], data[1][0]
print(x,y)

Below we see the pixel values (x) and the label (y).

==========================
OUTPUT
==========================
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0902, 0.7020, 1.0000, 0.9922, 0.6471,
          0.1098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0275, 0.3216, 0.9176, 0.9882, 0.9922, 0.8392, 0.9412,
          0.8431, 0.6627, 0.5647, 0.1490, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.4588, 0.9882, 0.9882, 0.9882, 0.5608, 0.0745, 0.7725,
          0.9882, 0.9922, 0.7686, 0.8431, 0.3804, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.9451, 0.9882, 0.9882, 0.7922, 0.0510, 0.3961, 0.9647,
          0.7922, 0.5020, 0.0353, 0.3804, 0.9882, 0.2000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.4471, 1.0000, 0.9922, 0.6118, 0.0000, 0.0000, 0.3961, 0.3922,
          0.0000, 0.0000, 0.0275, 0.6039, 0.9922, 0.4471, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.7373, 0.9922, 0.8392, 0.0745, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0275, 0.5882, 0.9882, 0.9882, 0.3451, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.8824, 0.9922, 0.4275, 0.0000, 0.0000, 0.0000, 0.0000, 0.1137,
          0.3333, 0.6039, 0.9882, 0.9882, 0.6941, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.6863, 0.9922, 0.5216, 0.0000, 0.0000, 0.4471, 0.4431, 0.6235,
          0.9882, 0.9922, 0.9882, 0.6941, 0.0118, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.8863, 0.9569, 0.8471, 0.7490, 0.9961, 0.9922, 0.9922,
          0.9922, 0.8980, 0.9922, 0.6588, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.2196, 0.3294, 0.3294, 0.6980, 0.9882, 0.9882,
          0.9882, 0.3098, 0.9882, 0.6588, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.1490, 0.8980, 0.9882, 0.9882,
          0.9882, 0.6980, 0.9882, 0.7686, 0.0392, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.1490, 0.8824, 0.9922, 0.9882, 0.6431,
          0.1569, 0.0157, 0.6980, 0.9882, 0.3059, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.4706, 0.9922, 0.9922, 0.8863, 0.1490, 0.0000,
          0.0000, 0.0000, 0.6627, 0.9922, 0.5490, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1765, 0.8824, 0.9882, 0.9882, 0.1490, 0.0000, 0.0000,
          0.0000, 0.0745, 0.8824, 0.9882, 0.5451, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.8980, 0.9882, 0.9882, 0.8392, 0.0000, 0.0000, 0.0000,
          0.0000, 0.6039, 0.9882, 0.9882, 0.2039, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.9922, 0.9882, 0.8392, 0.1098, 0.0000, 0.0000, 0.0000,
          0.3961, 0.9922, 0.9882, 0.9882, 0.1098, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.9961, 0.9922, 0.6588, 0.0000, 0.0000, 0.0745, 0.4078,
          0.9922, 0.9961, 0.9686, 0.5882, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.9922, 0.9882, 0.6588, 0.0000, 0.4706, 0.8824, 0.9882,
          0.9882, 0.9686, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.8471, 0.9882, 0.9176, 0.7725, 0.9922, 0.9882, 0.9882,
          0.9882, 0.3922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0627, 0.4039, 0.8941, 0.9882, 0.9451, 0.5451, 0.4039,
          0.1098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]]) tensor(3)

As the 28*28 pixel values are hard to interpret, we can also visualize them using MatPlotLib.

plt.imshow(x.view(28,28))

In order to have good model training the dataset needs to be balanced. Luckily, the MNIST-Dataset is already prepared in the respective condition that we can use it for training a potential model immediately without having to clean or adjust data first. Indeed, let's just have a look at the distribution of the all number in the dataset. To do this, we create a dictionary.

total = 0
counter_dict = {0:0,1:0,2:0,3:0,4:0,5:0,6:0,7:0,8:0,9:0}

for data in trainset:
  Xs, ys = data
  for y in ys:
    counter_dict[int(y)] +=1
    total +=1

for i in counter_dict:
  print(f"{i}: {round(counter_dict[i]/total*100,2)}","%")

We

==========================
OUTPUT
==========================

0: 9.87 %
1: 11.24 %
2: 9.93 %
3: 10.22 %
4: 9.74 %
5: 9.04 %
6: 9.86 %
7: 10.44 %
8: 9.75 %
9: 9.92 %

We now import two Torch libraries that we will use to build our neural network. The torch.nn Module is the base class for all neural network modules.Our models will also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. We can assign the submodules as regular attributes.

import torch.nn as nn
import torch.nn.functional as F

We then start building our neural network model using a class. In our class, we will have two functions. In order, to initiate both functions, when the class is used, we need to run __init__(self) & super().__init__() as they inherit. We give our model 5 fully connected layers. We just have to define the forward function, and the backward function (where gradients are computed) is automatically defined for you using autograd. You can use any of the Tensor operations in the forward function. We use 4 ReLU activation functions for the first 4 layers and a softmax activation function for the last layer, as we want to have a probability respectively classification output for each number prediction possibility.

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(784, 128)
    self.fc2 = nn.Linear(128, 64)
    self.fc3 = nn.Linear(64, 64)
    self.fc4 = nn.Linear(64, 64)
    self.fc5 = nn.Linear(64, 10)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))
    x = F.relu(self.fc4(x))
    x = self.fc4(x)
    return F.log_softmax(x, dim=1)

net = Net()
print(net)

We can have a look at our model.

==========================
OUTPUT
==========================

Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc4): Linear(in_features=64, out_features=64, bias=True)
  (fc5): Linear(in_features=64, out_features=10, bias=True)
)

We will then initialize our optimization function (Adam) with the respective hyperparameters (epochs and learning rate).

import torch.optim as optim
optimizer = optim.Adam(net.parameters(), lr=0.001)

EPOCHS = 3

for epoch in range(EPOCHS):
  for data in testset:
    X, y = data
    net.zero_grad()
    output  = net(X.view(-1,28*28))
    loss   = F.nll_loss(output, y)
    loss.backward()
    optimizer.step()
  print(loss)


In the training history we can see how our model is improving.

==========================
OUTPUT
==========================

tensor(0.0056, grad_fn=)
tensor(0.0857, grad_fn=)
tensor(0.0009, grad_fn=)
tensor(0.3751, grad_fn=)
tensor(0.0104, grad_fn=)
tensor(0.0841, grad_fn=)

Let's quickly calculate the accuracy when using our testset.

correct = 0
total = 0

with torch.no_grad():
  for data in trainset:
    X, y = data
    output = net(X.view(-1, 784))
    for idx, i in enumerate(output):
      if torch.argmax(i) == y[idx]:
        correct +=1
      total +=1
print("Accuracy: " , round(correct/total, 3))


We find a accuracy abobe 97% of correct predictions.

==========================
OUTPUT
==========================

Accuracy:  0.971

In this simple tutorial, we have used PyTorch to make sense of the MNIST-Dataset We have built a powerful Neural Network ML model that allows us to predict the numbers of a 28x28 image.

Leverage PyTorch

#EpicML


News
Dec 2021

--- Quantum ---

Simulating matter on the quantum scale with AI #Deepmind
Nov 2021

--- Graviton3 ---

Amazon announced its Graviton3 processors for AI inferencing - the next generation of its custom ARM-based chip for AI inferencing applications. #Graviton3
May 2021

--- Vertex AI & TPU Gen4. ---

Google announced its fourth generation of tensor processing units (TPUs) for AI and ML workloads and the Vertex AI managed platform #VertexAI #TPU
Feb 2021

--- TensorFlow 3D ---

In February of 2021, Google released TensorFlow 3D to help enterprises develop and train models capable of understanding 3D scenes #TensorFlow3D
Nov 2020

--- AlphaFold ---

In November of 2020, AlphaFold 2 was recognised as a solution to the protein folding problem at CASP14 #protein_folding
Oct 2019

--- Google Quantum ---

A research effort from Google AI that aims to build quantum processors and develop novel quantum algorithms to dramatically accelerate computational tasks for machine learning. #quantum_supremacy
Oct 2016

--- AlphaGo ---

Mastering the game of Go with Deep Neural Networks. #neural_network