5 minute read

Multi-Layer Perceptron(MLP) ๊ตฌํ˜„

MNIST ๋ฐ์ดํ„ฐ์…‹์„ ์ด์šฉํ•˜์—ฌ Multi-Layer Perceptron(MLP)์„ MLX์™€ CPU, GPU(MPS)๋ฅผ ์ด์šฉํ•ด ๊ตฌํ˜„ํ•˜๊ณ  ๋น„๊ตํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

(CPU์™€ MPS๋Š” Torch ์‚ฌ์šฉ)

MLX๋ฅผ ์ด์šฉํ•˜์—ฌ MLP ๊ตฌํ˜„

๊ด€๋ จ ๋ชจ๋“ˆ์„ import ํ•ฉ๋‹ˆ๋‹ค.

import mlx.nn as nn์ด import torch.nn as nn๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•˜๋‹ค๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

MLX์™€ GPU(MPS)๋ฅผ ๋น„๊ตํ•  ๋•Œ, ์•ฝ๊ฐ„์˜ ์ฝ”๋“œ ์ˆ˜์ •๊ณผ import ๋ณ€๊ฒฝ๋งŒ์œผ๋กœ ์‰ฝ๊ฒŒ ๋น„๊ตํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

import numpy as np
from matplotlib import pyplot as plt

from time import time

MLP class ๋ถ€๋ถ„์ด torch.nn์„ ์‚ฌ์šฉํ•  ๋•Œ์™€ ์œ ์‚ฌํ•จ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์•ž์—์„œ ์–ธ๊ธ‰ํ•œ๋Œ€๋กœ ์ฝ”๋“œ ์ˆ˜์ • ์—†์ด import๋งŒ ๋ณ€๊ฒฝํ•˜์—ฌ MLX์™€ GPU(MPS)๋ฅผ ๋น„๊ตํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

class MLP(nn.Module):
    def __init__(
        self,
        num_layers: int,
        input_dims: int,
        hidden_dims: int,
        output_dims: int
    ):
        super().__init__()
        layer_sizes = [input_dims] + [hidden_dims] * num_layers + [output_dims]
        self.layers = self._make_layers(layer_sizes)

    def _make_layers(self, layer_sizes):
        layers = []
        for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]):
            layers += [nn.Linear(idim, odim), nn.ReLU()]
        
        return nn.Sequential(*layers[:-1])
    
    def __call__(self, x):
        return self.layers(x)

Loss function๊ณผ evaluation function๋„ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.

def loss_fn(model, X, y):
    return mx.mean(nn.losses.cross_entropy(model(X), y)) # nn.losses.cross_entropy๋Š” logit๊ณผ target์‚ฌ์ด์˜ loss๋ฅผ ๊ณ„์‚ฐํ•ด์ค€๋‹ค.

def eval_fn(model, X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)

Hyperparam์„ ์„ค์ •ํ•˜๊ณ , MNIST ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค์šด๋ฐ›์•„ ์ „์ฒ˜๋ฆฌ ํ•ด์ฃผ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

MLP๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— (28 X 28)์˜ ์ด๋ฏธ์ง€๋ฅผ 768dimensions์œผ๋กœ flattenํ•ด์ค๋‹ˆ๋‹ค.

num_layers = 2
hidden_dim = 256
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-2

# Data Load
import mnist
train_images, train_labels, test_images, test_labels = map(
    mx.array, [
        mnist.train_images(),
        mnist.train_labels(),
        mnist.test_images(),
        mnist.test_labels(),
    ]
)

# Flatten Images
train_images = mx.reshape(train_images, [train_images.shape[0],-1])
valid_images, test_images = test_images[:-10], test_images[-10:]
valid_labels, test_labels = test_labels[:-10], test_labels[-10:]
valid_images = mx.reshape(valid_images, [valid_images.shape[0],-1])

Batch iterator๋„ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.

torch์—์„œ๋Š” dataloader๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s: s + batch_size]
        yield X[ids], y[ids]

Generator๋ž€?

  • iterator๋ฅผ ์ƒ์„ฑํ•ด์ฃผ๋Š” ํ•จ์ˆ˜ ๋˜๋Š” ๊ฐ์ฒด์ž…๋‹ˆ๋‹ค.
  • ํ•จ์ˆ˜ ๋‚ด๋ถ€์— yield ํ‚ค์›Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ’์„ ๋ฐ˜ํ™˜ํ•˜๋ฉด์„œ ์ƒํƒœ๋ฅผ ๋ณด์กดํ•ฉ๋‹ˆ๋‹ค.

