3 minute read

MLX: Linear Regression

MLX๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ„๋‹จํ•œ Linear Regression ์˜ˆ์ œ๋ฅผ ๋Œ๋ ค๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

์ž„์˜์˜ ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค์–ด ๋ฐ์ดํ„ฐ๋ฅผ ํ•ฉ์„ฑํ•˜๊ณ , ํ•ด๋‹น ๋ฐ์ดํ„ฐ๋ฅผ ์ด์šฉํ•ด ์—ญ์œผ๋กœ ๊ทผ์‚ฌ์น˜๋ฅผ ๊ตฌํ•˜๋Š” ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.

์šฐ์„  ๊ด€๋ จ ๋ชจ๋“ˆ์„ import ํ•˜๊ณ  hyperparam์„ ์„ธํŒ…ํ•ฉ๋‹ˆ๋‹ค.

import mlx.core as mx
import time

num_features = 100
num_examples = 1_000
test_examples = 100
num_iters = 10_000 # iterations of SGD
lr = 0.01 # learning rate for SGD

๋จธ์‹ ๋Ÿฌ๋‹์—์„œ ๋งํ•˜๋Š” Batch์˜ ์ •์˜

  • ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ ํ•œ iteration๋‹น(๋ฐ˜๋ณต 1ํšŒ๋‹น) ์‚ฌ์šฉ๋˜๋Š” example์˜ set๋ชจ์ž„์ž…๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ iteration์€ ์ •ํ•ด์ง„ batch size๋ฅผ ์ด์šฉํ•˜์—ฌ ํ•™์Šต(forward - backward)๋ฅผ ๋ฐ˜๋ณตํ•˜๋Š” ํšŸ์ˆ˜๋ฅผ ๋งํ•ฉ๋‹ˆ๋‹ค.
  • ํ•œ ๋ฒˆ์˜ epoch๋ฅผ ์œ„ํ•ด ์—ฌ๋Ÿฌ๋ฒˆ์˜ iteration์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • training error์™€ validation error๊ฐ€ ๋™์ผํ•˜๊ฒŒ ๊ฐ์†Œํ•˜๋‹ค๊ฐ€ validation error๊ฐ€ ์ฆ๊ฐ€ํ•˜๊ธฐ ์‹œ์ž‘ํ•˜๋Š” ์ง์ „ ์ ์˜ epoch๋ฅผ ์„ ํƒํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. (overfitting ๋ฐฉ์ง€)

Batch Size์˜ ์ •์˜

  • Batch ํ•˜๋‚˜์— ํฌํ•จ๋˜๋Š” example set์˜ ๊ฐฏ์ˆ˜์ž…๋‹ˆ๋‹ค.
  • Batch / Mini-Batch/ Stochastic ์„ธ ๊ฐ€์ง€๋กœ ๋‚˜๋ˆŒ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.(์•„๋ž˜ ๊ทธ๋ฆผ ์ฐธ๊ณ )

    batch

  • SGD(Stochastic Gradient Descent)๋Š” ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ 1, Mini-Batch๋Š” 10 ~ 1,00 ์‚ฌ์ด์ง€๋งŒ ๋ณดํ†ต 2์˜ ์ง€์ˆ˜์Šน(32, 64, 128โ€ฆ)์œผ๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

Batch๋ณ„ ํŠน์ง• ๋ฐ ์žฅ๋‹จ์ 

Batch

  • ์—ฌ๋Ÿฌ๊ฐœ์˜ ์ƒ˜ํ”Œ๋“ค์ด ํ•œ๋ฒˆ์— ์˜ํ–ฅ์„ ์ฃผ์–ด ํ•ฉ์˜๋œ ๋ฐฉํ–ฅ์œผ๋กœ smoothํ•˜๊ฒŒ ์ˆ˜๋ ด๋ฉ๋‹ˆ๋‹ค.
  • ์ƒ˜ํ”Œ ๊ฐฏ์ˆ˜๋ฅผ ์ „๋ถ€ ๊ณ„์‚ฐํ•ด์•ผ ํ•จ์œผ๋กœ ์‹œ๊ฐ„์ด ๋งŽ์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค.
  • ๋ชจ๋“  Training data set์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

