MLX: Apple silicon Machine Learning - 01.Quick Start Guide
MLX๋?
MLX๋ Apple silicon์์ ๋จธ์ ๋ฌ๋์ ์ํํ๊ธฐ ์ํด ๋ง๋ค์ด์ง array framework์ ๋๋ค.
Apple silicon๋ง์ CPU์ GPU๋ฅผ ์ฌ์ฉํ์ฌ ๋ฒกํฐ์ ๊ทธ๋ํ ์ฐ์ฐ ์๋๋ฅผ ํฌ๊ฒ ๋์ผ ์ ์์ต๋๋ค.
์ค์น ์๊ตฌ ์ฌํญ
-
M ์๋ฆฌ์ฆ apple silicon
-
native Python >= 3.8
-
MacOS >= 13.3
!python -c "import platform; print(platform.processor())" # It must be arm
!pip install mlx
arm Collecting mlx Obtaining dependency information for mlx from https://files.pythonhosted.org/packages/8f/e7/40e631abca0823399ad5f89e2fd849393d7e6a8f3efd2cf1a3ef4ceb0df0/mlx-0.0.11-cp311-cp311-macosx_14_0_arm64.whl.metadata Downloading mlx-0.0.11-cp311-cp311-macosx_14_0_arm64.whl.metadata (4.9 kB) Downloading mlx-0.0.11-cp311-cp311-macosx_14_0_arm64.whl (17.1 MB) [2K [38;2;114;156;31mโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ[0m [32m17.1/17.1 MB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m[36m0:00:01[0mm eta [36m0:00:01[0m [?25hInstalling collected packages: mlx Successfully installed mlx-0.0.11
Basic Quick Start
array(๋ฐฐ์ด)๋ฅผ ๋ง๋ค๊ธฐ ์ํด mlx.core๋ฅผ import ํฉ๋๋ค.
import mlx.core as mx
a = mx.array([1,2,3,4])
print(f'a shape: {a.shape}')
print(f'a dtype: {a.dtype}')
b = mx.array([1.0, 2.0, 3.0, 4.0])
print(f'b shape: {b.shape}')
print(f'b dtype: {b.dtype}')
c = mx.array([[1,2,3,4],[5,6,7,8]])
print(f'c shape: {c.shape}')
print(f'c dtype: {c.dtype}')
d = mx.array([[1,2,3,4],[5.0,6.0,7.0,8.0]])
print(f'd shape: {d.shape}')
print(f'd dtype: {d.dtype}')
a shape: [4] a dtype: mlx.core.int32 b shape: [4] b dtype: mlx.core.float32 c shape: [2, 4] c dtype: mlx.core.int32 d shape: [2, 4] d dtype: mlx.core.float32
MLX๋ lazy evaluation์ ์ฌ์ฉํฉ๋๋ค.
- lazy evaluation์ด๋?
- ์ค์ ๋ก ์ฐ์ฐ ๊ฒฐ๊ณผ๊ฐ ์ด๋๊ฐ์ ์ฌ์ฉ๋๊ธฐ ์ ๊น์ง ์ฐ์ฐ์ ๋ฏธ๋ฃจ๋ ํ๋ก๊ทธ๋๋ฐ ๋ฐฉ๋ฒ๋ก ์ ๋๋ค.
lazy evaluation์ ์ฑ๋ฅ ๊ด์ ์์ ์ด๋์ ๋ณด๊ฑฐ๋ ์ค๋ฅ๋ฅผ ํํผ ํน์ ๋ฌดํ ์๋ฃ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์๋ค๋ ์ฅ์ ์ด ์์ต๋๋ค.
์ฐ์ฐ์ด ์ผ์ด๋์ง ์์ ๊ฒ๊ณผ ์ฐ์ฐ์ด ์ผ์ด๋๋ ๊ฒ์ ๋ํ ์๊ฐ ์๋ชจ๋ฅผ ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค.
import time
start = time.time()
for _ in range(100):
c = a + b # ์ค์ ๋ก ์ฐ์ฐ์ด ์ผ์ด๋์ง ์๋๋ค, c๊ฐ ์ฌ์ฉ๋์ง ์์๊ธฐ ๋๋ฌธ์ด๋ค.
print(f'lazy evaluation time: {time.time()-start}')
start = time.time()
for _ in range(100):
c = a + b
mx.eval(c) # ์ฐ์ฐ์ด ์ํ๋๋ค. (mx.eval ํจ์๋ ๊ฐ์ ๋ก ์ฐ์ฐ์ ์ํ์ํจ๋ค.)
print(f'forced evaluation time: {time.time()-start}')
lazy evaluation time: 0.0003581047058105469 forced evaluation time: 0.02145099639892578
Unified Memory
Apple Silicon์ CPU์ GPU๊ฐ ๋ณ๊ฐ์ ์ฅ์น๋ก ์กด์ฌํ์ง ์์ต๋๋ค.
ํ๋์ Unifired Memory Architecture(UMA)๋ก ๊ตฌ์ฑ๋์ด ์์ผ๋ฉฐ, CPU์ GPU๊ฐ ๋์ผํ memory pool์์ ์ง์ ์ ์ผ๋ก ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
MLX๋ ์ด๋ฌํ ์ฅ์ ์ ๋๋ฆด ์ ์๋๋ก ๋์์ธ ๋์์ต๋๋ค.
(ํ์ฌ torch์์ MPS๋ฅผ ์ฌ์ฉํ์ฌ GPU๋ฅผ ์ฌ์ฉํ ์ ์์ง๋ง ์ฑ๋ฅ์ ์ ๋๋ก ์ฌ์ฉํ์ง ๋ชปํ๊ณ ์์ต๋๋ค. MLX๋ M3 Macbook pro ์ฑ๋ฅ์ ์ ๋๋ก ์ด๋ ๊ฒ ๊ฐ์์ ๊ธฐ๋๊ฐ ๋ฉ๋๋ค.)
# ๋๊ฐ์ array ์์ฑ
a = mx.random.normal((100,))
b = mx.random.normal((100,))
MLX์์๋ operation์ ์ํ device๋ฅผ ์ง์ ํด์ค ์ ์์ต๋๋ค.
์ฆ, memory ์์น์ ์ด๋ ์์ด CPU ์ฐ์ฐ๊ณผ GPU ์ฐ์ฐ์ ๋ชจ๋ ํ ์ ์์ต๋๋ค.
# dependency ์กด์ฌ X
mx.add(a, b, stream = mx.cpu)
mx.add(a, b, stream = mx.gpu)
# dependency ์กด์ฌ O
c = mx.add(a, b, stream = mx.cpu)
d = mx.add(a, b, stream = mx.gpu)
dependency๊ฐ ์กด์ฌํ์ง ์์ ๊ฒฝ์ฐ ๋ณ๋ ฌ์ ์ผ๋ก ๊ฐ๊ฐ ์ฐ์ฐ์ด ๋ฉ๋๋ค.
ํ์ง๋ง dependency๊ฐ ์กด์ฌํ ๊ฒฝ์ฐ ์ฒซ ๋ฒ์งธ ์ฐ์ฐ์ด ๋๋ ํ ๋ ๋ฒ์งธ ์ฐ์ฐ์ด ์์๋ฉ๋๋ค.
(โcโ ์ฐ์ฐ ํ โdโ ์ฐ์ฐ)
์ฐ์ฐ์ ์ข ๋ฅ์ ๋ฐ๋ผ์ CPU๊ฐ ์ ๋ฆฌํ ์๋ ์๊ณ GPU๊ฐ ์ ๋ฆฌํ ์๋ ์์ต๋๋ค.
matmul ์ฐ์ฐ์ GPU์์ธ ์ ๋ฆฌํ ์ฐ์ฐ์ ๋๋ค. ํ์ง๋ง for loop๋ก ์ด๋ฃจ์ด์ง ์ฐ์ฐ์ CPU์ ์ ๋ฆฌํ ์ฐ์ฐ์ ๋๋ค.
์๋์ ๋ด์ฉ์ผ๋ก ์ฐ์ฐ ์๋๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
def fun(a, b, d1, d2):
x = mx.matmul(a, b, stream=d1)
mx.eval(x) # mx.eval ํจ์๋ ๊ฐ์ ๋ก ์ฐ์ฐ์ ์ํ์ํจ๋ค.
for _ in range(500):
b = mx.exp(b, stream=d2)
mx.eval(b)
return x, b
a = mx.random.uniform(shape=(4096, 512))
b = mx.random.uniform(shape=(512, 4))
start = time.time()
fun(a, b, mx.cpu, mx.cpu)
print(f"cpu elapsed time: {time.time()-start}")
start = time.time()
fun(a, b, mx.gpu, mx.gpu)
print(f"gpu elapsed time: {time.time()-start}")
start = time.time()
fun(a, b, mx.cpu, mx.gpu)
print(f"cpu-gpu elapsed time: {time.time()-start}")
start = time.time()
fun(a, b, mx.gpu, mx.cpu)
print(f"gpu-cpu elapsed time: {time.time()-start}")
cpu elapsed time: 0.024873018264770508 gpu elapsed time: 0.11129403114318848 cpu-gpu elapsed time: 0.062744140625 gpu-cpu elapsed time: 0.0065081119537353516
MXL์ stream์ ์ง์ ํ์ง ์์ผ๋ฉด default_device๋ก ์ค์ ๋์ด ์์ต๋๋ค.
M3 pro ๊ธฐ์ค์ GPU์ ๋๋ค.
print(mx.default_device())
print(mx.default_stream(mx.default_device()))
Device(gpu, 0) Stream(Device(gpu, 0), 0)
References
-
MLX(MLX ํํ์ด์ง)
-
SKT Enterprise(MLX ์ค๋ช )
-
Medium(Lazy computation ์ค๋ช )