JAX — новичок в мире машинного обучения (ML), и он обещает сделать программирование ML более интуитивно понятным, структурированным и чистым. Возможно, он сможет заменить подобные Tensorflow и PyTorch, несмотря на то, что он сильно отличается по своей сути.
Как сказал мой друг, у нас были всевозможные тузы, короли и дамы. Теперь у нас есть JAX.
В этой статье мы рассмотрим, что такое JAX и почему его следует использовать вместо всех других библиотек. Мы выскажем свое мнение, используя фрагменты кода, демонстрирующие мощь JAX, и представим некоторые его полезные функции.
Если это звучит интересно, заходите.
Что такое Джакс?
Jax — это библиотека Python, предназначенная для высокопроизводительных исследований машинного обучения. Jax — это не что иное, как библиотека числовых вычислений, такая же, как Numpy, но с некоторыми ключевыми улучшениями. Он был разработан Google и использовался внутри команд Google и Deepmind.
Источник: JAX-документация
Установить JAX
Прежде чем мы обсудим основные преимущества JAX, я предлагаю вам установить JAX в вашей среде Python или в совместной лаборатории Google, чтобы вы могли самостоятельно выполнять код. Ссылку на полный код я, конечно же, оставлю в конце статьи.
Чтобы установить JAX, мы можем просто использовать pip
из нашей командной строки:
$ pip install --upgrade jax jaxlib
Обратите внимание, что это будет поддерживать выполнение только на ЦП. Если вы также хотите поддерживать GPU, вам сначала нужно CUDA и cuDNN а затем выполните следующую команду (не забудьте сопоставить версию jaxlib с вашей версией CUDA):
$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Для устранения неполадок обратитесь к официальному Инструкции на гитхабе.
Теперь давайте импортируем JAX вместе с Numpy. Мы будем использовать Numpy для сравнения различных вариантов использования.
import jax
import jax.numpy as jnp
import numpy as np
Основы JAX
Начнем с основ. Как мы уже говорили, основная и единственная цель JAX — выполнять числовые операции выразительным и высокопроизводительным способом. Это означает, что синтаксис почти идентичен Numpy. Например, если мы хотим создать массив нулей, у нас будет:
x = np.zeros(10)
y= jnp.zeros(10)
Разница кроется за кадром.
Массив устройств
Вы видите, что одно из основных преимуществ JAX заключается в том, что мы можем запустить одну и ту же программу без каких-либо изменений в аппаратных ускорителях, таких как GPU и TPU..
Это достигается базовой структурой, называемой Массив устройствчто по существу заменяет Стандартный массив Numpy.
DeviceArrays ленивы, что означает, что они сохраняют значения в ускорителе и вытягивают их только при необходимости.
x
y
Мы можем использовать DeviceArrays точно так же, как мы используем стандартные массивы. Мы можем передать его другим библиотекам, построить графики, выполнить дифференцирование, и все заработает. Также обратите внимание, что большинство API Numpy (функций и операций) поддерживаются JAX, поэтому ваш код JAX будет почти идентичен Numpy.
Другая важная вещь — скорость. Ну JAX быстрее. Намного быстрее. Давайте рассмотрим простой пример. Мы создаем два массива размером (1000, 1000), один с Numpy и один с JAX, и вычисляем внутренний продукт с самим собой.
Давайте timeit
две операции
x = np.random.rand(1000,1000)
y = jnp.array(x)
%timeit -n 1 -r 1 np.dot(x,x)
%timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready()
Впечатляет, правда? Что ж, это ожидаемо. Вычисления выполняются быстрее в графических процессорах. Также вы заметили, block_until_ready()
функция. Поскольку JAX является асинхронным, нам нужно дождаться завершения выполнения, чтобы правильно измерить время.
Вы не можете поверить, что это все, что может предложить JAX, верно?
Теперь о хорошем…
Почему ДЖАКС?
Если вам недостаточно скорости и автоматической поддержки графических процессоров, я вас не виню. Кажется, что любая другая библиотека может справиться с этим. Чтобы лучше понять преимущества JAX, нам нужно погрузиться глубже. JAX можно рассматривать как набор функциональных преобразований обычного кода Python и Numpy.
Примером таких преобразований является дифференциация. Поддерживает ли JAX автоматическую дифференциацию?
Я уверен, что вы правильно догадались.
Автоматическое дифференцирование с помощью функции grad()
JAX может различать все виды функций Python и NumPy, включая циклы, ветки, рекурсии и многое другое.
Это невероятно полезно для приложений глубокого обучения, поскольку мы можем без особых усилий запускать обратное распространение. Основная функция для достижения этого называется grad()
. Вот пример. Определим простую квадратичную функцию и возьмем ее производную в точке 1.0.
Чтобы доказать, что результат правильный, мы также вычислим производную вручную.
from jax import grad
def f(x):
return 3*x**2 + 2*x + 5
def f_prime(x):
return 6*x +2
grad(f)(1.0)
f_prime(1.0)
Меня очень удивило, что JAX на самом деле делает аналитический градиентное решение под капотом вместо какой-либо другой причудливой техники. Он просто принимает форму функции и выполняет цепное правило. Поскольку автоматическая дифференциация — это гораздо больше, я настоятельно рекомендую ознакомиться с официальная документация для более полного понимания.
Ускоренная линейная алгебра (компилятор XLA)
Одним из факторов, делающих JAX таким быстрым, является ускоренная линейная алгебра или XLA.
XLA — это предметно-ориентированный компилятор для линейной алгебры, который широко используется Tensorflow.
Чтобы выполнять матричные операции как можно быстрее, код компилируется в набор вычислительных ядер, которые можно значительно оптимизировать в зависимости от характера кода.
Примеры таких оптимизаций включают:
Своевременная компиляция (jit)
Компиляция «как раз вовремя» идет рука об руку с XLA. Чтобы воспользоваться преимуществами XLA, код должен быть скомпилирован в ядра XLA. Это где jit
вступает в игру.
Компиляция «точно в срок» (JIT) — это способ выполнения компьютерного кода, который включает компиляцию во время выполнения программы — во время выполнения, а не перед выполнением.
Чтобы использовать XLA и jit, можно использовать либо jit()
функция или @jit
аннотация.
from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
for _ in range(10):
x = 0.5*x + 0.1* jnp.sin(x)
return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
%timeit -n 5 -r 5 g(y).block_until_ready()
И снова улучшение времени выполнения более чем очевидно. Конечно, jit
также можно сочетать с grad
преобразование (или любое другое преобразование в этом отношении), что делает обратное распространение очень быстрым.
Также обратите внимание, что jit
имеет некоторые недостатки: например, если он не может точно представить функцию (что обычно происходит с ветвями «если»), он, скорее всего, потерпит неудачу. Однако для большинства случаев использования, связанных с глубоким обучением, это невероятно полезно.
Повторяйте вычисления на устройствах с помощью pmap
Pmap — это еще одно преобразование, которое позволяет нам реплицировать вычисления на несколько ядер или устройств и выполнять их параллельно (p в pmap означает «параллельный»).
Он автоматически распределяет вычисления по всем текущим устройствам и обрабатывает всю связь между ними. Чтобы проверить доступные устройства, вы можете запустить jax.devices()
.
from jax import pmap
def f(x):
return jnp.sin(x) + x**2
f(np.arange(4))
pmap(f)(np.arange(4))
Обратите внимание, что DeviceArray теперь стал ShardedDeviceArray, который представляет собой структуру, которая обрабатывает параллельное выполнение.
Еще одна очень крутая вещь, которую JAX позволяет нам делать, это коллективное общение между устройствами. Допустим, мы хотим выполнить операцию «уменьшить» между значениями на всех устройствах (например, взять сумму). Для этого нам нужно собрать все данные со всех устройств и выполнить сумму. Это легко сделать следующим образом:
from functools import partial
from jax.lax import psum
@partial(pmap, axis_name="i")
def normalize(x):
return x/ psum(x,'i')
normalize(np.arange(8.))
Приведенный выше код сопоставляет вектор x со всеми устройствами и запускает коллективную операцию связи для выполнения psum
(параллельная сумма). Другими словами, он собирает все «x» с устройств, суммирует их и возвращает результат каждому устройству для продолжения параллельных вычислений. Я позаимствовал приведенный выше пример из этого потрясающий разговор Мэтью Джонсона во время GTC 2020.
Вы также можете представить, что с pmap
мы можем определить наши собственные схемы вычислений и наилучшим образом использовать наши устройства. Так же, как мы обычно делаем с CUDA для отдельных ядер, но на этот раз для отдельных устройств.
Автоматическая векторизация с vmap
Vmap, как следует из названия, представляет собой преобразование функции, которое позволяет нам векторизовать функции (v означает вектор!).
Мы можем взять функцию, которая работает с одной точкой данных, и векторизовать ее, чтобы она могла принимать пакет этих точек данных (или вектор) произвольного размера. Вот пример:
from jax import vmap
def f(x):
return jnp.square(x)
f(jnp.arange(10))
vmap(f)(jnp.arange(10))
Вы можете задаться вопросом, что мы получили здесь. Чтобы понять это, давайте взглянем на то, что происходит, когда f(x)
выполняется без vmap
:
-
Выходной список инициализируется.
-
Квадрат 0 вычисляется и возвращается.
-
Результат 0 добавляется к списку.
-
Квадрат 1 вычисляется и возвращается.
-
Результат 1 добавляется к списку.
-
Квадрат 2 вычисляется и возвращается.
-
Результат 4 добавляется к списку.
-
И так далее…
Что vmap делает, так это то, что он выполняет квадратную операцию только один раз, потому что он объединяет все значения вместе и передает их через функцию. И, конечно же, это приводит к увеличению как скорости, так и потребления памяти.
Хотя вышеупомянутые преобразования — это те, которые вам обязательно нужно знать, я хотел бы упомянуть еще несколько вещей, которые удивили меня во время моего путешествия по JAX.
Генератор псевдослучайных чисел
Генератор случайных чисел JAX работает немного иначе, чем Numpy. Вместо того, чтобы быть стандартным генератором псевдослучайных чисел (PRNG) с отслеживанием состояния, как в Numpy и Scipy, все случайные функции JAX требуют, чтобы в качестве первого аргумента передавалось явное состояние PRNG.
Генератор случайных чисел имеет состояние. Следующее «случайное» число является функцией предыдущего числа и начального числа/состояния. Последовательность случайных значений конечна и повторяется.
Важно отметить, что PRNG хорошо работают как с точки зрения векторизации, так и с точки зрения параллельных вычислений между устройствами.
from jax import random
key = random.PRNGKey(5)
random.uniform(key)
Асинхронная отправка
Другой аспект JAX, который меня впечатлил, заключается в том, что он использует асинхронную диспетчеризацию. Это означает, что он не ждет завершения операций, прежде чем вернуть управление программе Python. Вместо этого он возвращает DeviceArray
это будущее (точно так же, как Дополняемое будущее в Java)
Будущее — это значение, которое будет создано в будущем на устройстве-ускорителе, но не обязательно будет доступно немедленно.
Будущее можно передавать другим операциям, не дожидаясь завершения вычислений. Таким образом, JAX позволяет коду Python работать перед ускорителем, гарантируя, что он может ставить в очередь операции для аппаратного ускорителя (например, GPU) без необходимости ожидания.
Профилирование JAX и профилировщик памяти устройства
Последняя функция, о которой я хочу упомянуть, это профилирование. Вам будет приятно узнать, что Tensoboard поддерживает профилирование JAX.
!(Профилирование Tensorboard JAX)(Профилирование Tensorboard JAX.png)
Источник: JAX-документация
То же самое верно для Nvidia Nsight, который используется для отладки и профилирования кода графического процессора. Наряду с этим можно также использовать встроенный в JAX профилировщик памяти устройств, который обеспечивает представление о том, как код JAX выполняется на графических процессорах и TPU. Вот фрагмент из документации:
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
Если вы установили ппрофбиблиотеку от Google, вы можете выполнить следующую команду, которая откроет окно браузера со всей необходимой информацией.
$ pprof --web memory.prof
!(Профилирование памяти устройства)(Профилирование памяти устройства.png)
Источник: JAX-документация
Это круто или что?
Не стесняйтесь играть с ним. Я знаю, что я сделал.
Заключение
В этом посте я попытался дать обзор преимуществ JAX по сравнению с другими библиотеками и представить простые фрагменты кода, чтобы изучить его основной синтаксис и тонкости. Кстати, вы можете найти полный код в этом коллаб блокнот или в нашем репозиторий github.
В следующих статьях мы сделаем еще один шаг и рассмотрим, как создавать и обучать глубокие нейронные сети с помощью JAX, а также взглянем на различные фреймворки, построенные на его основе.
Если статья показалась вам интересной, не забудьте поделиться ею в социальных сетях.
Рекомендации
* Раскрытие информации: обратите внимание, что некоторые из приведенных выше ссылок могут быть партнерскими ссылками, и мы без дополнительных затрат для вас получим комиссию, если вы решите совершить покупку после перехода по ссылке.