Text Classification with Neural Networks

Text Classification with Neural Networks#

OBJECTIVES

  • Understand Text Vectorization through Count Vectorization, Sequential Representation, and Word Embeddings

  • Use Multilayer Perceptrons, Recurrent Neural Networks, and Transformer Models for text classification

  • Use pytorch to build text classification models

  • Use huggingface and the transformers library to fine-tune existing large language models

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

Classifying Jailbreak Prompts#

A Jailbreak attack is an attempt to exploit an LLM’s vunerabilities. Today, we will take a few different approaches to identifying jailbreak attempts through text classification.

#read in data
splits = {'train': 'balanced/jailbreak_dataset_train_balanced.csv', 'test': 'balanced/jailbreak_dataset_test_balanced.csv'}
df = pd.read_csv("hf://datasets/jackhhao/jailbreak-classification/" + splits["train"])
#take a peek
df.head()
prompt type
0 You are a devoted fan of a celebrity. benign
1 You are Joseph Seed from Far Cry 5. Sermonize ... benign
2 As Rei Ayanami from Neon Genesis Evangelion, t... benign
3 You are a medieval blacksmith named Wulfric, l... benign
4 Answer the following question: So, I worked wi... benign
#create a tokenizer
tokenizer = Tokenizer(num_words = 500)
#fit the tokenizer -- learns the vocabulary
tokenizer.fit_on_texts(df['prompt'].values)
#look at tokenizer
tokenizer.num_words
500
#create document term matrix (binarized)
dtm = tokenizer.texts_to_matrix(df['prompt'].values)
#take a peek
dtm
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 1., 1., ..., 0., 0., 0.],
       [0., 1., 1., ..., 0., 0., 0.],
       ...,
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 1., 1., ..., 0., 0., 0.]])
tokenizer.index_word[2]
'and'
[tokenizer.index_word[i] for i in range(1, 50)]
['the',
 'and',
 'to',
 'you',
 'a',
 'of',
 'is',
 'in',
 'it',
 'as',
 'will',
 'that',
 'or',
 'are',
 'with',
 'not',
 'your',
 'i',
 'do',
 'for',
 'if',
 'this',
 'any',
 'be',
 'dan',
 'chatgpt',
 'can',
 'have',
 'answer',
 'an',
 'on',
 'always',
 'all',
 'by',
 'from',
 'about',
 'he',
 'must',
 'no',
 'like',
 'response',
 'anything',
 'should',
 'responses',
 'ai',
 'what',
 'now',
 'user',
 'but']
y = np.where(df['type'] == 'benign', 0, 1)
Xt = torch.tensor(dtm, dtype = torch.float32)
yt = torch.tensor(y, dtype = torch.float32)
from torch.utils.data import TensorDataset
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(Xt, yt, test_size=.2)
X_train
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 1.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]])
#create data class
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
#dataset and loader -- making batches of our bigger dataset
trainloader = DataLoader(train_dataset, batch_size = 32)
#dataset and loader
testloader = DataLoader(test_dataset, batch_size = 32)
model = nn.Sequential(nn.Linear(in_features=500, out_features=1000),
                      nn.ReLU(),
                      nn.Linear(1000, 100),
                      nn.ReLU(),
                      nn.Linear(100, 1),
                      nn.Sigmoid()
                      )
model(Xt)
tensor([[0.4850],
        [0.4867],
        [0.4842],
        ...,
        [0.4856],
        [0.4850],
        [0.4841]], grad_fn=<SigmoidBackward0>)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr = 0.01)
from tqdm import tqdm
model = model.to('cuda')
#keep track of the losses
losses = []
#train it for 20 epochs
for epoch in tqdm(range(20)):
  #iterate over the batches
  for x,y in trainloader:
    x = x.to('cuda')
    y = y.to('cuda')
    #feeds data into model
    yhat = model(x)
    #evaluate the predictions
    loss = loss_fn(yhat, y.unsqueeze(1))
    #update the weights/params
    optimizer.zero_grad() #pytorch house cleaning
    loss.backward() #pass info backwards
    optimizer.step() #step towards less loss
    losses.append(loss.item()) #tracking the loss
