Introduction to Artificial Neural Networks#

OBJECTIVES

  • Understand connections between linear models and neural networks

  • Use pytorch to build basic multilayer perceptron models for regression and classification

NEXT CLASS!!!

NYU faculty, administrators, staff, and students may apply for a WRDS account. Non-PhD students can request a temporary Research Assistant account when working with faculty.

  • Visit Wharton Research Database Services

  • Click the Register button.

  • Complete the form, read the Terms of Use, and submit.

  • The NYU Representative for WRDS will review your application and affiliation.

  • Upon approval, WRDS will email instructions for setting your password.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

Please use the form here to discuss your time series problem and final project ideas.

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

Linear Models Revisited#

Recall our logistic regression model. We can think about this model as a linear function composed into the sigmoid function.

\[l(x) = \beta_0 + \beta_1 X\]
\[\sigma(x) = \frac{1}{1 + e^{-x}}\]
\[\text{Logistic Model} = \sigma(l(x))\]

We also understodd the quality of the model minimizing a loss function – log loss. This is also called binary cross entropy, and the “best” model was the one whose parameters minimized the log loss or BCE. Together, we have two important ideas – a linear function and the goal of minimizing a loss function.

Linear regression is a similar algorithm with a different loss function – Mean Squared Error.

Minimizing Functions#