Generator ํŠน์ง•

  • Iterableํ•œ ์ˆœ์„œ๊ฐ€ ์ง€์ •๋ฉ๋‹ˆ๋‹ค. (๋ชจ๋“  generator๋Š” iterator)
  • ๋А์Šจํ•˜๊ฒŒ ํ‰๊ฐ€๋˜์–ด ํ•„์š”์— ๋”ฐ๋ผ ๊ฐ’์„ ์ƒ์„ฑํ•˜๋ฉฐ, ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ๋‚ด๋ถ€ ์ƒํƒœ๋ฅผ ์œ ์ง€ํ•˜๋ฉฐ, ํ•จ์ˆ˜ ํ˜ธ์ถœ ๊ฐ„์— ์ •๋ณด๋ฅผ ๊ธฐ์–ตํ•ฉ๋‹ˆ๋‹ค.
  • ๋ฌดํ•œํ•œ ์ˆœ์„œ๊ฐ€ ์žˆ๋Š” ๊ฐ์ฒด๋ฅผ ๋ชจ๋ธ๋งํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ž์—ฐ์Šค๋Ÿฌ์šด ์ŠคํŠธ๋ฆผ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ตฌ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

SGD๋กœ MLPํ•™์Šต์„ ํ•ฉ๋‹ˆ๋‹ค.

Parameter initialize๊ฐ€ ๋žœ๋ค์ด๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต์ด ์•ˆ๋  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

(์ดˆ๊ธฐํ™”์— ๋”ฐ๋ฅธ ๋ฌธ์ œ)

# Model Load
model = MLP(num_layers=num_layers,
            input_dims=train_images.shape[-1],
            hidden_dims=hidden_dim,
            output_dims=num_classes)
mx.eval(model.parameters())

# loss and grad fn
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# optimizer
optimizer = optim.SGD(learning_rate=learning_rate)

accuracy = []
tic = time()
for epoch in range(num_epochs):
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        loss, grads = loss_and_grad_fn(model, X, y)
        optimizer.update(model, grads)

    accuracy += [eval_fn(model, valid_images, valid_labels).item()]

mx.eval(model.parameters(), optimizer.state)
toc = time()

print(f"Training time: {(toc-tic)/num_epochs:.2f} sec/epoch")

plt.figure(figsize=(4,3))
plt.plot(range(1,num_epochs+1), accuracy)
plt.plot(range(1,num_epochs+1), [1.0]*num_epochs, ls='--')
plt.xlabel("Epoch")
plt.ylabel("accuracy")
plt.show()
Training time: 0.12 sec/epoch

๋Œ๋ฆด๋•Œ ๋งˆ๋‹ค ๊ทธ๋ž˜ํ”„๊ฐ€ ๋‹ค๋ฅด๊ฒŒ ๋‚˜์˜ค๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ž˜ ํ•™์Šต ๋˜์—ˆ์„ ๋•Œ๋ฅผ ๋…ธ๋ ค test set ํ‰๊ฐ€๋ฅผ ์ง„ํ–‰ํ•ด ๋ณด์•˜์Šต๋‹ˆ๋‹ค.

num_images = len(test_images)

images_per_row = 5 # ํ•œ์ค„์— ํ‘œ์‹œ๋˜๋Š” ๊ทธ๋ฆผ์˜ ๊ฐฏ์ˆ˜๋ฅผ ์ง€์ •

num_rows = (num_images + images_per_row - 1) // images_per_row # ์ „์ฒด ํ–‰์˜ ๊ฐฏ์ˆ˜๋ฅผ ๊ณ„์‚ฐ

fig, axes = plt.subplots(num_rows, images_per_row, figsize = (images_per_row * 2, num_rows * 2)) # ์ „์ฒด ํ–‰๊ณผ ์—ด์— ๋Œ€ํ•œ subplot์„ ์ƒ์„ฑ

# ๊ฐ subplot์— ์ด๋ฏธ์ง€์™€ ์˜ˆ์ธก๊ฐ’, ์ •๋‹ต์„ ํ‘œ์‹œ
for i, (test_img, test_lb) in enumerate(zip(test_images, test_labels)):
    row = i // images_per_row
    col = i % images_per_row
    ax = axes[row, col]

    pred = mx.argmax(model(test_img.reshape([1,-1])), axis=1).item()
    ax.imshow(np.array(test_img.reshape(28, 28) * 255), cmap='gray')
    ax.set_title(f"Predict: {pred} \n True: {test_lb.item()}")
    ax.axis("off") # ์ถ•์„ ์ˆจ๊น€