100%|██████████| 20/20 [00:01<00:00, 15.23it/s]
X_train = X_train.to('cuda')
model(X_train)
tensor([[7.6080e-12],
        [1.0000e+00],
        [5.7215e-08],
        [5.2058e-10],
        [1.0000e+00],
        [1.2170e-06],
        [2.7917e-06],
        [2.5559e-06],
        [1.9318e-09],
        [4.4235e-06],
        [3.1658e-14],
        [1.3348e-10],
        [8.7853e-21],
        [1.8574e-26],
        [1.0863e-12],
        [1.0000e+00],
        [4.8573e-05],
        [1.3076e-08],
        [8.7744e-13],
        [1.0000e+00],
        [1.7011e-27],
        [9.1420e-06],
        [1.1157e-31],
        [2.3193e-10],
        [8.9439e-13],
        [1.0000e+00],
        [3.1938e-29],
        [2.1481e-13],
        [2.3711e-17],
        [1.5619e-18],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.7602e-09],
        [1.0000e+00],
        [1.0000e+00],
        [5.4653e-13],
        [6.1534e-08],
        [1.0000e+00],
        [1.0756e-11],
        [3.5019e-08],
        [1.0000e+00],
        [1.6152e-35],
        [1.0000e+00],
        [1.0000e+00],
        [4.0023e-08],
        [1.0000e+00],
        [1.0000e+00],
        [2.6483e-11],
        [1.0000e+00],
        [1.0000e+00],
        [1.5129e-14],
        [4.7407e-11],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.1003e-06],
        [1.0000e+00],
        [4.0842e-07],
        [3.4858e-24],
        [3.0237e-06],
        [1.0000e+00],
        [1.0000e+00],
        [9.2892e-06],
        [1.0000e+00],
        [2.4930e-13],
        [2.5443e-17],
        [1.0000e+00],
        [1.8773e-18],
        [5.4862e-13],
        [1.3422e-12],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [4.3333e-19],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [4.7768e-07],
        [1.0000e+00],
        [6.7554e-12],
        [1.0000e+00],
        [9.9700e-09],
        [1.1951e-06],
        [1.0000e+00],
        [1.0000e+00],
        [4.0301e-13],
        [4.0222e-10],
        [1.0000e+00],
        [1.2033e-30],
        [1.0000e+00],
        [1.0000e+00],
        [7.0814e-14],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [6.6346e-08],
        [1.0000e+00],
        [3.9480e-12],
        [1.0000e+00],
        [1.0000e+00],
        [3.2298e-16],
        [6.9528e-20],
        [1.0000e+00],
        [5.9310e-34],
        [1.0000e+00],
        [3.3052e-07],
        [8.9197e-10],
        [8.4916e-14],
        [1.0000e+00],
        [3.7740e-09],
        [3.5425e-13],
        [2.0322e-20],
        [8.0247e-15],
        [1.0000e+00],
        [2.4492e-09],
        [1.0000e+00],
        [1.0000e+00],
        [5.8535e-22],
        [3.1931e-10],
        [1.0781e-10],
        [1.2999e-16],
        [1.0000e+00],
        [1.0000e+00],
        [7.6543e-17],
        [4.1282e-14],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [5.0808e-31],
        [1.0000e+00],
        [1.0357e-09],
        [1.0000e+00],
        [3.3517e-31],
        [1.0000e+00],
        [2.0929e-12],
        [1.0000e+00],
        [2.2723e-07],
        [1.0000e+00],
        [1.0000e+00],
        [3.4682e-11],
        [1.0000e+00],
        [8.1089e-24],
        [1.0000e+00],
        [4.6083e-29],
        [8.8073e-21],
        [3.5441e-14],
        [2.0895e-07],
        [2.4707e-08],
        [1.0000e+00],
        [1.8157e-07],
        [1.0000e+00],
        [1.0000e+00],
        [1.4945e-11],
        [1.8185e-12],
        [1.0035e-31],
        [1.0000e+00],
        [1.0000e+00],
        [8.4013e-07],
        [4.9424e-07],
        [2.5908e-21],
        [3.9454e-13],
        [4.2798e-17],
        [2.3836e-18],
        [5.9513e-07],
        [4.8654e-06],
        [1.1005e-07],
        [9.6565e-10],
        [1.0000e+00],
        [1.0000e+00],
        [1.9583e-13],
        [2.2274e-17],
        [1.0000e+00],
        [1.0000e+00],
        [2.1209e-12],
        [9.7197e-08],
        [6.9110e-11],
        [1.0000e+00],
        [1.9209e-14],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [3.3753e-07],
        [1.0710e-11],
        [1.3585e-13],
        [1.0000e+00],
        [2.3729e-22],
        [9.7339e-13],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [5.4023e-11],
        [1.0000e+00],
        [1.0000e+00],
        [7.8555e-11],
        [1.0000e+00],
        [1.5852e-13],
        [3.0759e-12],
        [1.0000e+00],
        [1.7310e-06],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.1813e-15],
        [5.6922e-19],
        [1.0000e+00],
        [1.6350e-14],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.8431e-06],
        [1.0000e+00],
        [2.4273e-17],
        [1.0000e+00],
        [2.2495e-13],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [4.7385e-07],
        [3.6363e-09],
        [1.0000e+00],
        [8.4714e-11],
        [3.4570e-13],
        [1.0000e+00],
        [3.0373e-17],
        [1.3614e-09],
        [1.0000e+00],
        [6.6554e-07],
        [4.6547e-20],
        [2.1885e-14],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [4.8972e-12],
        [1.6769e-12],
        [1.0000e+00],
        [3.4732e-15],
        [6.7583e-15],
        [8.7086e-25],
        [1.0123e-09],
        [2.9289e-08],
        [5.8780e-39],
        [2.0881e-14],
        [1.0000e+00],
        [1.7181e-09],
        [1.0000e+00],
        [1.8996e-15],
        [2.5248e-17],
        [1.0000e+00],
        [1.7710e-11],
        [1.0000e+00],
        [1.0000e+00],
        [8.5199e-06],
        [7.1810e-33],
        [1.0000e+00],
        [1.0335e-18],
        [1.0000e+00],
        [2.1272e-30],
        [3.1929e-24],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [7.1056e-13],
        [7.3681e-22],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [3.5328e-23],
        [9.3458e-08],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.4462e-25],
        [1.0000e+00],
        [1.0000e+00],
        [8.7466e-17],
        [1.0000e+00],
        [1.0000e+00],
        [5.0269e-13],
        [3.5592e-09],
        [2.0682e-10],
        [1.0061e-12],
        [9.9999e-01],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.7691e-22],
        [3.8480e-19],
        [1.0000e+00],
        [1.0000e+00],
        [7.3466e-07],
        [8.5948e-11],
        [1.0000e+00],
        [1.7310e-06],
        [1.2006e-11],
        [5.2096e-33],
        [1.0000e+00],
        [5.0874e-13],
        [1.0000e+00],
        [4.0452e-20],
        [1.0000e+00],
        [2.4319e-11],
        [4.5418e-08],
        [9.6719e-27],
        [1.0000e+00],
        [1.1523e-10],
        [1.0000e+00],
        [1.1233e-08],
        [1.0000e+00],
        [1.0000e+00],
        [3.8030e-08],
        [5.8437e-06],
        [4.0740e-11],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [5.6462e-11],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [4.7617e-09],
        [1.0000e+00],
        [3.0044e-19],
        [1.0000e+00],
        [1.0000e+00],
        [1.0820e-08],
        [1.0000e+00],
        [6.5817e-15],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [8.9099e-12],
        [1.3741e-06],
        [2.6022e-15],
        [1.0000e+00],
        [6.2634e-11],
        [1.0000e+00],
        [3.0271e-39],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [8.2249e-04],
        [6.5539e-08],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [5.1980e-09],
        [1.5318e-12],
        [7.4109e-09],
        [2.7556e-15],
        [7.3833e-07],
        [1.8009e-07],
        [1.0000e+00],
        [4.5214e-11],
        [5.7082e-12],
        [1.0000e+00],
        [1.0000e+00],
        [5.1360e-11],
        [1.0599e-08],
        [1.0000e+00],
        [5.1184e-16],
        [2.2450e-07],
        [6.0685e-08],
        [8.8499e-12],
        [1.0824e-12],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [3.5133e-19],
        [1.9280e-07],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.3835e-15],
        [1.2285e-12],
        [1.0000e+00],
        [9.1441e-08],
        [1.4274e-16],
        [3.5973e-12],
        [5.6507e-07],
        [1.0000e+00],
        [2.7861e-27],
        [1.0000e+00],
        [3.6147e-18],
        [8.8834e-10],
        [2.2726e-15],
        [2.9571e-08],
        [2.3905e-08],
        [1.0000e+00],
        [1.0000e+00],
        [6.8630e-12],
        [1.0000e+00],
        [1.0000e+00],
        [8.8646e-38],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [5.4823e-06],
        [1.8902e-07],
        [6.2151e-06],
        [2.2608e-09],
        [5.8893e-09],
        [5.9277e-35],
        [1.0000e+00],
        [1.0000e+00],
        [1.0346e-36],
        [1.0000e+00],
        [1.0000e+00],
        [1.1400e-12],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.8878e-13],
        [4.6045e-25],
        [1.9714e-08],
        [1.7071e-06],
        [1.0000e+00],
        [1.0000e+00],
        [9.9999e-01],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.1512e-08],
        [1.2562e-32],
        [6.0701e-05],
        [1.0446e-07],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [3.3442e-12],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.2263e-07],
        [8.8751e-10],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.5873e-06],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [9.9815e-01],
        [1.0000e+00],
        [2.6574e-11],
        [1.0000e+00],
        [8.9289e-16],
        [1.0661e-08],
        [1.0000e+00],
        [3.8447e-15],
        [1.0000e+00],
        [4.9434e-09],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.7705e-09],
        [1.0000e+00],
        [9.0765e-12],
        [1.6568e-22],
        [1.0000e+00],
        [3.7118e-15],
        [8.2185e-10],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [3.8024e-14],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [4.1724e-09],
        [1.0000e+00],
        [2.4444e-10],
        [5.5473e-14],
        [6.7498e-26],
        [4.5167e-08],
        [1.0000e+00],
        [2.9914e-19],
        [1.0000e+00],
        [1.0000e+00],
        [3.5664e-06],
        [1.0000e+00],
        [1.6103e-06],
        [1.0000e+00],
        [1.2317e-08],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [3.4416e-12],
        [1.0000e+00],
        [1.7454e-22],
        [1.0000e+00],
        [2.6619e-22],
        [4.4394e-05],
        [6.8070e-10],
        [1.0000e+00],
        [1.0000e+00],
        [1.5606e-25],
        [9.9296e-13],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.5479e-19],
        [1.0000e+00],
        [1.0000e+00],
        [1.2502e-16],
        [4.9235e-10],
        [8.2249e-04],
        [1.7822e-29],
        [1.7208e-13],
        [5.7934e-06],
        [1.0000e+00],
        [5.8547e-34],
        [1.0000e+00],
        [8.9631e-15],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.7677e-25],
        [2.4887e-08],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [6.0555e-12],
        [1.3976e-11],
        [7.2694e-12],
        [1.0000e+00],
        [1.0000e+00],
        [3.8431e-07],
        [3.0476e-15],
        [3.1762e-15],
        [2.3198e-34],
        [1.0000e+00],
        [2.1554e-13],
        [1.6552e-11],
        [1.0000e+00],
        [6.8611e-08],
        [3.5678e-19],
        [6.3688e-09],
        [1.0000e+00],
        [1.9865e-07],
        [1.0000e+00],
        [1.3561e-07],
        [9.5887e-10],
        [1.0000e+00],
        [1.3309e-11],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [7.2320e-17],
        [5.0319e-13],
        [1.0000e+00],
        [2.3231e-19],
        [8.3355e-13],
        [8.9346e-09],
        [1.0000e+00],
        [4.0053e-13],
        [1.0000e+00],
        [4.5948e-08],
        [1.0000e+00],
        [1.0000e+00],
        [3.9865e-08],
        [9.2344e-10],
        [4.0831e-12],
        [1.0000e+00],
        [1.0000e+00],
        [5.4003e-08],
        [2.1252e-14],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [4.8304e-06],
        [9.9999e-01],
        [1.2097e-10],
        [1.4807e-12],
        [5.2954e-08],
        [1.2981e-11],
        [1.0000e+00],
        [1.5609e-10],
        [1.0337e-06],
        [1.0000e+00],
        [1.0000e+00],
        [7.2601e-10],
        [2.6015e-11],
        [4.0895e-18],
        [1.0000e+00],
        [1.0000e+00],
        [1.9136e-30],
        [1.0000e+00],
        [4.3966e-11],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [6.4488e-11],
        [1.0000e+00],
        [2.7153e-09],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.2352e-11],
        [1.0000e+00],
        [2.5442e-13],
        [5.3482e-20],
        [1.0000e+00],
        [7.7598e-17],
        [4.3691e-12],
        [1.0000e+00],
        [1.7015e-28],
        [1.0000e+00],
        [1.0000e+00],
        [1.1883e-10],
        [2.2860e-08],
        [2.5602e-12],
        [5.6117e-13],
        [5.4564e-09],
        [1.0000e+00],
        [1.4737e-07],
        [1.0000e+00],
        [1.0000e+00],
        [1.5269e-12],
        [2.4767e-08],
        [1.0000e+00],
        [1.3555e-14],
        [1.7366e-12],
        [1.0000e+00],
        [1.0000e+00],
        [4.5586e-07],
        [1.0000e+00],
        [1.3741e-06],
        [1.7016e-11],
        [1.0000e+00],
        [3.3735e-09],
        [1.0000e+00],
        [6.6276e-06],
        [1.4308e-23],
        [1.0000e+00],
        [1.0000e+00],
        [9.9095e-10],
        [1.0000e+00],
        [1.0000e+00],
        [1.9658e-29],
        [1.0000e+00],
        [1.2357e-13],
        [1.0000e+00],
        [1.0000e+00],
        [9.7498e-08],
        [8.7968e-11],
        [4.7344e-08],
        [1.0000e+00],
        [4.4410e-37],
        [3.7485e-14],
        [5.6180e-11],
        [7.1872e-22],
        [1.0000e+00],
        [1.0000e+00],
        [2.6261e-14],
        [4.8668e-06],
        [1.0071e-08],
        [1.0871e-10],
        [1.5981e-10],
        [1.0000e+00],
        [2.2093e-07],
        [3.0221e-15],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [5.0190e-08],
        [1.0000e+00],
        [5.9400e-07],
        [1.1787e-12],
        [2.4028e-15],
        [6.3309e-24],
        [1.0000e+00],
        [1.7048e-06],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.1686e-27],
        [1.0000e+00],
        [4.8304e-06],
        [1.0000e+00],
        [6.6314e-08],
        [2.0999e-08],
        [1.0000e+00],
        [1.8951e-17],
        [4.3959e-08],
        [2.9466e-32],
        [2.9071e-13],
        [1.0000e+00],
        [1.1377e-06],
        [1.0000e+00],
        [4.1744e-15],
        [1.0000e+00],
        [7.8604e-09],
        [3.6231e-08],
        [1.0000e+00],
        [1.0000e+00],
        [3.7443e-09],
        [3.0113e-15],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.7225e-17],
        [1.5757e-08],
        [2.1029e-18],
        [1.0000e+00],
        [2.4945e-07],
        [1.0000e+00],
        [1.0000e+00],
        [1.0955e-04],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.2524e-10],
        [1.0000e+00],
        [2.0474e-11],
        [1.0000e+00],
        [9.6384e-16],
        [7.1075e-13],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [3.1319e-10],
        [1.0000e+00],
        [2.0716e-08],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [2.9775e-09],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [4.1768e-15],
        [1.0000e+00],
        [6.9567e-07],
        [4.2268e-07],
        [1.0000e+00],
        [1.0000e+00],
        [1.4883e-15],
        [2.8013e-12],
        [1.0000e+00],
        [7.0262e-10],
        [5.4125e-15],
        [2.6895e-33],
        [2.4506e-18],
        [1.0000e+00],
        [2.5652e-11],
        [1.2826e-08],
        [1.0562e-08],
        [1.0000e+00],
        [1.0000e+00],
        [3.6054e-10],
        [1.0000e+00],
        [1.0000e+00],
        [1.3106e-09],
        [1.0000e+00],
        [8.1108e-08],
        [6.3903e-08],
        [1.0000e+00],
        [1.0000e+00],
        [2.5173e-09],
        [2.4364e-11],
        [1.0000e+00],
        [1.0000e+00],
        [1.6809e-13],
        [1.0000e+00],
        [1.0000e+00],
        [2.2221e-15],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00],
        [3.3765e-21],
        [1.0000e+00],
        [1.0000e+00],
        [2.2045e-14],
        [1.0000e+00],
        [7.6733e-09]], device='cuda:0', grad_fn=<SigmoidBackward0>)
