import torch
from torch import nn

First, the images is divided into non overlapping patches

The question is: - is it necessary to use convolutional neural network for getting patches - can it patches be generated simply by Image.crop

import cv2
img = cv2.imread('images/trees_reflected.jpg')
img = cv2.resize(img, (224, 224))
img  # each pixel is a 3D vector (RGB)
array([[[  1,  41,  27],
        [  3,  27,  11],
        [ 20,  72,  55],
        ...,
        [152,  82,   3],
        [145,  77,   0],
        [152,  85,   0]],

       [[  2,  39,  31],
        [  4,   9,   5],
        [  7,  29,  27],
        ...,
        [147,  79,   8],
        [146,  80,   3],
        [145,  83,   0]],

       [[ 28,  44,  53],
        [  1,  15,  17],
        [  4,  25,  32],
        ...,
        [148,  90,  16],
        [141,  85,   3],
        [141,  84,   8]],

       ...,

       [[177, 102,   0],
        [177, 102,   0],
        [178, 103,   1],
        ...,
        [170,  93,   1],
        [170,  93,   1],
        [171,  93,   0]],

       [[178, 102,   1],
        [178, 104,   2],
        [178, 103,   1],
        ...,
        [170,  93,   1],
        [169,  92,   0],
        [169,  92,   0]],

       [[177, 102,   0],
        [176, 101,   0],
        [178, 102,   2],
        ...,
        [170,  93,   1],
        [169,  92,   0],
        [171,  93,   0]]], dtype=uint8)
x = img.transpose(2,1,0)
x = torch.tensor(x, dtype = torch.float32)
x
tensor([[[  1.,   2.,  28.,  ..., 177., 178., 177.],
         [  3.,   4.,   1.,  ..., 177., 178., 176.],
         [ 20.,   7.,   4.,  ..., 178., 178., 178.],
         ...,
         [152., 147., 148.,  ..., 170., 170., 170.],
         [145., 146., 141.,  ..., 170., 169., 169.],
         [152., 145., 141.,  ..., 171., 169., 171.]],

        [[ 41.,  39.,  44.,  ..., 102., 102., 102.],
         [ 27.,   9.,  15.,  ..., 102., 104., 101.],
         [ 72.,  29.,  25.,  ..., 103., 103., 102.],
         ...,
         [ 82.,  79.,  90.,  ...,  93.,  93.,  93.],
         [ 77.,  80.,  85.,  ...,  93.,  92.,  92.],
         [ 85.,  83.,  84.,  ...,  93.,  92.,  93.]],

        [[ 27.,  31.,  53.,  ...,   0.,   1.,   0.],
         [ 11.,   5.,  17.,  ...,   0.,   2.,   0.],
         [ 55.,  27.,  32.,  ...,   1.,   1.,   2.],
         ...,
         [  3.,   8.,  16.,  ...,   1.,   1.,   1.],
         [  0.,   3.,   3.,  ...,   1.,   0.,   0.],
         [  0.,   0.,   8.,  ...,   0.,   0.,   0.]]])