x = torch.linspace(-2, 2, 100)
def f(x): return x**2
plt.plot(x, f(x));
_images/d4ed1c4da283aa41587bd6ca9433c806f7d7daaeb313e712033d1a749fcbb3aa.png
def df(x): return 2*x
def tan_line(x, a): return df(a)*(x - a) + f(a)
plt.plot(x, f(x))
plt.plot(x, tan_line(x, 1), 'r')
plt.plot(1, f(1), 'ro')
plt.ylim(0, 4)
plt.grid();
_images/11c16c4bfafd84c3e9aea2f421c10da85d1d0986d713ca36a9670f0653517980.png
\[x_{n + 1} = x_n - \text{step size} * f'(x_n)\]
x0 = 1
step_size = 0.1
x1 = x0 - step_size*df(x0)
x1
0.8
x2 = x1 - step_size*df(x1)
x2
0.64
xs = [x0]
for i in range(100):
    xnext = xs[-1] - step_size*df(xs[-1])
    xs.append(xnext)
xs = np.array(xs)
plt.plot(xs, f(xs), 'ro', label = 'Iterations of Descent')
plt.plot(x, f(x))
plt.grid()
plt.title('Minimizing A Function with Gradient Descent')
plt.legend();
_images/af6b17b52e5b0c6947285fbfe294420312a82a30a61b32af562509f8c3046bb3.png

nn.Linear#

This layer performs a linear operation on its input using randomly initialized weights.

linear = nn.Linear(in_features=1, out_features = 1)
list(linear.parameters()) #slope and intercept/bias
[Parameter containing:
 tensor([[-0.1917]], requires_grad=True),
 Parameter containing:
 tensor([0.6534], requires_grad=True)]
linear(torch.tensor([5.]))
tensor([-0.3048], grad_fn=<ViewBackward0>)

A Visual Example#

from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer()
X, Y = cancer.data, cancer.target
l1 = nn.Linear(in_features=30, out_features = 1)
X = torch.tensor(X, dtype = torch.float32)
Y = torch.tensor(Y, dtype = torch.float32)
l1(X)
tensor([[385.6059],
        [425.1644],
        [381.2654],
        [123.1488],
        [380.3270],
        [155.7974],
        [340.3042],
        [191.4602],
        [162.7812],
        [153.5927],
        [252.9365],
        [266.2387],
        [329.4026],
        [220.5488],
        [166.8834],
        [208.0313],
        [233.6949],
        [271.0690],
        [466.2239],
        [166.1907],
        [149.5128],
        [ 77.3162],
        [218.0750],
        [511.1205],
        [389.4114],
        [308.9959],
        [198.7623],
        [331.0209],
        [254.9189],
        [285.9835],
        [360.2159],
        [167.4129],
        [292.0126],
        [384.0669],
        [264.4366],
        [272.3440],
        [186.9859],
        [141.8659],
        [195.4436],
        [168.5011],
        [175.0093],
        [115.1450],
        [359.5890],
        [185.4277],
        [164.1520],
        [342.8755],
        [ 58.5466],
        [167.1497],
        [133.8530],
        [165.1459],
        [125.5530],
        [160.8275],
        [126.6336],
        [312.8376],
        [225.4085],
        [120.5401],
        [420.6554],
        [212.5527],
        [152.6815],
        [ 64.7041],
        [ 91.0730],
        [ 66.4581],
        [209.5568],
        [ 75.6245],
        [177.3172],
        [207.6735],
        [ 79.6804],
        [113.6258],
        [ 75.5284],
        [139.5494],
        [386.6965],
        [ 70.4334],
        [333.2090],
        [179.7733],
        [140.7128],
        [266.7924],
        [154.0018],
        [342.4907],
        [379.5468],
        [147.6422],
        [121.8749],
        [147.6585],
        [581.4711],
        [323.4955],
        [131.3045],
        [347.8865],
        [191.9728],
        [349.0177],
        [136.8732],
        [191.1504],
        [193.9312],
        [206.3083],
        [178.2025],
        [164.0910],
        [224.3710],
        [395.0594],
        [125.7926],
        [ 87.7052],
        [120.0431],
        [191.5875],
        [193.5374],
        [ 43.0810],
        [132.4814],
        [ 86.7152],
        [ 98.0784],
        [174.0295],
        [122.3387],
        [132.3212],
        [505.1856],
        [119.6320],
        [ 86.6926],
        [133.6069],
        [176.5305],
        [ 94.4445],
        [ 67.2337],
        [133.9390],
        [ 68.3867],
        [227.3170],
        [264.8519],
        [294.5808],
        [117.5188],
        [343.3098],
        [521.0482],
        [180.9499],
        [155.3643],
        [171.3465],
        [183.1933],
        [351.0328],
        [190.2109],
        [364.2998],
        [129.9817],
        [245.6834],
        [257.1158],
        [219.8407],
        [346.1803],
        [151.9675],
        [127.8226],
        [112.8264],
        [219.1846],
        [108.3369],
        [ 82.2619],
        [266.3422],
        [118.0320],
        [150.0077],
        [104.0613],
        [123.6807],
        [132.6603],
        [198.1650],
        [183.3826],
        [171.2808],
        [150.9269],
        [ 60.3736],
        [ 91.8039],
        [107.6782],
        [158.3378],
        [134.0098],
        [299.6031],
        [252.9788],
        [127.6878],
        [108.7975],
        [127.5925],
        [352.4642],
        [438.9236],
        [134.4822],
        [532.8017],
        [193.6805],
        [100.0958],
        [280.2904],
        [344.6536],
        [194.3747],
        [132.0196],
        [199.2160],
        [236.1602],
        [102.1293],
        [100.1049],
        [ 64.1702],
        [ 91.5782],
        [238.2876],
        [149.8196],
        [142.0125],
        [712.1172],
        [438.1293],
        [260.3235],
        [114.5135],
        [217.8821],
        [ 98.6447],
        [317.0545],
        [123.1188],
        [120.9522],
        [132.3665],
        [180.4637],
        [146.8713],
        [ 79.8177],
        [161.5931],
        [191.2035],
        [146.2700],
        [183.4704],
        [299.8969],
        [366.0995],
        [214.4099],
        [143.3075],
        [285.3851],
        [486.2537],
        [222.6867],
        [150.7548],
        [220.1203],
        [ 82.9459],
        [278.0937],
        [152.9500],
        [214.3382],
        [390.8600],
        [127.2601],
        [713.2225],
        [264.3916],
        [185.6940],
        [172.5236],
        [126.4306],
        [ 95.5270],
        [420.7129],
        [470.8332],
        [165.2664],
        [162.2132],
        [ 90.6141],
        [238.8570],
        [163.5269],
        [197.4455],
        [ 94.6807],
        [194.5492],
        [147.2164],
        [156.8078],
        [270.2161],
        [112.7176],
        [115.2002],
        [415.9377],
        [ 80.8582],
        [173.7032],
        [594.7373],
        [400.8531],
        [183.8498],
        [302.9858],
        [165.3815],
        [133.2438],
        [112.7207],
        [172.5309],
        [339.1064],
        [ 98.5858],
        [150.5862],
        [148.6569],
        [106.0068],
        [117.5114],
        [443.6562],
        [120.8196],
        [405.5094],
        [279.1884],
        [401.7118],
        [185.6088],
        [406.7041],
        [215.9131],
        [265.1793],
        [231.7568],
        [406.2415],
        [281.8440],
        [296.1489],
        [227.0885],
        [304.1989],
        [613.3719],
        [102.1399],
        [161.5824],
        [145.5543],
        [100.0941],
        [173.2524],
        [110.1315],
        [504.5834],
        [ 84.7809],
        [303.4726],
        [123.1559],
        [112.1837],
        [309.0811],
        [170.8656],
        [164.8572],
        [370.9340],
        [126.8918],
        [358.9413],
        [238.5231],
        [146.1127],
        [139.1817],
        [129.5569],
        [142.9206],
        [111.8856],
        [112.9591],
        [191.3653],
        [196.5560],
        [144.3390],
        [124.4168],
        [136.6943],
        [162.8214],
        [ 99.0971],
        [131.3993],
        [188.7959],
        [ 93.9362],
        [423.6980],
        [135.5242],
        [390.3335],
        [ 93.7101],
        [117.5654],
        [119.0636],
        [153.3978],
        [ 69.8470],
        [164.5318],
        [157.4952],
        [118.5574],
        [196.6161],
        [146.5678],
        [113.6858],
        [ 62.3848],
        [134.7057],
        [128.0575],
        [324.8601],
        [ 72.1593],
        [132.7603],
        [ 94.4336],
        [384.9345],
        [143.8275],
        [412.6895],
        [135.2859],
        [139.5942],
        [178.2682],
        [127.5211],
        [251.7369],
        [239.8643],
        [241.1124],
        [149.9733],
        [109.4512],
        [114.7387],
        [132.5284],
        [298.3152],
        [142.7844],
        [380.9796],
        [ 91.4125],
        [603.4620],
        [195.0843],
        [ 82.7392],
        [105.9966],
        [358.7552],
        [122.8373],
        [ 89.7529],
        [131.8483],
        [200.8518],
        [114.8557],
        [121.6879],
        [126.3329],
        [220.9971],
        [677.0919],
        [231.7601],
        [110.6581],
        [137.0247],
        [145.1196],
        [168.9988],
        [ 72.4895],
        [ 93.7470],
        [142.4478],
        [154.4638],
        [141.5353],
        [242.4823],
        [158.5611],
        [400.9662],
        [383.4135],
        [140.1878],
        [602.6859],
        [483.6783],
        [261.1656],
        [199.9380],
        [384.8380],
        [425.2932],
        [163.2012],
        [216.3300],
        [ 91.0975],
        [163.9034],
        [161.5601],
        [111.9773],
        [111.8852],
        [106.7326],
        [123.6611],
        [138.0370],
        [151.9718],
        [189.1689],
        [130.0730],
        [175.0304],
        [111.0453],
        [331.7663],
        [ 92.7705],
        [ 73.3990],
        [268.3349],
        [452.0145],
        [131.6606],
        [169.6781],
        [161.4070],
        [145.3985],
        [113.1418],
        [130.0435],
        [298.6223],
        [133.1189],
        [150.3924],
        [142.3147],
        [132.8214],
        [110.8583],
        [228.1250],
        [154.1153],
        [304.8116],
        [140.7985],
        [119.8614],
        [110.0282],
        [ 75.7031],
        [204.6302],
        [218.7270],
        [126.1414],
        [ 84.1175],
        [330.3467],
        [138.8142],
        [110.9491],
        [121.7879],
        [192.4903],
        [115.7817],
        [167.5982],
        [ 90.4162],
        [ 90.3155],
        [100.6125],
        [111.2829],
        [105.5458],
        [143.3810],
        [198.0254],
        [129.3493],
        [359.4892],
        [356.5425],
        [190.7578],
        [188.6352],
        [149.3391],
        [178.6596],
        [175.9903],
        [170.0743],
        [110.2574],
        [289.8058],
        [169.6444],
        [ 93.9841],
        [296.5919],
        [126.9309],
        [313.2699],
        [196.1140],
        [194.4151],
        [442.3218],
        [124.2708],
        [345.5390],
        [128.3248],
        [182.6312],
        [146.0563],
        [167.6344],
        [125.4113],
        [154.9266],
        [151.9552],
        [ 86.0437],
        [334.9469],
        [905.5530],
        [183.1540],
        [118.8649],
        [159.3483],
        [165.1531],
        [160.4539],
        [ 87.6969],
        [318.0310],
        [122.8011],
        [ 88.8856],
        [136.9910],
        [206.0622],
        [138.0359],
        [102.8029],
        [144.6574],
        [188.4209],
        [171.7232],
        [114.1210],
        [230.8881],
        [131.3198],
        [186.2254],
        [156.5313],
        [164.7318],
        [208.2155],
        [138.4752],
        [194.9171],
        [376.0541],
        [126.8436],
        [252.6947],
        [142.1432],
        [289.2410],
        [319.2804],
        [133.3829],
        [156.5364],
        [192.8399],
        [145.2354],
        [141.0107],
        [347.9254],
        [403.0551],
        [202.4365],
        [181.9880],
        [134.9574],
        [579.9763],
        [ 71.7991],
        [ 78.4644],
        [127.4013],
        [102.2906],
        [229.3870],
        [215.3423],
        [117.9109],
        [189.3057],
        [180.6059],
        [198.0365],
        [217.2246],
        [114.2890],
        [331.6286],
        [372.3651],
        [152.2468],
        [146.3449],
        [ 76.4938],
        [586.2101],
        [108.7388],
        [166.9078],
        [ 87.2658],
        [ 64.2720],
        [164.6628],
        [133.7273],
        [166.0365],
        [129.3683],
        [128.0057],
        [125.8141],
        [174.1215],
        [390.0528],
        [101.7245],
        [407.6327],
        [178.7321],
        [118.5261],
        [ 58.0628],
        [ 53.4199],
        [113.3562],
        [192.0620],
        [196.8395],
        [155.0190],
        [167.5977],
        [171.6828],
        [ 93.1960],
        [ 88.9811],
        [ 85.8754],
        [114.7710],
        [103.1692],
        [108.3144],
        [146.9176],
        [ 75.5056],
        [147.6724],
        [ 91.6771],
        [ 87.8983],
        [ 82.3798],
        [183.4592],
        [116.4222],
        [173.6772],
        [111.6660],
        [213.3705],
        [417.5299],
        [465.6631],
        [394.5838],
        [260.8066],
        [402.2528],
        [ 60.4572]], grad_fn=<AddmmBackward0>)
output_activation = nn.Sigmoid()
yhat = output_activation(l1(X))
yhat.shape
torch.Size([569, 1])
loss = nn.BCELoss()
loss(Y.unsqueeze(-1), yhat)
tensor(37.2583, grad_fn=<BinaryCrossEntropyBackward0>)

Training the Model#

Ingredients:

  • Model: We will build this with either nn.Sequential or by subclassing the nn.Module.

  • Loss Function: Depending on the task this will be different

    • Regression is nn.MSELoss

    • Binary classification is nn.BCELoss

    • Multi-class classification is nn.CrossEntropyLoss

  • Optimizer: Function to minimize the loss – today we will use optim.SGD

model = nn.Sequential(nn.Linear(30, 100),
                      nn.ReLU(),
                      nn.Linear(100, 1),
                      nn.Sigmoid())
loss_fn = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr = 0.05)
yhat = model(X)
loss = loss_fn(yhat, Y.unsqueeze(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
X.shape
torch.Size([569, 30])
X = (X - X.mean())/X.std()
losses = []
for epoch in range(1000):
    yhat = model(X)
    loss = loss_fn(yhat, Y.unsqueeze(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x15b828d40>]
_images/d28c9d108e85d3cda040bf4937483ddcbfe1d3dbaaff274fb6ec79b5a4dade5b.png

Making Predictions#

The model returns probabilites that we need to use in order to make predictions.

yhat = model(X)
preds = torch.where(yhat > .5, 1, 0)
(preds.squeeze(-1) == Y).sum()/len(Y)
tensor(0.9227)

Example: Handwritten Digits#

Below, the handwritten digits data is loaded and transformed to torch.tensor objects. Recall these are 8 x 8 images of handwritten digits, and there are ten categories to predict.

from sklearn.datasets import load_digits
digits = load_digits()
X, y = digits.data, digits.target
X = (X - X.mean())/X.std()
X = torch.tensor(X, dtype = torch.float32)
Y = torch.tensor(y, dtype = torch.float32)
X.shape
torch.Size([1797, 64])

Problem#

Below, build a model with the following:

  • A Linear layer with 100 neurons

  • A ReLU activation on this layer

  • A Linear layer with another 100 neurons

  • A ReLU activation on this layer

  • A Linear output layer with 10 neurons

Use the optim.SGD to optimize the parameters of the model.

Use nn.CrossEntropyLoss as your loss function for multiclass classification.

model = ''
optimizer = ''
loss_fn = ''
from tqdm import tqdm
def train_function(model, X, Y, epochs, optimizer, loss_fn, losses = []):
    '''
    The function trains a pytorch model given a model,
    number of training epochs, optimizer, and loss function.
    It returns a list of losses from training.
    '''
    for epoch in tqdm(range(epochs)):
        yhat = model(X)
        loss = loss_fn(yhat, Y.unsqueeze(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return losses
train_function(model, X, Y, epochs = 1000, optimizer = optimizer, loss_fn = loss_fn)
  0%|                                                  | 0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[377], line 1
----> 1 losses = train_function(model, 1000, optimizer, loss_fn)

Cell In[376], line 3, in train_function(model, epochs, optimizer, loss_fn, losses)
      1 def train_function(model, epochs, optimizer, loss_fn, losses = []):
      2     for epoch in tqdm(range(1000)):
----> 3         yhat = model(X)
      4         loss = loss_fn(yhat, Y.unsqueeze(-1))
      5         optimizer.zero_grad()

TypeError: 'str' object is not callable
yhat = model(X)
predictions = torch.argmax(yhat)
sum(predictions == Y)/len(Y)