train_predictions = model(X_train)
ytrain_preds = torch.where(train_predictions > .5, 1, 0)
ytrain_preds.shape
torch.Size([835, 1])
y_train.shape
torch.Size([835])
y_train = y_train.to('cuda')
torch.sum(ytrain_preds.squeeze(1) == y_train)/len(y_train)
tensor(1., device='cuda:0')
X_test, y_test = X_test.to('cuda'), y_test.to('cuda')
ytest_preds = torch.where(model(X_test) > .5, 1, 0)
torch.sum(ytest_preds.squeeze(1) == y_test)/len(y_test)
tensor(0.9617, device='cuda:0')
#loss and optimizer
class TextModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.lin1 = nn.Linear(in_features = 500, out_features = 100)
    self.lin2 = nn.Linear(100, 100)
    self.lin3 = nn.Linear(100, 1)
    self.sigmoid = nn.Sigmoid()
    self.act = nn.ReLU()

  def forward(self, x):
    x = self.act(self.lin1(x))
    x = self.act(self.lin2(x))
    return self.sigmoid(self.lin3(x))


#training function
model = TextModel()
optimizer = optim.Adam(model.parameters(), lr = 0.01)
loss_fn = nn.BCELoss()
#torch.save(model, 'textmodel.pt')
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
#evaluate
for epoch in tqdm(range(10)):
  losses = 0
  for x,y in trainloader:
    x = x.to(device)
    y = y.to(device)
    yhat = model(x)
    y = y.reshape(-1, 1)
    loss = loss_fn(yhat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses += loss.item()
  if epoch % 10 == 0:
    print(f'Epoch {epoch} Loss: {losses}')
 40%|████      | 4/10 [00:00<00:00, 18.42it/s]
Epoch 0 Loss: 7.5143995471298695
100%|██████████| 10/10 [00:00<00:00, 18.21it/s]
Xt = torch.tensor(X_test.to(device), dtype = torch.float)
/tmp/ipython-input-823397259.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  Xt = torch.tensor(X_test.to(device), dtype = torch.float)

output = model(Xt) #model predictions
output[:5]
tensor([[4.5087e-07],
        [2.6038e-06],
        [1.0000e+00],
        [1.0000e+00],
        [1.0000e+00]], device='cuda:0', grad_fn=<SliceBackward0>)
#Converting probabilities to prediction
preds = torch.where(output >= .5, 1, 0)
preds.shape
torch.Size([209, 1])
sum(preds[:, 0] == y_test)/len(y_test)
tensor(0.9761, device='cuda:0')

Basic RNN#

# !pip install -U torch torchtext
#new tokenizer
tokenizer = Tokenizer()
tokenizer.fit_on_texts(df['prompt'].values)
#create sequences
sequences = tokenizer.texts_to_sequences(df['prompt'].values)
#look at first sequence
sequences[0]
[4, 14, 5, 3094, 4337, 6, 5, 5442]
#compare to text
df['prompt'].values[1]
'You are Joseph Seed from Far Cry 5. Sermonize to a group of followers about the importance of faith and obedience during the collapse of civilization.'
sequences[1]
[4,
 14,
 4338,
 7554,
 35,
 1060,
 2721,
 203,
 7555,
 3,
 5,
 657,
 6,
 3599,
 36,
 1,
 1805,
 6,
 2722,
 2,
 1583,
 470,
 1,
 3600,
 6,
 3095]
#pad and make all same length
sequences = pad_sequences(sequences, maxlen=100)
#examine results
sequences[1].shape
(100,)
sequences[1]
array([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    4,   14, 4338,
       7554,   35, 1060, 2721,  203, 7555,    3,    5,  657,    6, 3599,
         36,    1, 1805,    6, 2722,    2, 1583,  470,    1, 3600,    6,
       3095], dtype=int32)
#example rnn
rnn = nn.RNN(input_size = 100,
             hidden_size = 30,
             num_layers = 1,
             batch_first = True)
#pass data through
sample_sequence = torch.tensor(sequences[1],
                               dtype = torch.float,
                               ).reshape(1, -1)
sample_sequence.shape
torch.Size([1, 100])
#output
output, hidden = rnn(sample_sequence)
#hidden
hidden.shape
torch.Size([1, 30])
#linear layer
output.shape
torch.Size([1, 30])
#pass through linear
lin1 = nn.Linear(in_features = 30, out_features = 1)
lin1(output)
tensor([[0.6444]], grad_fn=<AddmmBackward0>)
X_train, X_test, y_train, y_test = train_test_split(sequences, yt)
train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), y_train)
test_dataset = TensorDataset(torch.tensor(X_test, dtype=torch.float32), y_test)
trainloader = DataLoader(train_dataset, batch_size = 32)
testloader = DataLoader(test_dataset, batch_size = 32)
model = nn.Sequential(nn.RNN(input_size = 100, hidden_size = 50, num_layers=2),
                      nn.Linear(in_features = 50, out_features=1),
                      nn.Sigmoid())
