import torch import torch.nn as nn import torch.optim as optim from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence class SimpleModel(nn.Module): # ... (same as before) # Sample input sequences with varying lengths input_sequences = [ [0, 1, 2], [3, 4, 5, 6], [7, 8, 9, 10, 11] ] # Convert input sequences to PyTorch tensors padded_input = nn.utils.rnn.pad_sequence([torch.LongTensor(seq) for seq in input_sequences], batch_first=True) # Create the model model = SimpleModel(input_size, hidden_size, output_size) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop optimizer.zero_grad() loss = 0 for i in range(len(input_sequences)): input_sequence = padded_input[i] target_sequence = input_sequence[1:] # Predict the next letter # Initialize hidden state hidden = (torch.zeros(1, len(input_sequence), hidden_size), torch.zeros(1, len(input_sequence), hidden_size)) # Forward pass packed_input = pack_padded_sequence(input_sequence.unsqueeze(0), [len(input_sequence)], batch_first=True) packed_output, hidden = model(packed_input, hidden) output, _ = pad_packed_sequence(packed_output, batch_first=True) # Compute loss loss += criterion(output.view(-1, output_size), target_sequence) loss.backward() optimizer.step() print("Loss:", loss.item()) 23,24,18,9,22,10,4,1,6,8,21,25,3,19,7,20,13,12,16 7,6,2,16,26,18,11,25,1,3,9,4,23,10,17,8,16,22,21,19 16,20,2,15,9,18,25,6,17,7,8,5,13,22,3,23,26,4,10,1 24,10,9,19,7,5,2,15,8,13,20,23,4,12,6,11,22,26,3,16 7,12,2,18,14,21,16,5,10,4,11,3,25,8,22,1,15,24,6,13 2,26,8,9,7,4,18,5,13,19,25,1,24,15,11,6,16,22,14,21 16,22,5,21,25,17,10,13,15,3,24,14,20,9,19,6,12,1,11 25,7,2,26,21,16,3,18,14,8,10,17,9,23,13,20,5,11,22,6 13,12,8,7,15,10,9,1,24,2,14,11,10,25,26,5,22,17,21,18 22,26,5,24,14,3,25,23,1,21,13,11,6,16,18,17,2,8,15,8 1,11,6,23,5,8,10,24,13,26,2,14,9,4,19,17,18,16,25,21 9,23,26,22,8,11,10,19,1,16,24,15,4,3,6,17,2,12,5,13 21,9,24,13,20,22,8,25,2,23,16,7,1,11,15,14,3,4,18,12 15,25,22,21,9,8,18,6,7,3,24,26,19,2,20,14,12,1,10,16 15,1,4,21,23,7,8,18,16,25,14,2,22,20,12,3,5,6,11,24 3,14,16,6,8,1,15,11,12,5,18,26,22,21,10,7,17,9,8,18 7,17,6,16,10,2,21,12,22,3,8,13,11,22,1,20,9,12,19,15 2,9,3,4,21,26,7,18,22,5,15,8,11,25,10,16,13,6,1,19 8,14,22,16,2,23,10,5,9,15,21,13,1,4,6,18,11,12,3,10 5,8,18,12,17,16,26,2,23,13,21,22,19,14,16,4,1,15,25,24 24,5,18,22,3,9,15,21,16,25,7,17,14,6,8,11,19,13,2,23 13,2,14,8,15,21,9,5,19,18,22,6,10,16,7,12,4,3,25,11 2,22,4,23,18,17,16,15,11,20,14,9,21,3,13,10