(SGD)Stochastic Gradient Descent

  • ๋ฐ์ดํ„ฐ๋ฅผ ํ•œ ๊ฐœ์”ฉ ์ถ”์ถœํ•ด์„œ ์ฒ˜๋ฆฌํ•˜๊ณ  ์ด๋ฅผ ๋ชจ๋“  ๋ฐ์ดํ„ฐ์— ๋ฐ˜๋ณตํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
  • ์ˆ˜๋ ด ์†๋„๋Š” ๋น ๋ฅด์ง€๋งŒ ์˜ค์ฐจ์œจ์ด ํฝ๋‹ˆ๋‹ค. (global minimum์„ ์ฐพ์ง€ ๋ชปํ•  ์ˆ˜ ์žˆ์Œ)
  • GPU ์„ฑ๋Šฅ์„ ์ œ๋Œ€๋กœ ํ™œ์šฉํ•˜์ง€ ๋ชปํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋น„ํšจ์œจ์ ์ž…๋‹ˆ๋‹ค. (ํ•˜๋‚˜์”ฉ ์ฒ˜๋ฆฌํ•˜๊ธฐ ๋•Œ๋ฌธ)

Mini-Batch

  • ์ „์ฒด ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋กœ ๋“ฑ๋ถ„ํ•˜์—ฌ ๊ฐ ๋ฐฐ์น˜ ์…‹์„ ์ˆœ์ฐจ์ ์œผ๋กœ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • ๋ฐฐ์น˜๋ณด๋‹ค ๋น ๋ฅด๊ณ  SGD๋ณด๋‹ค ๋‚ฎ์€ ์˜ค์ฐจ์œจ์„ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์ž„์˜์˜ ์„ ํ˜• ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค๊ณ , ์ž„์˜์˜ input ๋ฐ์ดํ„ฐ๋ฅผ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.

mx.random.normal์„ ์ด์šฉํ•˜์—ฌ ๋žœ๋คํ•˜๊ฒŒ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.

label ๊ฐ’์˜ ๊ฒฝ์šฐ ๋งŒ๋“ค์–ด์ง„ input๋ฐ์ดํ„ฐ๋ฅผ ํ•จ์ˆ˜์— ํ†ต๊ณผ์‹œํ‚ค๊ณ , ์ž‘์€ noise๋ฅผ ๋ถ€์—ฌํ•˜์—ฌ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.

# ์ž„์˜์˜ ์„ ํ˜• ํ•จ์ˆ˜ True parameters
w_start = mx.random.normal((num_features,))

# Input examples(design matrix)
X = mx.random.normal((num_examples, num_features))

# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
mx.random.normal((num_examples,))
y = X @ w_start + eps
  • @๋Š” NumPy๋‚˜ MXNet๊ณผ ๊ฐ™์€ ๋ฐฐ์—ด ๊ณ„์‚ฐ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ํ–‰๋ ฌ ๊ณฑ์…ˆ(๋˜๋Š” ํ–‰๋ ฌ-๋ฒกํ„ฐ ๊ณฑ์…ˆ)์„ ๋‚˜ํƒ€๋‚ด๋Š” ์—ฐ์‚ฐ์ž์ž…๋‹ˆ๋‹ค.
  • X @ w_start๋Š” ํ–‰๋ ฌ X์™€ ๋ฒกํ„ฐ w_start์‚ฌ์ด์˜ ํ–‰๋ ฌ ๊ณฑ์…‰์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

๋ณ„๋„์˜ ํ…Œ์ŠคํŠธ ์…‹๋„ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.

# Test set generation
X_test = mx.random.normal((test_examples, num_features))
y_test = X_test @ w_start

๊ทธ๋ฆฌ๊ณ  Loss function๊ณผ Gradient function(mx.grad ์‚ฌ์šฉ)์„ ๋งŒ๋“ค์–ด์ค๋‹ˆ๋‹ค.

# MSE Loss function
def loss_fn(w):
    return 0.5 * mx.mean(mx.square(X @ w - y))

# Gradient function
grad_fn = mx.grad(loss_fn)