ex_rnn = nn.RNN(input_size = 100, hidden_size = 50, num_layers=2)
ex_rnn(train_dataset[0][0].unsqueeze(0))
(tensor([[-4.0267e-02, -6.5108e-01,  5.6029e-01, -1.8202e-01,  7.5730e-01,
          -7.5152e-01,  4.0158e-01,  1.2469e-03,  7.9779e-01,  3.2215e-01,
          -7.4351e-01, -6.9993e-01, -1.3441e-01, -6.5982e-01,  1.8362e-02,
          -2.6788e-01, -1.9714e-01, -2.8105e-01,  7.9029e-01,  5.8012e-01,
          -6.4096e-01,  6.6497e-01,  4.7763e-01,  6.3916e-01, -2.8295e-01,
          -3.6966e-01, -3.0895e-01,  5.7973e-01, -5.4188e-01, -1.9349e-01,
          -2.6235e-01,  8.9672e-01,  1.8769e-01, -1.2924e-02,  4.4903e-01,
           4.5544e-01,  2.2435e-01,  1.1875e-01,  3.5073e-01,  7.1262e-01,
          -3.2007e-01, -3.6262e-01, -1.0571e-01, -5.7416e-01, -3.8877e-01,
           4.8972e-01, -8.4148e-04,  3.9820e-01, -4.7924e-01,  1.2552e-01]],
        grad_fn=<SqueezeBackward1>),
 tensor([[ 1.0000e+00, -1.0000e+00,  1.0000e+00, -1.0000e+00, -1.0000e+00,
           1.0000e+00,  1.0000e+00,  1.0000e+00, -1.0000e+00,  1.0000e+00,
           1.0000e+00, -1.0000e+00, -1.0000e+00,  1.0000e+00, -9.9918e-01,
           1.0000e+00,  1.0000e+00, -1.0000e+00,  1.0000e+00, -1.0000e+00,
           1.0000e+00, -1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,
          -1.0000e+00,  1.0000e+00,  1.0000e+00, -1.0000e+00,  1.0000e+00,
           1.0000e+00,  1.0000e+00, -1.0000e+00, -1.0000e+00,  1.0000e+00,
          -1.0000e+00, -1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,
           1.0000e+00,  1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00,
           1.0000e+00, -1.0000e+00, -1.0000e+00,  1.0000e+00,  1.0000e+00],
         [-4.0267e-02, -6.5108e-01,  5.6029e-01, -1.8202e-01,  7.5730e-01,
          -7.5152e-01,  4.0158e-01,  1.2469e-03,  7.9779e-01,  3.2215e-01,
          -7.4351e-01, -6.9993e-01, -1.3441e-01, -6.5982e-01,  1.8362e-02,
          -2.6788e-01, -1.9714e-01, -2.8105e-01,  7.9029e-01,  5.8012e-01,
          -6.4096e-01,  6.6497e-01,  4.7763e-01,  6.3916e-01, -2.8295e-01,
          -3.6966e-01, -3.0895e-01,  5.7973e-01, -5.4188e-01, -1.9349e-01,
          -2.6235e-01,  8.9672e-01,  1.8769e-01, -1.2924e-02,  4.4903e-01,
           4.5544e-01,  2.2435e-01,  1.1875e-01,  3.5073e-01,  7.1262e-01,
          -3.2007e-01, -3.6262e-01, -1.0571e-01, -5.7416e-01, -3.8877e-01,
           4.8972e-01, -8.4148e-04,  3.9820e-01, -4.7924e-01,  1.2552e-01]],
        grad_fn=<SqueezeBackward1>))