# ๋‚จ์€ ๋นˆ subplot์„ ์ˆจ๊น€
for i in range(num_images, num_rows * images_per_row):
    axes[i // images_per_row, i % images_per_row].axis("off")

plt.tight_layout()
plt.show()

PyTorch๋ฅผ ์ด์šฉํ•œ MLP ๊ตฌํ˜„

๋™์ผํ•œ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Torch๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

PyTorch๋Š” device = "mps"๋ฅผ ํ†ตํ•ด GPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

MLX์€ Unified Memory(๋ฉ”๋ชจ๋ฆฌ ๊ณต์œ )๋ฅผ ํ†ตํ•ด GPU๋กœ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ด๋™์‹œํ‚ค๋Š” ์‹œ๊ฐ„์„ ์ค„์—ฌ์ฃผ๋Š” ์žฅ์ ์„ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด GPU๋งŒ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ ์–ด๋–ป๊ฒŒ ๋˜๋Š”์ง€ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

PyTorch + GPU(MPS) ํ•™์Šต

import torch
import mnist

device = torch.device("mps:0") if torch.backends.mps.is_available() else 'cpu'
print(f"Device: {device}")

# Data Load
train_images, train_labels, test_images, test_labels = map(
    torch.Tensor, [
        mnist.train_images(),
        mnist.train_labels(),
        mnist.test_images(),
        mnist.test_labels(),
    ]
)
# Flatten Images
train_labels, test_labels = train_labels.long(), test_labels.long()
train_images = torch.reshape(train_images, [train_images.shape[0],-1])
valid_images, test_images = test_images[:-10], test_images[-10:]
valid_labels, test_labels = test_labels[:-10], test_labels[-10:]
valid_images = torch.reshape(valid_images, [valid_images.shape[0],-1])

class torchMLP(torch.nn.Module):
    def __init__(
        self, 
        num_layers: int,
        input_dims: int, 
        hidden_dims: int,
        output_dims: int
    ):
        super().__init__()
        layer_sizes = [input_dims] + [hidden_dims] * num_layers + [output_dims]
        self.layers = self._make_layers(layer_sizes)
    
    def _make_layers(self, layer_sizes):
        layers = []
        for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]):
            layers += [
                            torch.nn.Linear(idim, odim), 
                            torch.nn.ReLU()
                       ]
        
        return torch.nn.Sequential(*layers[:-1])
    
    def __call__(self, x):
        return self.layers(x)

def loss_fn(model, X, y):
    return torch.nn.CrossEntropyLoss()(model(X), y) # nn.losses.cross_entropy๋Š” logit๊ณผ target์‚ฌ์ด์˜ loss๋ฅผ ๊ณ„์‚ฐ

def eval_fn(model, X, y):
    return torch.mean((torch.argmax(model(X), axis=1) == y).float())

def batch_iterate(batch_size, X, y):
    perm = torch.randperm(y.size(0))
    for s in range(0, y.size(0), batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

# Model Load
model = torchMLP(num_layers=num_layers, 
            input_dims=train_images.shape[-1],
            hidden_dims=hidden_dim,
            output_dims=num_classes)
model.to(device)

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

accuracy = [eval_fn(model, valid_images.to(device), valid_labels.to(device)).item()]
tic = time()
for epoch in range(num_epochs):
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        X, y = X.to(device), y.to(device)
        loss = loss_fn(model, X, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    accuracy += [eval_fn(model, valid_images.to(device), valid_labels.to(device)).item()]

toc = time()
print(f"Training time: {(toc-tic)/num_epochs:.2f} sec/epoch")

plt.figure(figsize=(4,3))
plt.plot(range(num_epochs+1), accuracy)
plt.plot(range(num_epochs+1),[1.0]*(num_epochs+1), ls='--')
plt.xlabel("Epoch")
plt.ylabel("accuracy")
plt.show()
Device: mps:0
Training time: 0.48 sec/epoch

ํ•™์Šต์‹œ๊ฐ„์ด GPU(MPS)๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ๊ฐ€ MLX๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ ๋ณด๋‹ค 1 epoch๋‹น 0.1์ดˆ์—์„œ 0.48์ดˆ๋กœ ์ฆ๊ฐ€ํ•œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด CPU๋กœ๋งŒ ํ•™์Šตํ–ˆ์„ ๋•Œ๋Š” ์–ด๋–ค์ง€ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

PyTorch + CPU ํ•™์Šต

import torch
import mnist

# Data Load
train_images, train_labels, test_images, test_labels = map(
    torch.Tensor, [
        mnist.train_images(),
        mnist.train_labels(),
        mnist.test_images(),
        mnist.test_labels(),
    ]
)
# Flatten Images
train_labels, test_labels = train_labels.long(), test_labels.long()
train_images = torch.reshape(train_images, [train_images.shape[0],-1])
valid_images, test_images = test_images[:-10], test_images[-10:]
valid_labels, test_labels = test_labels[:-10], test_labels[-10:]
valid_images = torch.reshape(valid_images, [valid_images.shape[0],-1])

class torchMLP(torch.nn.Module):
    def __init__(
        self, 
        num_layers: int,
        input_dims: int, 
        hidden_dims: int,
        output_dims: int
    ):
        super().__init__()
        layer_sizes = [input_dims] + [hidden_dims] * num_layers + [output_dims]
        self.layers = self._make_layers(layer_sizes)
    
    def _make_layers(self, layer_sizes):
        layers = []
        for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]):
            layers += [
                            torch.nn.Linear(idim, odim), 
                            torch.nn.ReLU()
                       ]
        
        return torch.nn.Sequential(*layers[:-1])
    
    def __call__(self, x):
        return self.layers(x)

