-Adwita Singh

Variables

x -> input

Wh -> weights for hidden layer

Ux -> again weights for input layer

b -> bias

Vo -> again weights for output layer

c-> bias for output layer

Create a class called RNN which would include our basic definitions of the variables.

class RNN(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(RNN,self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size

    self.Wh = nn.Parameter(torch.randn(hidden_size, hidden_size))
    self.Ux = nn.Parameter(torch.randn(input_size, hidden_size))
    self.bh = nn.Parameter(torch.randn(hidden_size))
    self.Vo = nn.Parameter(torch.randn(hidden_size, output_size))
    self.co = nn.Parameter(torch.randn(output_size))

Note: All weights have been defined in the format torch.randn(input_dim, output_dim). This is because when using matrix multiplication further down during the forward pass, we avoid calculating of the transpose of the weight matrix (which is not done usually lol standard implementations use x @ W.T type of calculations I just forgot about it while coding and let it be).

IMG_0093.jpeg

Computing the forward pass

def forward(self, inputs, h0=None):
    """
    inputs: (seq_len, batch_size, input_size)
    h0: (batch_size, hidden_size)
    outputs: seq_len, batch_size, output_size
    """

    seq_len, batch_size, _ = inputs.size()

    if h0 is None:
      ht = torch.zeros(batch_size, self.hidden_size, device = inputs.device)
    else:
      ht=h0
    
    outputs = []
    hidden_states = [ht]

    #equations from the provided image->forward pass only

    for t in range(seq_len):
      xt = inputs[t]
      at = self.bh + ht @ self.Wh + xt @ self.Ux 
      ht = torch.tanh(at)

      hidden_states.append(ht)
      ot = self.co + ht @ self.Vo 
      ot = torch.softmax(ot, dim=1)
      outputs.append(ot)
    
    return torch.stack(outputs),ht

Predicting the next word

We take a random sentence from wikipedia’s featured article english on 21st May, 2025. We also change the given sentence to lowercase and split it into a list of tokens

sentence = """The red-capped parrot (Purpureicephalus spurius) is a species of 
broad-tailed parrot native to southwest Western Australia. Described by Heinrich 
Kuhl in 1820, it is classified in its own genus owing to its distinctive 
elongated beak. Its closest relative is the mulga parrot. It is not easily 
confused with other parrot species; both adult sexes have a bright crimson 
crown, green-yellow cheeks, and a distinctive long bill. The wings, back, and 
long tail are dark green, and the underparts are purple-blue. Found in woodland 
and open savanna country, the red-capped parrot consumes seeds (particularly of 
eucalypts), flowers, berries, and occasionally insects. Nesting takes place in 
tree hollows. Although the red-capped parrot has been shot as a pest, and 
affected by land clearing, the population is growing and the species 
is not threatened."""

words = sentence.lower().split()

Creating a placeholder mapping like .encode() and .decode()