2 minute read

MLX๋ž€?

MLX๋Š” Apple silicon์—์„œ ๋จธ์‹ ๋Ÿฌ๋‹์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ๋งŒ๋“ค์–ด์ง„ array framework์ž…๋‹ˆ๋‹ค.

Apple silicon๋งŒ์˜ CPU์™€ GPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฒกํ„ฐ์™€ ๊ทธ๋ž˜ํ”„ ์—ฐ์‚ฐ ์†๋„๋ฅผ ํฌ๊ฒŒ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์„ค์น˜ ์š”๊ตฌ ์‚ฌํ•ญ

  1. M ์‹œ๋ฆฌ์ฆˆ apple silicon

  2. native Python >= 3.8

  3. MacOS >= 13.3

!python -c "import platform; print(platform.processor())" # It must be arm
!pip install mlx
arm
Collecting mlx
  Obtaining dependency information for mlx from https://files.pythonhosted.org/packages/8f/e7/40e631abca0823399ad5f89e2fd849393d7e6a8f3efd2cf1a3ef4ceb0df0/mlx-0.0.11-cp311-cp311-macosx_14_0_arm64.whl.metadata
  Downloading mlx-0.0.11-cp311-cp311-macosx_14_0_arm64.whl.metadata (4.9 kB)
Downloading mlx-0.0.11-cp311-cp311-macosx_14_0_arm64.whl (17.1 MB)
   โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ” 17.1/17.1 MB 5.9 MB/s eta 0:00:00[36m0:00:01m eta 0:00:01
[?25hInstalling collected packages: mlx
Successfully installed mlx-0.0.11

Basic Quick Start

array(๋ฐฐ์—ด)๋ฅผ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด mlx.core๋ฅผ import ํ•ฉ๋‹ˆ๋‹ค.

import mlx.core as mx

a = mx.array([1,2,3,4])
print(f'a shape: {a.shape}')
print(f'a dtype: {a.dtype}')

b = mx.array([1.0, 2.0, 3.0, 4.0])
print(f'b shape: {b.shape}')
print(f'b dtype: {b.dtype}')

c = mx.array([[1,2,3,4],[5,6,7,8]])
print(f'c shape: {c.shape}')
print(f'c dtype: {c.dtype}')

d = mx.array([[1,2,3,4],[5.0,6.0,7.0,8.0]])
print(f'd shape: {d.shape}')
print(f'd dtype: {d.dtype}')
a shape: [4]
a dtype: mlx.core.int32
b shape: [4]
b dtype: mlx.core.float32
c shape: [2, 4]
c dtype: mlx.core.int32
d shape: [2, 4]
d dtype: mlx.core.float32

MLX๋Š” lazy evaluation์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

lazy evaluation์ด๋ž€?
์‹ค์ œ๋กœ ์—ฐ์‚ฐ ๊ฒฐ๊ณผ๊ฐ€ ์–ด๋”˜๊ฐ€์— ์‚ฌ์šฉ๋˜๊ธฐ ์ „๊นŒ์ง€ ์—ฐ์‚ฐ์„ ๋ฏธ๋ฃจ๋Š” ํ”„๋กœ๊ทธ๋ž˜๋ฐ ๋ฐฉ๋ฒ•๋ก ์ž…๋‹ˆ๋‹ค.

lazy evaluation์€ ์„ฑ๋Šฅ ๊ด€์ ์—์„œ ์ด๋“์„ ๋ณด๊ฑฐ๋‚˜ ์˜ค๋ฅ˜๋ฅผ ํšŒํ”ผ ํ˜น์€ ๋ฌดํ•œ ์ž๋ฃŒ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์žฅ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

์—ฐ์‚ฐ์ด ์ผ์–ด๋‚˜์ง€ ์•Š์€ ๊ฒƒ๊ณผ ์—ฐ์‚ฐ์ด ์ผ์–ด๋‚˜๋Š” ๊ฒƒ์— ๋Œ€ํ•œ ์‹œ๊ฐ„ ์†Œ๋ชจ๋ฅผ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

import time

start = time.time()
for _ in range(100):
    c = a + b # ์‹ค์ œ๋กœ ์—ฐ์‚ฐ์ด ์ผ์–ด๋‚˜์ง€ ์•Š๋Š”๋‹ค, c๊ฐ€ ์‚ฌ์šฉ๋˜์ง€ ์•Š์•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
print(f'lazy evaluation time: {time.time()-start}')

start = time.time()
for _ in range(100):
    c = a + b
    mx.eval(c) # ์—ฐ์‚ฐ์ด ์ˆ˜ํ–‰๋œ๋‹ค. (mx.eval ํ•จ์ˆ˜๋Š” ๊ฐ•์ œ๋กœ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰์‹œํ‚จ๋‹ค.)
print(f'forced evaluation time: {time.time()-start}')
lazy evaluation time: 0.0003581047058105469
forced evaluation time: 0.02145099639892578

Unified Memory

Apple Silicon์€ CPU์™€ GPU๊ฐ€ ๋ณ„๊ฐœ์˜ ์žฅ์น˜๋กœ ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

ํ•˜๋‚˜์˜ Unifired Memory Architecture(UMA)๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์œผ๋ฉฐ, CPU์™€ GPU๊ฐ€ ๋™์ผํ•œ memory pool์—์„œ ์ง์ ‘์ ์œผ๋กœ ์ ‘๊ทผ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

MLX๋Š” ์ด๋Ÿฌํ•œ ์žฅ์ ์„ ๋ˆ„๋ฆด ์ˆ˜ ์žˆ๋„๋ก ๋””์ž์ธ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

(ํ˜„์žฌ torch์—์„œ MPS๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ GPU๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์ง€๋งŒ ์„ฑ๋Šฅ์„ ์ œ๋Œ€๋กœ ์‚ฌ์šฉํ•˜์ง€ ๋ชปํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. MLX๋Š” M3 Macbook pro ์„ฑ๋Šฅ์„ ์ œ๋Œ€๋กœ ์ด๋Œ ๊ฒƒ ๊ฐ™์•„์„œ ๊ธฐ๋Œ€๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.)

# ๋‘๊ฐœ์˜ array ์ƒ์„ฑ
a = mx.random.normal((100,))
b = mx.random.normal((100,))

MLX์—์„œ๋Š” operation์„ ์œ„ํ•œ device๋ฅผ ์ง€์ •ํ•ด์ค„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฆ‰, memory ์œ„์น˜์˜ ์ด๋™ ์—†์ด CPU ์—ฐ์‚ฐ๊ณผ GPU ์—ฐ์‚ฐ์„ ๋ชจ๋‘ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

# dependency ์กด์žฌ X
mx.add(a, b, stream = mx.cpu)
mx.add(a, b, stream = mx.gpu)

# dependency ์กด์žฌ O
c = mx.add(a, b, stream = mx.cpu)
d = mx.add(a, b, stream = mx.gpu)

dependency๊ฐ€ ์กด์žฌํ•˜์ง€ ์•Š์„ ๊ฒฝ์šฐ ๋ณ‘๋ ฌ์ ์œผ๋กœ ๊ฐ๊ฐ ์—ฐ์‚ฐ์ด ๋ฉ๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ dependency๊ฐ€ ์กด์žฌํ•  ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์—ฐ์‚ฐ์ด ๋๋‚œ ํ›„ ๋‘ ๋ฒˆ์งธ ์—ฐ์‚ฐ์ด ์‹œ์ž‘๋ฉ๋‹ˆ๋‹ค.

(โ€˜cโ€™ ์—ฐ์‚ฐ ํ›„ โ€˜dโ€™ ์—ฐ์‚ฐ)


์—ฐ์‚ฐ์˜ ์ข…๋ฅ˜์— ๋”ฐ๋ผ์„œ CPU๊ฐ€ ์œ ๋ฆฌํ•  ์ˆ˜๋„ ์žˆ๊ณ  GPU๊ฐ€ ์œ ๋ฆฌํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

matmul ์—ฐ์‚ฐ์€ GPU์—์„ธ ์œ ๋ฆฌํ•œ ์—ฐ์‚ฐ์ž…๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ for loop๋กœ ์ด๋ฃจ์–ด์ง„ ์—ฐ์‚ฐ์€ CPU์— ์œ ๋ฆฌํ•œ ์—ฐ์‚ฐ์ž…๋‹ˆ๋‹ค.

์•„๋ž˜์˜ ๋‚ด์šฉ์œผ๋กœ ์—ฐ์‚ฐ ์†๋„๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def fun(a, b, d1, d2):
    x = mx.matmul(a, b, stream=d1)
    mx.eval(x) # mx.eval ํ•จ์ˆ˜๋Š” ๊ฐ•์ œ๋กœ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰์‹œํ‚จ๋‹ค.
    for _ in range(500):
        b = mx.exp(b, stream=d2)
        mx.eval(b)
    return x, b

a = mx.random.uniform(shape=(4096, 512))
b = mx.random.uniform(shape=(512, 4))

start = time.time()
fun(a, b, mx.cpu, mx.cpu)
print(f"cpu elapsed time: {time.time()-start}")

start = time.time()
fun(a, b, mx.gpu, mx.gpu)
print(f"gpu elapsed time: {time.time()-start}")

start = time.time()
fun(a, b, mx.cpu, mx.gpu)
print(f"cpu-gpu elapsed time: {time.time()-start}")

start = time.time()
fun(a, b, mx.gpu, mx.cpu)
print(f"gpu-cpu elapsed time: {time.time()-start}")
cpu elapsed time: 0.024873018264770508
gpu elapsed time: 0.11129403114318848
cpu-gpu elapsed time: 0.062744140625
gpu-cpu elapsed time: 0.0065081119537353516

MXL์€ stream์„ ์ง€์ •ํ•˜์ง€ ์•Š์œผ๋ฉด default_device๋กœ ์„ค์ •๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

M3 pro ๊ธฐ์ค€์€ GPU์ž…๋‹ˆ๋‹ค.

print(mx.default_device())
print(mx.default_stream(mx.default_device()))
Device(gpu, 0)
Stream(Device(gpu, 0), 0)

References

  • MLX(MLX ํ™ˆํŽ˜์ด์ง€)

  • SKT Enterprise(MLX ์„ค๋ช…)

  • Medium(Lazy computation ์„ค๋ช…)