class PatchEmbeddings(nn.Module):
    def __init__(self, image, out_channels, kernel_size):
        super().__init__()
        self.image = image
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        num_tokens = (image.shape[1] // kernel_size) * (image.shape[2] // kernel_size)
        self.num_tokens = num_tokens
        self.Conv = nn.Conv2d(in_channels= image.shape[0], out_channels= self.out_channels, kernel_size = self.kernel_size, stride = 3)


    def forward(self):
        x = self.Conv(self.image) # image size (224x224) , convolved image (10, 74, 74) 74x74 patches with embedding dimension of 10
        x = x.flatten(1) # (10, 5476)
        x = x.transpose(1, 0) #(5476, 10), each token or patch is of dimension 10
        return x
token_embedding = PatchEmbeddings(x, out_channels = 10, kernel_size = 3)
out = token_embedding()
out.shape
torch.Size([5476, 10])

Here positional embedding and class token is added

class token is appended in the beginning of all the image toekns

class token gives the overall representation of the image. Without class token, the ViT will attend to patches (or image tokens) but can not giv the overall representation of the image ( to which class does the image belong).

If there are 100 image tokens of dimension 256 then a class token of dimension 256 is added to the sequence of image tokens, the number of tokens become 101 of dimension 256.

class Embedding(nn.Module):
    def __init__(self, image, out_channels, kernel_size):
        super().__init__()
        self.patch_embedding = PatchEmbeddings(image,  out_channels, kernel_size)
        self.linear = nn.Linear(1, out_channels)

    def forward(self):
        x = self.patch_embedding.forward()
        print(x.shape)
        num_tokens = self.patch_embedding.num_tokens
        position = torch.tensor(torch.arange(1, num_tokens + 1).unsqueeze(0).T, dtype = torch.float32)
        print(position.shape)
        pos_embed = self.linear(position)
        print(pos_embed.shape)
        x = x + pos_embed
        cls_token = nn.Parameter(torch.randn((1, self.patch_embedding.out_channels)))
        x = torch.cat([cls_token, x], dim = 0)
        return x

embed = Embedding(x, out_channels = 10, kernel_size = 3)
embed_out = embed()
embed_out.shape
torch.Size([5476, 10])
torch.Size([5476, 1])
torch.Size([5476, 10])
C:\Users\Hp\AppData\Local\Temp\ipykernel_18220\3315714593.py:11: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  position = torch.tensor(torch.arange(1, num_tokens + 1).unsqueeze(0).T, dtype = torch.float32)
torch.Size([5477, 10])
class MultiheadSelfAttention(nn.Module):
    def __init__(self, input_dim, heads):
        super().__init__()
        self.input_dim = input_dim
        self.heads = heads
        self.head_dim = input_dim // heads
        self.embed_dim = heads * self.head_dim
        
        self.wq = nn.Linear(self.head_dim, self.head_dim)
        self.wk = nn.Linear(self.head_dim, self.head_dim)
        self.wv = nn.Linear(self.head_dim, self.head_dim)

    def forward(self, query, key, value):
        num_tokens = query.shape[0]
        query = query.reshape(num_tokens, self.heads, self.head_dim)
        key = key.reshape(num_tokens, self.heads, self.head_dim)
        value = value.reshape(num_tokens, self.heads, self.head_dim)

        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)

        attention_score = torch.einsum('qhd, khd -> qhk', query, key)
        attention_value = nn.functional.softmax(attention_score / (self.head_dim ** 0.5), dim = -1)

        value = torch.einsum('nhn, nhd -> nhd', attention_value, value)
        value = value.reshape((num_tokens, self.embed_dim))

        return value


embed = Embedding(x, out_channels = 512, kernel_size = 3)
embed_out = embed()
sat = MultiheadSelfAttention(512, 8)
value = sat(embed_out, embed_out, embed_out)
value.shape      
torch.Size([5476, 512])
torch.Size([5476, 1])
torch.Size([5476, 512])
C:\Users\Hp\AppData\Local\Temp\ipykernel_18220\3315714593.py:11: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  position = torch.tensor(torch.arange(1, num_tokens + 1).unsqueeze(0).T, dtype = torch.float32)
torch.Size([5477, 512])
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, embed_dim, heads): # embed_dim is out_channels
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn_layer = MultiheadSelfAttention(input_dim, heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Dropout(p = 0.2),
            nn.Linear(embed_dim, embed_dim),
        )

    def forward(self, embedding):
        x = self.norm1(embedding)
        attn_x = self.attn_layer(x,x,x)
        x = attn_x + x
        x = self.norm2(x)
        mlp_x = self.mlp(x)
        x = mlp_x + x

        return x
    
embed = Embedding(x, out_channels = 512, kernel_size = 3)
embed_out = embed()
encoder = EncoderBlock(input_dim = 512, embed_dim = 512, heads = 8)
value = encoder(embed_out)
value.shape  
torch.Size([5476, 512])
torch.Size([5476, 1])
torch.Size([5476, 512])
C:\Users\Hp\AppData\Local\Temp\ipykernel_18220\3315714593.py:11: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  position = torch.tensor(torch.arange(1, num_tokens + 1).unsqueeze(0).T, dtype = torch.float32)
torch.Size([5477, 512])
class Vit(nn.Module):
    def __init__(self, input_dim, embed_dim, heads, num_encoders):
        super().__init__()
        self.num_encoders = num_encoders
        self.encoder_blocks = nn.ModuleList([EncoderBlock(input_dim, embed_dim, heads) for _ in range(self.num_encoders)])

    def forward(self, embedding):
        x = embedding
        for encoder_block in self.encoder_blocks:
            x = encoder_block(x)

        return x
    
embed = Embedding(x, out_channels = 512, kernel_size = 3)
embed_out = embed()
vit_model = Vit(input_dim = 512, embed_dim = 512, heads = 8, num_encoders = 4)
value = vit_model(embed_out)
value.shape
torch.Size([5476, 512])
torch.Size([5476, 1])
torch.Size([5476, 512])
C:\Users\Hp\AppData\Local\Temp\ipykernel_18220\3315714593.py:11: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  position = torch.tensor(torch.arange(1, num_tokens + 1).unsqueeze(0).T, dtype = torch.float32)
torch.Size([5477, 512])