def loss_fn(model, X, y):
    return torch.nn.CrossEntropyLoss()(model(X), y) # nn.losses.cross_entropy๋Š” logit๊ณผ target์‚ฌ์ด์˜ loss๋ฅผ ๊ณ„์‚ฐ

def eval_fn(model, X, y):
    return torch.mean((torch.argmax(model(X), axis=1) == y).float())

def batch_iterate(batch_size, X, y):
    perm = torch.randperm(y.size(0))
    for s in range(0, y.size(0), batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

# Model Load
model = torchMLP(num_layers=num_layers, 
            input_dims=train_images.shape[-1],
            hidden_dims=hidden_dim,
            output_dims=num_classes)

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

accuracy = [eval_fn(model, valid_images, valid_labels).item()]
tic = time()
for epoch in range(num_epochs):
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        loss = loss_fn(model, X, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    accuracy += [eval_fn(model, valid_images, valid_labels).item()]

toc = time()
print(f"Training time: {(toc-tic)/num_epochs:.2f} sec/epoch")

plt.figure(figsize=(4,3))
plt.plot(range(num_epochs+1), accuracy)
plt.plot(range(num_epochs+1),[1.0]*(num_epochs+1), ls='--')
plt.xlabel("Epoch")
plt.ylabel("accuracy")
plt.show()
Training time: 0.27 sec/epoch

CPU๋ฅผ ์‚ฌ์šฉํ•ด์„œ ํ•™์Šตํ•˜์˜€์„ ๋•Œ๋Š” GPU(MPS)๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ ๋ณด๋‹ค 1 epoch๋‹น 0.27์ดˆ๋กœ ๊ฐ์†Œํ•œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๊ฒƒ์„ ํ†ตํ•ด MLP์˜ ๊ฒฝ์šฐ GPU(MPS) ํ™œ์šฉ๋„๊ฐ€ ๋–จ์–ด์ง„๋‹ค๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

GPU(MPS)ํ™œ์šฉ๋„๊ฐ€ ๋–จ์–ด์ง€๋Š” ๊ฒƒ์€ Unified Memory๊ฐ€ ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ฉ”๋ชจ๋ฆฌ๋ฅผ device๋กœ ์˜ฎ๊ธฐ๋Š” ๊ณผ์ •์—์„œ ์‹œ๊ฐ„ ์†ํ•ด๊ฐ€ ์ผ์–ด๋‚ฌ๊ฑฐ๋‚˜, ์ตœ์ ํ™” ๋ฌธ์ œ ๋•Œ๋ฌธ์— ๋ฐœ์ƒ๋˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์—ฌ์ง‘๋‹ˆ๋‹ค.

MLX๋ฅผ ์‚ฌ์šฉํ•˜์˜€์„ ๋•Œ ๊ฐ„๋‹จํ•œ linear regression์—์„œ๋Š” ํฐ ์ฐจ์ด๊ฐ€ ์—†์—ˆ์ง€๋งŒ, multi-layer perceptron์ฒ˜๋Ÿผ ํ–‰๋ ฌ ์—ฐ์‚ฐ์ด ๋ฌด๊ฑฐ์›Œ์ง€๋Š” ๊ฒฝ์šฐ ์ฐจ์ด๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

References