MLX: Apple silicon Machine Learning - 03.Multi-Layer Perceptron(MLP)
Multi-Layer Perceptron(MLP) ๊ตฌํ
MNIST ๋ฐ์ดํฐ์ ์ ์ด์ฉํ์ฌ Multi-Layer Perceptron(MLP)์ MLX์ CPU, GPU(MPS)๋ฅผ ์ด์ฉํด ๊ตฌํํ๊ณ ๋น๊ตํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
(CPU์ MPS๋ Torch ์ฌ์ฉ)
MLX๋ฅผ ์ด์ฉํ์ฌ MLP ๊ตฌํ
๊ด๋ จ ๋ชจ๋์ import ํฉ๋๋ค.
import mlx.nn as nn
์ด import torch.nn as nn
๊ณผ ๋งค์ฐ ์ ์ฌํ๋ค๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
MLX์ GPU(MPS)๋ฅผ ๋น๊ตํ ๋, ์ฝ๊ฐ์ ์ฝ๋ ์์ ๊ณผ import ๋ณ๊ฒฝ๋ง์ผ๋ก ์ฝ๊ฒ ๋น๊ตํ ์ ์์ ๊ฒ ๊ฐ์ต๋๋ค.
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from matplotlib import pyplot as plt
from time import time
MLP class ๋ถ๋ถ์ด torch.nn์ ์ฌ์ฉํ ๋์ ์ ์ฌํจ์ ํ์ธํ ์ ์์ต๋๋ค.
์์์ ์ธ๊ธํ๋๋ก ์ฝ๋ ์์ ์์ด import๋ง ๋ณ๊ฒฝํ์ฌ MLX์ GPU(MPS)๋ฅผ ๋น๊ตํ ์ ์์ ๊ฒ ๊ฐ์ต๋๋ค.
class MLP(nn.Module):
def __init__(
self,
num_layers: int,
input_dims: int,
hidden_dims: int,
output_dims: int
):
super().__init__()
layer_sizes = [input_dims] + [hidden_dims] * num_layers + [output_dims]
self.layers = self._make_layers(layer_sizes)
def _make_layers(self, layer_sizes):
layers = []
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]):
layers += [nn.Linear(idim, odim), nn.ReLU()]
return nn.Sequential(*layers[:-1])
def __call__(self, x):
return self.layers(x)
Loss function๊ณผ evaluation function๋ ๋ง๋ค์ด์ค๋๋ค.
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y)) # nn.losses.cross_entropy๋ logit๊ณผ target์ฌ์ด์ loss๋ฅผ ๊ณ์ฐํด์ค๋ค.
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
Hyperparam์ ์ค์ ํ๊ณ , MNIST ๋ฐ์ดํฐ์ ์ ๋ค์ด๋ฐ์ ์ ์ฒ๋ฆฌ ํด์ฃผ๋๋ก ํฉ๋๋ค.
MLP๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ (28 X 28)์ ์ด๋ฏธ์ง๋ฅผ 768dimensions์ผ๋ก flattenํด์ค๋๋ค.
num_layers = 2
hidden_dim = 256
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-2
# Data Load
import mnist
train_images, train_labels, test_images, test_labels = map(
mx.array, [
mnist.train_images(),
mnist.train_labels(),
mnist.test_images(),
mnist.test_labels(),
]
)
# Flatten Images
train_images = mx.reshape(train_images, [train_images.shape[0],-1])
valid_images, test_images = test_images[:-10], test_images[-10:]
valid_labels, test_labels = test_labels[:-10], test_labels[-10:]
valid_images = mx.reshape(valid_images, [valid_images.shape[0],-1])
Batch iterator๋ ๋ง๋ค์ด์ค๋๋ค.
torch์์๋ dataloader๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค.
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s: s + batch_size]
yield X[ids], y[ids]
Generator๋?
- iterator๋ฅผ ์์ฑํด์ฃผ๋ ํจ์ ๋๋ ๊ฐ์ฒด์ ๋๋ค.
- ํจ์ ๋ด๋ถ์
yield
ํค์๋๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ ๋ฐํํ๋ฉด์ ์ํ๋ฅผ ๋ณด์กดํฉ๋๋ค.
Generator ํน์ง
- Iterableํ ์์๊ฐ ์ง์ ๋ฉ๋๋ค. (๋ชจ๋ generator๋ iterator)
- ๋์จํ๊ฒ ํ๊ฐ๋์ด ํ์์ ๋ฐ๋ผ ๊ฐ์ ์์ฑํ๋ฉฐ, ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํจ์จ์ ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
- ๋ด๋ถ ์ํ๋ฅผ ์ ์งํ๋ฉฐ, ํจ์ ํธ์ถ ๊ฐ์ ์ ๋ณด๋ฅผ ๊ธฐ์ตํฉ๋๋ค.
- ๋ฌดํํ ์์๊ฐ ์๋ ๊ฐ์ฒด๋ฅผ ๋ชจ๋ธ๋งํ ์ ์์ต๋๋ค.
- ์์ฐ์ค๋ฌ์ด ์คํธ๋ฆผ ์ฒ๋ฆฌ๋ฅผ ์ํ ํ์ดํ๋ผ์ธ์ ๊ตฌ์ฑํ ์ ์์ต๋๋ค.
SGD๋ก MLPํ์ต์ ํฉ๋๋ค.
Parameter initialize๊ฐ ๋๋ค์ด๊ธฐ ๋๋ฌธ์ ํ์ต์ด ์๋ ์๋ ์์ต๋๋ค.
(์ด๊ธฐํ์ ๋ฐ๋ฅธ ๋ฌธ์ )
# Model Load
model = MLP(num_layers=num_layers,
input_dims=train_images.shape[-1],
hidden_dims=hidden_dim,
output_dims=num_classes)
mx.eval(model.parameters())
# loss and grad fn
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# optimizer
optimizer = optim.SGD(learning_rate=learning_rate)
accuracy = []
tic = time()
for epoch in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
optimizer.update(model, grads)
accuracy += [eval_fn(model, valid_images, valid_labels).item()]
mx.eval(model.parameters(), optimizer.state)
toc = time()
print(f"Training time: {(toc-tic)/num_epochs:.2f} sec/epoch")
plt.figure(figsize=(4,3))
plt.plot(range(1,num_epochs+1), accuracy)
plt.plot(range(1,num_epochs+1), [1.0]*num_epochs, ls='--')
plt.xlabel("Epoch")
plt.ylabel("accuracy")
plt.show()
Training time: 0.12 sec/epoch
๋๋ฆด๋ ๋ง๋ค ๊ทธ๋ํ๊ฐ ๋ค๋ฅด๊ฒ ๋์ค๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
์ ํ์ต ๋์์ ๋๋ฅผ ๋ ธ๋ ค test set ํ๊ฐ๋ฅผ ์งํํด ๋ณด์์ต๋๋ค.
num_images = len(test_images)
images_per_row = 5 # ํ์ค์ ํ์๋๋ ๊ทธ๋ฆผ์ ๊ฐฏ์๋ฅผ ์ง์
num_rows = (num_images + images_per_row - 1) // images_per_row # ์ ์ฒด ํ์ ๊ฐฏ์๋ฅผ ๊ณ์ฐ
fig, axes = plt.subplots(num_rows, images_per_row, figsize = (images_per_row * 2, num_rows * 2)) # ์ ์ฒด ํ๊ณผ ์ด์ ๋ํ subplot์ ์์ฑ
# ๊ฐ subplot์ ์ด๋ฏธ์ง์ ์์ธก๊ฐ, ์ ๋ต์ ํ์
for i, (test_img, test_lb) in enumerate(zip(test_images, test_labels)):
row = i // images_per_row
col = i % images_per_row
ax = axes[row, col]
pred = mx.argmax(model(test_img.reshape([1,-1])), axis=1).item()
ax.imshow(np.array(test_img.reshape(28, 28) * 255), cmap='gray')
ax.set_title(f"Predict: {pred} \n True: {test_lb.item()}")
ax.axis("off") # ์ถ์ ์จ๊น
# ๋จ์ ๋น subplot์ ์จ๊น
for i in range(num_images, num_rows * images_per_row):
axes[i // images_per_row, i % images_per_row].axis("off")
plt.tight_layout()
plt.show()
PyTorch๋ฅผ ์ด์ฉํ MLP ๊ตฌํ
๋์ผํ ์ฝ๋๋ฅผ ์ฌ์ฉํ์ฌ Torch๋ก ๊ตฌํํฉ๋๋ค.
PyTorch๋ device = "mps"
๋ฅผ ํตํด GPU๋ฅผ ์ฌ์ฉํฉ๋๋ค.
MLX์ Unified Memory(๋ฉ๋ชจ๋ฆฌ ๊ณต์ )๋ฅผ ํตํด GPU๋ก ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ด๋์ํค๋ ์๊ฐ์ ์ค์ฌ์ฃผ๋ ์ฅ์ ์ ๊ฐ์ง๊ณ ์์ต๋๋ค.
๊ทธ๋ ๋ค๋ฉด GPU๋ง ์ฌ์ฉํ์ ๋ ์ด๋ป๊ฒ ๋๋์ง ์ดํด๋ณด๊ฒ ์ต๋๋ค.
PyTorch + GPU(MPS) ํ์ต
import torch
import mnist
device = torch.device("mps:0") if torch.backends.mps.is_available() else 'cpu'
print(f"Device: {device}")
# Data Load
train_images, train_labels, test_images, test_labels = map(
torch.Tensor, [
mnist.train_images(),
mnist.train_labels(),
mnist.test_images(),
mnist.test_labels(),
]
)
# Flatten Images
train_labels, test_labels = train_labels.long(), test_labels.long()
train_images = torch.reshape(train_images, [train_images.shape[0],-1])
valid_images, test_images = test_images[:-10], test_images[-10:]
valid_labels, test_labels = test_labels[:-10], test_labels[-10:]
valid_images = torch.reshape(valid_images, [valid_images.shape[0],-1])
class torchMLP(torch.nn.Module):
def __init__(
self,
num_layers: int,
input_dims: int,
hidden_dims: int,
output_dims: int
):
super().__init__()
layer_sizes = [input_dims] + [hidden_dims] * num_layers + [output_dims]
self.layers = self._make_layers(layer_sizes)
def _make_layers(self, layer_sizes):
layers = []
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]):
layers += [
torch.nn.Linear(idim, odim),
torch.nn.ReLU()
]
return torch.nn.Sequential(*layers[:-1])
def __call__(self, x):
return self.layers(x)
def loss_fn(model, X, y):
return torch.nn.CrossEntropyLoss()(model(X), y) # nn.losses.cross_entropy๋ logit๊ณผ target์ฌ์ด์ loss๋ฅผ ๊ณ์ฐ
def eval_fn(model, X, y):
return torch.mean((torch.argmax(model(X), axis=1) == y).float())
def batch_iterate(batch_size, X, y):
perm = torch.randperm(y.size(0))
for s in range(0, y.size(0), batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
# Model Load
model = torchMLP(num_layers=num_layers,
input_dims=train_images.shape[-1],
hidden_dims=hidden_dim,
output_dims=num_classes)
model.to(device)
# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
accuracy = [eval_fn(model, valid_images.to(device), valid_labels.to(device)).item()]
tic = time()
for epoch in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
X, y = X.to(device), y.to(device)
loss = loss_fn(model, X, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
accuracy += [eval_fn(model, valid_images.to(device), valid_labels.to(device)).item()]
toc = time()
print(f"Training time: {(toc-tic)/num_epochs:.2f} sec/epoch")
plt.figure(figsize=(4,3))
plt.plot(range(num_epochs+1), accuracy)
plt.plot(range(num_epochs+1),[1.0]*(num_epochs+1), ls='--')
plt.xlabel("Epoch")
plt.ylabel("accuracy")
plt.show()
Device: mps:0 Training time: 0.48 sec/epoch
ํ์ต์๊ฐ์ด GPU(MPS)๋ฅผ ์ฌ์ฉํ์ ๋๊ฐ MLX๋ฅผ ์ฌ์ฉํ์ ๋ ๋ณด๋ค 1 epoch๋น 0.1์ด์์ 0.48์ด๋ก ์ฆ๊ฐํ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
๊ทธ๋ ๋ค๋ฉด CPU๋ก๋ง ํ์ตํ์ ๋๋ ์ด๋ค์ง ์ดํด๋ณด๊ฒ ์ต๋๋ค.
PyTorch + CPU ํ์ต
import torch
import mnist
# Data Load
train_images, train_labels, test_images, test_labels = map(
torch.Tensor, [
mnist.train_images(),
mnist.train_labels(),
mnist.test_images(),
mnist.test_labels(),
]
)
# Flatten Images
train_labels, test_labels = train_labels.long(), test_labels.long()
train_images = torch.reshape(train_images, [train_images.shape[0],-1])
valid_images, test_images = test_images[:-10], test_images[-10:]
valid_labels, test_labels = test_labels[:-10], test_labels[-10:]
valid_images = torch.reshape(valid_images, [valid_images.shape[0],-1])
class torchMLP(torch.nn.Module):
def __init__(
self,
num_layers: int,
input_dims: int,
hidden_dims: int,
output_dims: int
):
super().__init__()
layer_sizes = [input_dims] + [hidden_dims] * num_layers + [output_dims]
self.layers = self._make_layers(layer_sizes)
def _make_layers(self, layer_sizes):
layers = []
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]):
layers += [
torch.nn.Linear(idim, odim),
torch.nn.ReLU()
]
return torch.nn.Sequential(*layers[:-1])
def __call__(self, x):
return self.layers(x)
def loss_fn(model, X, y):
return torch.nn.CrossEntropyLoss()(model(X), y) # nn.losses.cross_entropy๋ logit๊ณผ target์ฌ์ด์ loss๋ฅผ ๊ณ์ฐ
def eval_fn(model, X, y):
return torch.mean((torch.argmax(model(X), axis=1) == y).float())
def batch_iterate(batch_size, X, y):
perm = torch.randperm(y.size(0))
for s in range(0, y.size(0), batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
# Model Load
model = torchMLP(num_layers=num_layers,
input_dims=train_images.shape[-1],
hidden_dims=hidden_dim,
output_dims=num_classes)
# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
accuracy = [eval_fn(model, valid_images, valid_labels).item()]
tic = time()
for epoch in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss = loss_fn(model, X, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
accuracy += [eval_fn(model, valid_images, valid_labels).item()]
toc = time()
print(f"Training time: {(toc-tic)/num_epochs:.2f} sec/epoch")
plt.figure(figsize=(4,3))
plt.plot(range(num_epochs+1), accuracy)
plt.plot(range(num_epochs+1),[1.0]*(num_epochs+1), ls='--')
plt.xlabel("Epoch")
plt.ylabel("accuracy")
plt.show()
Training time: 0.27 sec/epoch
CPU๋ฅผ ์ฌ์ฉํด์ ํ์ตํ์์ ๋๋ GPU(MPS)๋ฅผ ์ฌ์ฉํ์ ๋ ๋ณด๋ค 1 epoch๋น 0.27์ด๋ก ๊ฐ์ํ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
์ด๊ฒ์ ํตํด MLP์ ๊ฒฝ์ฐ GPU(MPS) ํ์ฉ๋๊ฐ ๋จ์ด์ง๋ค๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
GPU(MPS)ํ์ฉ๋๊ฐ ๋จ์ด์ง๋ ๊ฒ์ Unified Memory๊ฐ ์๋๊ธฐ ๋๋ฌธ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ device๋ก ์ฎ๊ธฐ๋ ๊ณผ์ ์์ ์๊ฐ ์ํด๊ฐ ์ผ์ด๋ฌ๊ฑฐ๋, ์ต์ ํ ๋ฌธ์ ๋๋ฌธ์ ๋ฐ์๋๋ ๊ฒ์ผ๋ก ๋ณด์ฌ์ง๋๋ค.
MLX๋ฅผ ์ฌ์ฉํ์์ ๋ ๊ฐ๋จํ linear regression์์๋ ํฐ ์ฐจ์ด๊ฐ ์์์ง๋ง, multi-layer perceptron์ฒ๋ผ ํ๋ ฌ ์ฐ์ฐ์ด ๋ฌด๊ฑฐ์์ง๋ ๊ฒฝ์ฐ ์ฐจ์ด๊ฐ ๋ฐ์ํ๋ ๊ฒ์ ํ์ธํ ์ ์์์ต๋๋ค.
References
-
MLX ํํ์ด์ง(Multi-Layer Perceptron ์ค๋ช )
-
๋ค์ด๋ฒ ๋ธ๋ก๊ทธ(๋ค์ธต ํผ์ ํธ๋ก (Multi-Layer Perceptron))
-
SKT Enterprise(MLX Multi-Layer Perceptron ์ค๋ช )
-
์ํค๋ ์ค(Generator ์ค๋ช )
-
MLX ํํ์ด์ง(Unified Memory)