#class
class BasicRNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.rnn = nn.RNN(input_size = 100,
                    hidden_size = 100,
                    num_layers = 2,
                    batch_first = True,
                    bidirectional = True)
    self.lin1 = nn.Linear(in_features = 200, out_features=128)
    self.lin2 = nn.Linear(128,64)
    self.lin3 = nn.Linear(64, 1)
    self.act = nn.ReLU()
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x, _ = self.rnn(x) #extracting important information
    x = self.act(self.lin1(x)) #multilayer perceptron -- to predict
    x = self.act(self.lin2(x))
    x = self.sigmoid(self.lin3(x))
    return x
#optimizer and loss
model = BasicRNN()
optimizer = optim.Adam(model.parameters(), lr = 0.01)
loss_fn = nn.BCELoss()
model = model.to(device)
#train
for epoch in tqdm(range(100)):
  losses = 0
  for x,y in trainloader:
    x,y = x.to(device), y.to(device)
    yhat = model(x)
    y = y.reshape(-1, 1)
    loss = loss_fn(yhat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses += loss.item()
  if epoch % 100 == 0:
    print(f'Epoch {epoch} Loss: {losses}')
  2%|▏         | 2/100 [00:00<00:09, 10.49it/s]
Epoch 0 Loss: 16.39161452651024
100%|██████████| 100/100 [00:09<00:00, 10.90it/s]
Xt = torch.tensor(X_test, dtype = torch.float)
output = model(Xt.to(device))
preds = torch.where(output >= .5, 1, 0)
#preds = output.argmax(axis = 1)
y_test.shape, preds.shape
(torch.Size([261]), torch.Size([261, 1]))
y_test = y_test.to(device)
sum(preds.squeeze(-1) == y_test)/len(y_test)
tensor(0.7126, device='cuda:0')

LSTM#

# nn.LSTM()
class BasicLSTM(nn.Module):
  def __init__(self):
    super().__init__()
    self.rnn = nn.LSTM(input_size = 100,
                    hidden_size = 100,
                    num_layers = 1,
                    batch_first = True)

    self.lin1 = nn.Linear(in_features = 100, out_features=100)
    self.lin2 = nn.Linear(in_features = 100, out_features = 1)
    self.act = nn.ReLU()
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x, _ = self.rnn(x)
    x = self.act(self.lin1(x))
    x = self.lin2(x)
    return self.sigmoid(x)
model = BasicLSTM()
optimizer = optim.Adam(model.parameters(), lr = 0.01)
loss_fn = nn.BCELoss()
model = model.to(device)
#train
for epoch in range(100):
  losses = 0
  for x,y in trainloader:
    x,y = x.to(device), y.to(device)
    yhat = model(x)
    y = y.reshape(-1, 1)
    loss = loss_fn(yhat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses += loss.item()
  if epoch % 10 == 0:
    print(f'Epoch {epoch} Loss: {losses}')
Epoch 0 Loss: 16.08923327922821
Epoch 10 Loss: 12.361571580171585
Epoch 20 Loss: 11.54874774813652
Epoch 30 Loss: 11.193515717983246
Epoch 40 Loss: 10.83355501294136
Epoch 50 Loss: 10.410791963338852
Epoch 60 Loss: 10.066820561885834
Epoch 70 Loss: 9.951143264770508
Epoch 80 Loss: 9.611373007297516
Epoch 90 Loss: 9.232007339596748
Xt = torch.tensor(X_test, dtype = torch.float).to(device)
output = model(Xt)
preds = torch.where(output >= .5, 1, 0)
sum(preds[:, 0] == y_test)/len(y_test)
tensor(0.6437, device='cuda:0')
class RNN2(nn.Module):
  def __init__(self):
    super().__init__()
    self.rnn = nn.GRU(input_size = 100,
                    hidden_size = 50,
                    num_layers = 2,
                    batch_first = True)

    self.lin1 = nn.Linear(in_features = 50, out_features=100)
    self.lin2 = nn.Linear(in_features = 100, out_features = 1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x, _ = self.rnn(x)
    x = self.lin1(x)
    x = self.lin2(x)
    return self.sigmoid(x)
model = RNN2()
optimizer = optim.Adam(model.parameters(), lr = 0.01)
loss_fn = nn.BCELoss()
model = model.to(device)
#train
for epoch in range(100):
  losses = 0
  for x,y in trainloader:
    x, y = x.to(device), y.to(device)
    yhat = model(x)
    y = y.reshape(-1, 1)
    loss = loss_fn(yhat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses += loss.item()
  if epoch % 100 == 0:
    print(f'Epoch {epoch} Loss: {losses}')
Epoch 0 Loss: 2.572824880015105
Epoch 10 Loss: 1.6687806567642838
Epoch 20 Loss: 2.084140522405505
Epoch 30 Loss: 0.14215507486005663
Epoch 40 Loss: 2.5895073297433555
Epoch 50 Loss: 4.6147964131087065
Epoch 60 Loss: 1.1158297238289379
Epoch 70 Loss: 1.6419790575746447
Epoch 80 Loss: 1.0041726826311788
Epoch 90 Loss: 7.350915879011154
Epoch 100 Loss: 0.32188589729139494
Epoch 110 Loss: 0.826660448419716
Epoch 120 Loss: 0.09769756537468766
Epoch 130 Loss: 0.010040155945596041
Epoch 140 Loss: 0.005505460409189311
Epoch 150 Loss: 0.0035882416049775046
Epoch 160 Loss: 0.002542968425865322
Epoch 170 Loss: 0.0018933118432123974
Epoch 180 Loss: 0.0014562517539018494
Epoch 190 Loss: 0.001146036667118655
Epoch 200 Loss: 0.0009173696277831878
Epoch 210 Loss: 0.0007440028362964424
Epoch 220 Loss: 0.0006096175069977006
Epoch 230 Loss: 0.0005037316108048862
Epoch 240 Loss: 0.00041903144462378944
Epoch 250 Loss: 0.0003505299123727282
Epoch 260 Loss: 0.0002946198410908145
Epoch 270 Loss: 0.0002486322255693882
Epoch 280 Loss: 0.0002105297801064561
Epoch 290 Loss: 0.0001787880755889093
Epoch 300 Loss: 0.0001522262615572187
Epoch 310 Loss: 0.00012987622631860287
Epoch 320 Loss: 0.0001110454306913625
Epoch 330 Loss: 9.509693933032644e-05
Epoch 340 Loss: 8.157440949980001e-05
Epoch 350 Loss: 7.005838113996249e-05
Epoch 360 Loss: 6.024776315617349e-05
Epoch 370 Loss: 5.187681379641715e-05
Epoch 380 Loss: 4.469714198158123e-05
Epoch 390 Loss: 3.8565075329249085e-05
Epoch 400 Loss: 3.3286928844570044e-05
Epoch 410 Loss: 2.875219730579491e-05
Epoch 420 Loss: 2.4848146654375767e-05
Epoch 430 Loss: 2.150737431606364e-05
Epoch 440 Loss: 1.860629562645599e-05
Epoch 450 Loss: 1.6106333091094147e-05
Epoch 460 Loss: 1.3980879112434865e-05
Epoch 470 Loss: 1.2109346418927346e-05
Epoch 480 Loss: 1.0493581766806125e-05
Epoch 490 Loss: 9.120325635563579e-06
output = model(Xt)
preds = torch.where(output >= .5, 1, 0)
sum(preds.squeeze(-1) == y_test)/len(y_test)
tensor(0.6475, device='cuda:0')