В этом руководстве мы рассмотрим, как разработать нейронную сеть (NN) с помощью JAX. И какую лучше модель выбрать, чем Трансформер. По мере роста популярности JAX все больше и больше команд разработчиков начинают экспериментировать с ним и включать его в свои проекты. Несмотря на то, что ему не хватает зрелости Tensorflow или Pytorch, он предоставляет некоторые отличные функции для создания и обучения моделей глубокого обучения.
Чтобы получить полное представление об основах JAX, прочтите мою предыдущую статью, если вы еще этого не сделали. Также вы можете найти полный код в нашем Репозиторий на гитхабе.
Одна из распространенных проблем, с которой сталкиваются люди, начинающие работать с JAX, — это выбор фреймворка. Похоже, что люди в Deepmind очень заняты и уже выпустили множество фреймворков поверх JAX. Вот список самых известных из них:
-
Хайку: Haiku — это фреймворк для глубокого обучения, который используется многими внутренними командами Google и Deepmind. Он предоставляет несколько простых компонуемых абстракций для исследований в области машинного обучения, а также готовые к использованию модули и слои.
-
Оптакс: Optax — это библиотека обработки и оптимизации градиентов, содержащая готовые оптимизаторы и связанные с ними математические операции.
-
Рлакс: RLax — это система обучения с подкреплением со множеством подкомпонентов и операций RL.
-
Чекс: Chex — это библиотека утилит для тестирования и отладки JAX-кода.
-
Джраф: Jraph — это библиотека графовых нейронных сетей в JAX.
-
Лен: Flax — еще одна нейросетевая библиотека с множеством готовых к использованию модулей, оптимизаторов и утилит. Это, скорее всего, самое близкое, что у нас есть во всей JAX-инфраструктуре.
-
Обьякс: Objax — это третья библиотека мл, ориентированная на объектно-ориентированное программирование и читабельность кода. Опять же, он содержит самые популярные модули, функции активации, потери, оптимизаторы, а также несколько предварительно обученные модели.
-
Тракс: Trax — это сквозная библиотека для глубокого обучения, ориентированная на трансформеров.
-
JAXline: JAXline — это библиотека обучения с учителем, которая используется для распределенное обучение JAX и оценка.
-
АКМЕ: ACME — еще одна исследовательская структура для обучения с подкреплением.
-
ДЖАКС-МД: JAX-MD — нишевая структура, занимающаяся молекулярной динамикой.
-
Джакшем: JAXChem — еще одна нишевая библиотека, в которой особое внимание уделяется химическому моделированию.
Конечно, вопрос в том, что мне выбрать?
Честно говоря, я не уверен.
Но если бы я был на вашем месте и хотел бы изучить JAX, я бы начал с самых популярных. Haiku и Flax часто используются в Google/Deepmind и имеют самое активное сообщество на Github. В этой статье я начну с первого и посмотрю, понадобится ли мне еще один в будущем.
Итак, вы готовы построить Transformer с помощью JAX и Haiku? Кстати, я предполагаю, что вы хорошо разбираетесь в трансформерах. Если нет, посоветуйте наши статьи о внимании и трансформерах.
Начнем с блока внимания к себе.
Блок внимания к себе
Во-первых, нам нужно импортировать JAX и Haiku.
import jax
import jax.numpy as jnp
import haiku as hk
Import numpy as np
К счастью для нас, в Haiku есть встроенный MultiHeadAttention
блок, который можно расширить, чтобы построить замаскированный блок внутреннего внимания. Наш блок принимает запрос, ключ, значение, а также маску и возвращает результат в виде массива JAX. Вы можете видеть, что код очень похож на стандартный код Pytorch или Tensorflow. Все, что мы делаем, это строим каузальную маску, используя np.trill()
которые обнуляют все элементы массива выше k-го, умножают на нашу маску и передают все в hk.MultiHeadAttention
модуль.
class SelfAttention(hk.MultiHeadAttention):
"""Self attention with a causal mask applied."""
def __call__(
self,
query: jnp.ndarray,
key: Optional(jnp.ndarray) = None,
value: Optional(jnp.ndarray) = None,
mask: Optional(jnp.ndarray) = None,
) -> jnp.ndarray:
key = key if key is not None else query
value = value if value is not None else query
seq_len = query.shape(1)
causal_mask = np.tril(np.ones((seq_len, seq_len)))
mask = mask * causal_mask if mask is not None else causal_mask
return super().__call__(query, key, value, mask)
Этот фрагмент позволяет мне представить первый ключевой принцип Хайку. Все модули должны быть подклассом hk.Module
. Это означает, что они должны реализовать __init__
и __call__
, наряду с любым другим методом. В некотором смысле это та же архитектура с модулями Pytorch, где мы реализуем __init__
и forward
.
Чтобы сделать это кристально ясным, давайте создадим простой двухслойный многослойный персептрон. hk.Module
который удобно будет использовать в Transformer ниже.
Линейный слой
Простой двухслойный MLP будет выглядеть так. Еще раз, вы можете заметить, как знакомо это выглядит.
class DenseBlock(hk.Module):
"""A 2-layer MLP"""
def __init__(self,
init_scale: float,
widening_factor: int = 4,
name: Optional(str) = None):
super().__init__(name=name)
self._init_scale = init_scale
self._widening_factor = widening_factor
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hiddens = x.shape(-1)
initializer = hk.initializers.VarianceScaling(self._init_scale)
x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
x = jax.nn.gelu(x)
return hk.Linear(hiddens, w_init=initializer)(x)
Несколько вещей, на которые следует обратить внимание:
-
Haiku предоставляет нам набор инициализаторов весов под
hk.initializers
где мы можем найти наиболее общие подходы. -
Он также имеет встроенные многие популярные слои и модули, такие как
hk.Linear
. Полный список см. официальная документация. -
Функции активации не предоставляются, поскольку в JAX уже есть подпакет с именем
jax.nn
где мы можем найти функции активации такой какrelu
илиsoftmax
.
Слой нормализации
Нормализация слоев — еще один неотъемлемый блок архитектуры преобразователя, который мы также можем найти в общих модулях внутри Haiku.
def layer_norm(x: jnp.ndarray, name: Optional(str) = None) -> jnp.ndarray:
"""Apply a unique LayerNorm to x with default settings."""
return hk.LayerNorm(axis=-1,
create_scale=True,
create_offset=True,
name=name)(x)
Трансформер
А теперь о хорошем. Ниже вы можете найти очень упрощенный Transformer, в котором используются наши предустановленные модули. Внутри __init__
, мы определяем основные переменные, такие как количество слоев, внимание и процент отсева. Внутри __call__
составляем список блоков с помощью for
петля.
Как видите, каждый блок включает в себя:
В конце мы также добавляем финальный слой нормализации.
class Transformer(hk.Module):
"""A transformer stack."""
def __init__(self,
num_heads: int,
num_layers: int,
dropout_rate: float,
name: Optional(str) = None):
super().__init__(name=name)
self._num_layers = num_layers
self._num_heads = num_heads
self._dropout_rate = dropout_rate
def __call__(self,
h: jnp.ndarray,
mask: Optional(jnp.ndarray),
is_training: bool) -> jnp.ndarray:
"""Connects the transformer.
Args:
h: Inputs, (B, T, H).
mask: Padding mask, (B, T).
is_training: Whether we're training or not.
Returns:
Array of shape (B, T, H).
"""
init_scale = 2. / self._num_layers
dropout_rate = self._dropout_rate if is_training else 0.
if mask is not None:
mask = mask(:, None, None, :)
for i in range(self._num_layers):
h_norm = layer_norm(h, name=f'h{i}_ln_1')
h_attn = SelfAttention(
num_heads=self._num_heads,
key_size=64,
w_init_scale=init_scale,
name=f'h{i}_attn')(h_norm, mask=mask)
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
h_norm = layer_norm(h, name=f'h{i}_ln_2')
h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
h = layer_norm(h, name='ln_f')
return h
Я думаю, вы уже поняли, что построить нейронную сеть с помощью JAX очень просто.
Слой встраивания
Для завершения давайте также включим слой встраивания. Полезно знать, что Haiku также предоставляет слой внедрения, который будет создавать токены из нашего входного предложения. Затем токен добавляется к позиционным вложениям, которые производят окончательный ввод.
def embeddings(data: Mapping(str, jnp.ndarray), vocab_size: int) :
tokens = data('obs')
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape(1)
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter(
'pos_embs', (seq_length, d_model), init=embed_init)
input_embeddings = token_embs + positional_embeddings
return input_embeddings, input_mask
hk.get_parameter(param_name, ...)
используется для доступа к обучаемым параметрам модуля. Но вы можете спросить, почему бы просто не использовать свойства объекта, как мы это делаем в Pytorch. Здесь вступает в действие второй ключевой принцип хайку. Мы используем этот API, чтобы преобразовать код в чистую функцию, используя hk.transform
. Это не очень просто понять, но я постараюсь сделать это как можно более ясным.
Почему чистые функции?
Сила JAX заключается в преобразованиях функций: возможность векторизовать функцию с помощью vmap
автоматическое распараллеливание с pmap
своевременная компиляция с jit
. Предостережение здесь заключается в том, что для преобразования функции она должна быть чистой.
А чистая функция это функция, обладающая следующими свойствами:
-
Возвращаемые значения функции идентичны для идентичных аргументов (никаких различий с локальными статическими переменными, нелокальными переменными, изменяемыми ссылочными аргументами или входными потоками).
-
Приложение функции не имеет побочных эффектов (без изменения локальных статических переменных, нелокальных переменных, изменяемых ссылочных аргументов или потоков ввода/вывода).
Источник: Чистые функции Scala от O’Reily
Практически это означает, что чистая функция всегда будет:
-
вернуть тот же результат, если он вызывается с теми же входными данными
-
все входные данные передаются через аргументы функции, все результаты выводятся через результаты функции
Haiku предоставляет функциональную трансформацию, называемую hk.transform
, который превращает функции с объектно-ориентированными, функционально «нечистыми» модулями в чистые функции, которые можно использовать с JAX. Чтобы увидеть это на практике, давайте продолжим обучение нашей модели Transformer.
Проход вперед
Типичный прямой проход включает в себя:
-
Получение входных данных и вычисление входного встраивания
-
Пробегите блоки Трансформера
-
Вернуть вывод
Вышеупомянутые шаги можно легко составить с помощью JAX следующим образом:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
num_layers: int, dropout_rate: float):
"""Create the model's forward pass."""
def forward_fn(data: Mapping(str, jnp.ndarray),
is_training: bool = True) -> jnp.ndarray:
"""Forward pass."""
input_embeddings, input_mask = embeddings(data, vocab_size)
transformer = Transformer(
num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
output_embeddings = transformer(input_embeddings, input_mask, is_training)
return hk.Linear(vocab_size)(output_embeddings)
return forward_fn
Хотя код прост, его структура может показаться немного странной. Фактический прямой проход выполняется через forward_fn
функция. Тем не менее, мы обернем это с помощью build_forward_fn
функция, которая возвращает forward_fn
. Какого черта?
В дальнейшем нам нужно будет преобразовать forward_fn
функцию в чистую функцию, используя hk.transform
чтобы мы могли воспользоваться преимуществами автоматического дифференцирования, распараллеливания и т. д.
Это будет достигнуто за счет:
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
Вот почему вместо простого определения функции мы упаковываем и возвращаем саму функцию или вызываемый если быть точнее. Затем этот вызываемый объект может быть передан в hk.transform
и стать чистой функцией. Если это понятно, давайте продолжим нашу функцию потерь.
Функция потерь
Функция потерь — это наша хорошо известная кросс-энтропийная функция с той разницей, что мы также учитываем маску. И снова JAX предоставляет one_hot
и log_softmax
функциональные возможности.
def lm_loss_fn(forward_fn,
vocab_size: int,
params,
rng,
data: Mapping(str, jnp.ndarray),
is_training: bool = True) -> jnp.ndarray:
"""Compute the loss on data wrt params."""
logits = forward_fn(params, rng, data, is_training)
targets = jax.nn.one_hot(data('target'), vocab_size)
assert logits.shape == targets.shape
mask = jnp.greater(data('obs'), 0)
loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return loss
Если вы все еще со мной, сделайте глоток кофе, потому что с этого момента все станет серьезно. Пришло время построить наш тренировочный цикл.
Тренировочный цикл
Поскольку ни Jax, ни Haiku не имеют встроенных функций оптимизации, мы воспользуемся другой структурой, называемой Оптакс. Как упоминалось в начале, Optax — это пакет goto для обработки градиентов.
Во-первых, вот некоторые вещи, которые вам нужно знать об Optax:
Ключевым преобразованием Optax является GradientTransformation
. Преобразование определяется двумя функциями: __init__
и __update__
. __init__
инициализирует состояние и __update__
преобразует градиенты относительно состояния и текущего значения параметров
state = init(params)
grads, state = update(grads, state, params=None)
Еще одна вещь, которую нужно знать, прежде чем мы увидим код, это встроенный в Python functools.partial
функция. functools
package имеет дело с функциями и операциями высшего порядка над вызываемыми объектами.
Функция называется функцией высшего порядка, если она содержит другие функции в качестве параметра или возвращает функцию в качестве вывода.
partial
, который также можно использовать в качестве аннотации, возвращает новую функцию, основанную на исходной, но с меньшим количеством или фиксированными аргументами. Если, например, f умножает два значения x, y, партиал создаст новую функцию, где x будет фиксированным и равным 2.
from functools import partial
def f(x,y):
return x * y
g = partial(f,2)
print(g(4))
После этого короткого обхода давайте продолжим. Чтобы разгрузить наш main
мы извлечем обновление градиентов в отдельный класс.
прежде всего GradientUpdater
принимает модель, функцию потерь и оптимизатор.
- Модель будет чистой
forward_fn
функция, преобразованнаяhk.transform
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
- Функция потерь будет результатом частичного с фиксированным
forward_fn
и `vocab_size
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
- Оптимизатор представляет собой набор оптимизационных преобразований, которые будут выполняться последовательно (операции можно комбинировать с помощью
optax.chain
)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
Средство обновления Gradient будет инициализировано следующим образом:
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
и будет выглядеть так:
class GradientUpdater:
"""A stateless abstraction around an init_fn/update_fn pair.
This extracts some common boilerplate from the training loop.
"""
def __init__(self, net_init, loss_fn,
optimizer: optax.GradientTransformation):
self._net_init = net_init
self._loss_fn = loss_fn
self._opt = optimizer
@functools.partial(jax.jit, static_argnums=0)
def init(self, master_rng, data):
"""Initializes state of the updater."""
out_rng, init_rng = jax.random.split(master_rng)
params = self._net_init(init_rng, data)
opt_state = self._opt.init(params)
out = dict(
step=np.array(0),
rng=out_rng,
opt_state=opt_state,
params=params,
)
return out
@functools.partial(jax.jit, static_argnums=0)
def update(self, state: Mapping(str, Any), data: Mapping(str, jnp.ndarray)):
"""Updates the state using some data and returns metrics."""
rng, new_rng = jax.random.split(state('rng'))
params = state('params')
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
updates, opt_state = self._opt.update(g, state('opt_state'))
params = optax.apply_updates(params, updates)
new_state = {
'step': state('step') + 1,
'rng': new_rng,
'opt_state': opt_state,
'params': params,
}
metrics = {
'step': state('step'),
'loss': loss,
}
return new_state, metrics
Внутри __init__
мы инициализируем наш оптимизатор с помощью self._opt.init(params)
и мы объявляем состояние оптимизации. Состояние будет словарем с:
update
будет обновлять как состояние оптимизатора, так и обучаемые параметры. В конце концов, он вернет новое состояние.
updates, opt_state = self._opt.update(g, state('opt_state'))
params = optax.apply_updates(params, updates)
Еще две вещи, на которые следует обратить внимание:
-
jax.value_and_grad()
это специальная функция который возвращает дифференцируемую функцию с ее градиентами -
Оба
__init__
и__update__
аннотированы@functools.partial(jax.jit, static_argnums=0)
, который вызовет JIT-компилятор и скомпилирует их в XLA во время выполнения. Обратите внимание, что если мы не преобразовалиforward_fn
в чистую функцию, это было бы невозможно.
Наконец, мы готовы построить весь цикл обучения, который объединяет все идеи и код, упомянутые до сих пор.
def main():
train_dataset, vocab_size = load(batch_size,
sequence_length)
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
logging.info('Initializing parameters...')
rng = jax.random.PRNGKey(428)
data = next(train_dataset)
state = updater.init(rng, data)
logging.info('Starting train loop...')
prev_time = time.time()
for step in range(MAX_STEPS):
data = next(train_dataset)
state, metrics = updater.update(state, data)
Обратите внимание, как мы включаем GradientUpdate
. Это всего две строчки кода:
-
state = updater.init(rng, data)
-
state, metrics = updater.update(state, data)
Вот и все. Я надеюсь, что теперь у вас есть более четкое представление о JAX и его возможностях.
Благодарности
Представленный код сильно вдохновлен официальными примерами фреймворка Haiku. Он был изменен, чтобы соответствовать потребностям этой статьи. Полный список примеров см. официальный репозиторий
Заключение
В этой статье мы увидели, как можно разработать и обучить ванильного трансформера в JAX с помощью Haiku. Хотя код не всегда сложен для понимания, ему все же не хватает читабельности Pytorch или Tensorflow. Я настоятельно рекомендую поиграть с ним, открыть для себя сильные и слабые стороны JAX и посмотреть, подойдет ли он для вашего следующего проекта. По моему опыту, JAX очень силен для исследовательских приложений, требующих высокой производительности, но совершенно незрел для реальных проектов. Дайте нам знать, что вы думаете в нашем Дискорд-канал.
* Раскрытие информации: обратите внимание, что некоторые из приведенных выше ссылок могут быть партнерскими ссылками, и мы без дополнительных затрат для вас получим комиссию, если вы решите совершить покупку после перехода по ссылке.