์ด์ œ Linear regression์„ ์œ„ํ•œ parameter๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  SGD(Stochastic Gradient Descent) ๋ฐฉ๋ฒ•์„ ์ด์šฉํ•ด ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

# Initialize random parameter
w = 1e-2 * mx.random.normal((num_features,))

# Test Error(MSE)
pred_test = X_test @ w
test_error = mx.mean(mx.square(y_test - pred_test))

print(f"Initial Test Error(MSE): {test_error.item():.6f}")
Initial Test Error(MSE): 114.406784
# Training by SGD
start = time.time()
for its in range(1,num_iters+1):
    grad = grad_fn(w)
    w = w - lr * grad
mx.eval(w)
end = time.time()

print(f"Training elapsed time: {end-start} seconds")
print(f"Throughput: {num_iters/(end-start):.3f} iter/s")
Training elapsed time: 0.8959159851074219 seconds
Throughput: 11161.761 iter/s
# Test Error(MSE)
pred_test = X_test @ w
test_error = mx.mean(mx.square(y_test - pred_test))

print(f"Final Test Error(MSE): {test_error.item():.6f}")
Final Test Error(MSE): 0.000011

Test Set์• ์„œ MSE๊ฐ’์ด ํฌ๊ฒŒ ๊ฐ์†Œํ•œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.(Test Set MSE: 0.00001)

์ถ”๊ฐ€์ ์œผ๋กœ ์ˆ˜ํ–‰์‹œ๊ฐ„์€ ์•ฝ 0.8์ดˆ ๊ฑธ๋ฆฐ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.(M3 Macbook pro ๊ธฐ์ค€)

CPU ์—ฐ์‚ฐ๊ณผ์˜ ์ˆ˜ํ–‰์‹œ๊ฐ„ ๋น„๊ต

๋งŒ์•ฝ MLX(GPU)๋Œ€์‹ ์— numpy array(CPU)๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ ๋ฐœ์ƒ๋˜๋Š” ์†๋„ ์ฐจ์ด๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

import numpy as np

# True parameters
w_star = np.random.normal(size=(num_features,1))

# Input examples(design matrix)
X = np.random.normal(size=(num_examples, num_features))

# Noisy labels
eps = 1e-2 * np.random.normal(size=(num_examples,1))
y = np.matmul(X, w_star) + eps

# Test Set Generation
X_test = np.random.normal(size=(test_examples, num_features))
y_test = np.matmul(X_test, w_star)
def loss_fn(w):
    return 0.5 * np.mean(np.square(np.matmul(X, w) - y))

def grad_fn(w):
    return np.matmul(X.T, np.matmul(X, w) - y) * (1/num_examples)
w = 1e-2 * np.random.normal(size = (num_features,1))

pred_test = np.matmul(X_test, w)
test_error = np.mean(np.square(y_test - pred_test))

print(f"Initial Test Error(MSE): {test_error.item():.6f}")
Initial Test Error(MSE): 93.005015
start = time.time()
for its in range(1,num_iters+1):
    grad = grad_fn(w)
    w = w - lr * grad

end = time.time()

print(f"Training elapsed time: {end-start} seconds")
print(f"Throughput: {num_iters/(end-start):.3f} iter/s")
Training elapsed time: 0.8565518856048584 seconds
Throughput: 11674.716 iter/s
pred_test = np.matmul(X_test, w)
test_error = np.mean(np.square(y_test - pred_test))

print(f"Final Test Error(MSE): {test_error.item():.6f}")
Final Test Error(MSE): 0.000010

๊ฐ„๋‹จํ•œ linear regression์—์„œ๋Š” ํฐ ์ฐจ์ด๊ฐ€ ์—†๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

multi-layer perceptron์ฒ˜๋Ÿผ ํ–‰๋ ฌ ์—ฐ์‚ฐ์ด ๋ฌด๊ฑฐ์›Œ์ง€๋Š” ๊ฒฝ์šฐ ์ฐจ์ด๊ฐ€ ๋ฐœ์ƒํ•˜๋Š”์ง€ ๋‹ค์Œ ํฌ์ŠคํŒ…์—์„œ ํ™•์ธํ•ด ๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

References