Мне было очень любопытно посмотреть, как JAX сравнивается с Pytorch или Tensorflow. Я решил, что лучший способ для кого-то сравнить фреймворки — создать одно и то же с нуля в обоих из них. И это именно то, что я сделал. В этой статье я разрабатываю вариационный автоэнкодер с JAX, Tensorflow и Pytorch одновременно. Я представлю код для каждого компонента рядом, чтобы найти различия, сходства, слабые и сильные стороны.
Начнем?
Пролог
Некоторые вещи, на которые следует обратить внимание, прежде чем мы приступим к изучению кода:
-
я буду использовать Лен поверх JAX, который представляет собой библиотеку нейронной сети, разработанную Google. Он содержит множество готовых к использованию модулей глубокого обучения, слоев, функций и операций.
-
Для реализации Tensorflow я буду полагаться на Керас абстракции.
-
Для Pytorch я буду использовать стандартный
nn.module
.
Поскольку большинство из нас немного знакомы с Tensorflow и Pytorch, мы уделим больше внимания JAX и Flax. Вот почему я буду объяснять вещи по пути, которые могут быть незнакомы многим. Так что вы можете рассматривать эту статью как легкое руководство по Flax.
Кроме того, я предполагаю, что вы знакомы с основными принципами VAE. Если нет, то можете посоветовать мою предыдущую статью о моделях скрытых переменных. Если вроде все понятно, продолжим.
Краткое резюме: Ванильный автоэнкодер состоит из кодировщика и декодера. Кодер преобразует ввод в скрытое представление и декодер пытается восстановить ввод на основе этого представления. В вариационных автоэнкодерах к смеси также добавляется стохастичность в том смысле, что скрытое представление обеспечивает распределение вероятностей. Это происходит с трюком репараметризации.
Изображение автора
Кодер
Для кодировщика простого линейного слоя с последующей активацией RELU должно быть достаточно для игрушечного примера. Результатом слоя будет как среднее значение, так и стандартное отклонение распределения вероятностей.
Основным строительным блоком Flax API является Module
абстракция, которую мы будем использовать для реализации нашего кодировщика в JAX. module
является частью linen
подпакет. Похож на Pytorch nn.module
, нам снова нужно определить аргументы нашего класса. В Pytorch мы привыкли объявлять их внутри __init__
функцию и реализацию прямого прохода внутри forward
метод. Во льне дела обстоят немного иначе. Аргументы определяются либо как атрибуты класса данных, либо как аргументы метода. Обычно фиксированные свойства определяются как аргументы класса данных, а динамические свойства — как аргументы метода. Также вместо реализации forward
метод, мы реализуем __call__
Модуль класса данных появился в Python 3.7 как служебный инструмент для создания структурированных классов, специально предназначенных для хранения данных. Эти классы содержат определенные свойства и функции для работы с данными и их представлением. Они также сокращают объем шаблонного кода по сравнению с обычными классами.
Итак, чтобы создать новый модуль во Flax, нам нужно:
-
Инициализировать класс, который наследует
flax.linen.nn.Module
-
Определите статические аргументы как аргументы класса данных
-
Реализуйте прямой проход внутри
__call_
метод.
Чтобы связать аргументы с моделью и иметь возможность определять подмодули непосредственно внутри модуля, нам также необходимо аннотировать __call__
метод с @nn.compact
.
Обратите внимание, что вместо использования аргументов класса данных и @nn.compact
аннотацию, мы могли бы объявить все аргументы внутри setup
точно так же, как в Pytorch или Tensorflow. __init__
.
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from flax import optim
class Encoder(nn.Module):
latents: int
@nn.compact
def __call__(self, x):
x = nn.Dense(500, name='fc1')(x)
x = nn.relu(x)
mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
return mean_x, logvar_x
import tensorflow as tf
from tensorflow.keras import layers
class Encoder(layers.Layer):
def __init__(self,
latent_dim =20,
name='encoder',
**kwargs):
super(Encoder, self).__init__(name=name, **kwargs)
self.enc1 = layers.Dense(500, activation='relu')
self.mean_x = layers.Dense(latent_dim)
self.logvar_x = layers.Dense(latent_dim)
def call(self, inputs):
x = self.enc1(inputs)
z_mean = self.mean_x(x)
z_log_var = self.logvar_x(x)
return z_mean, z_log_var
import torch
import torch.nn.functional as F
class Encoder(torch.nn.Module):
def __init__(self, latent_dim=20):
super(Encoder, self).__init__()
self.enc1 = torch.nn.Linear(784, 500)
self.mean_x = torch.nn.Linear(500,latent_dim)
self.logvar_x = torch.nn.Linear(500, latent_dim)
def forward(self,inputs):
x = self.enc1(inputs)
x= F.relu(x)
z_mean = self.mean_x(x)
z_log_var = self.logvar_x(x)
return z_mean, z_log_var
Еще несколько вещей, на которые следует обратить внимание, прежде чем мы продолжим:
-
льна
nn.linen
пакет содержит большинство уровней глубокого обучения и операции, такие какDense
,relu
и многое другое -
Код во Flax, Tensorflow и Pytorch практически неотличим друг от друга.
Декодер
Очень похожим образом мы можем разработать декодер во всех трех фреймворках. Декодер будет представлять собой два линейных слоя, получающих скрытое представление и выведите восстановленный ввод.
Опять же, реализации очень похожи.
class Decoder(nn.Module):
@nn.compact
def __call__(self, z):
z = nn.Dense(500, name='fc1')(z)
z = nn.relu(z)
z = nn.Dense(784, name='fc2')(z)
return z
class Decoder(layers.Layer):
def __init__(self,
name='decoder',
**kwargs):
super(Decoder, self).__init__(name=name, **kwargs)
self.dec1 = layers.Dense(500, activation='relu')
self.out = layers.Dense(784)
def call(self, z):
z = self.dec1(z)
return self.out(z)
class Decoder(torch.nn.Module):
def __init__(self, latent_dim=20):
super(Decoder, self).__init__()
self.dec1 = torch.nn.Linear(latent_dim, 500)
self.out = torch.nn.Linear(500, 784)
def forward(self,z):
z = self.dec1(z)
z = F.relu(z)
return self.out(z)
Вариационный автоэнкодер
Чтобы объединить кодировщик и декодер, давайте создадим еще один класс, называемый VAE
, который будет представлять всю архитектуру. Здесь нам также нужно написать код для трюка с репараметризацией. В целом имеем: скрытая переменная из кодировщика перепараметрируется и подается в декодер, который производит реконструированный ввод.
Напоминаем, вот интуитивное изображение, объясняющее трюк с репараметризацией:
Обратите внимание, что на этот раз в JAX мы используем setup
метод вместо nn.compact
аннотация. Кроме того, проверьте, насколько похожи функции репараметризации. Конечно, каждый фреймворк использует свои собственные функции и операции, но общий образ почти идентичен.
class VAE(nn.Module):
latents: int = 20
def setup(self):
self.encoder = Encoder(self.latents)
self.decoder = Decoder()
def __call__(self, x, z_rng):
mean, logvar = self.encoder(x)
z = reparameterize(z_rng, mean, logvar)
recon_x = self.decoder(z)
return recon_x, mean, logvar
def reparameterize(rng, mean, logvar):
std = jnp.exp(0.5 * logvar)
eps = random.normal(rng, logvar.shape)
return mean + eps * std
def model():
return VAE(latents=LATENTS)
class VAE(tf.keras.Model):
def __init__(self,
latent_dim=20,
name='vae',
**kwargs):
super(VAE, self).__init__(name=name, **kwargs)
self.encoder = Encoder(latent_dim=latent_dim)
self.decoder = Decoder()
def call(self, inputs):
z_mean, z_log_var = self.encoder(inputs)
z = self.reparameterize(z_mean, z_log_var)
reconstructed = self.decoder(z)
return reconstructed, z_mean, z_log_var
def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return mean + eps * tf.exp(logvar * .5)
class VAE(torch.nn.Module):
def __init__(self, latent_dim=20):
super(VAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def forward(self,inputs):
z_mean, z_log_var = self.encoder(inputs)
z = self.reparameterize(z_mean, z_log_var)
reconstructed = self.decoder(z)
return reconstructed, z_mean, z_log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + (eps * std)
Потеря и шаг обучения
Все начинает меняться, когда мы начинаем реализовывать этап обучения и функцию потерь. Но не намного.
-
Чтобы в полной мере использовать возможности JAX, нам нужно добавить в наш код автоматическую векторизацию и XLA-компиляцию. Это можно легко сделать с помощью
vmap
иjit
аннотации. -
Более того, мы должны включить автоматическую дифференциацию, которую можно выполнить с помощью
grad_fn
трансформация -
Мы используем
flax.optim
пакет алгоритмов оптимизации
Еще одно небольшое отличие, о котором нам нужно знать, заключается в том, как мы передаем данные в нашу модель. Это может быть достигнуто с помощью метода apply в виде model().apply({'params': params}, batch, z_rng)
где batch
наши тренировочные данные.
@jax.vmap
def kl_divergence(mean, logvar):
return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))
@jax.jit
def train_step(optimizer, batch, z_rng):
def loss_fn(params):
recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
kld_loss = kl_divergence(mean, logvar).mean()
loss = bce_loss + kld_loss
return loss, recon_x
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
_, grad = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return optimizer
def kl_divergence(mean, logvar):
return -0.5 * tf.reduce_sum(
1 + logvar - tf.square(mean) -
tf.exp(logvar), axis=1)
def binary_cross_entropy_with_logits(logits, labels):
logits = tf.math.log(logits)
return - tf.reduce_sum(
labels * logits +
(1-labels) * tf.math.log(- tf.math.expm1(logits)),
axis=1
)
@tf.function
def train_step(model, x, optimizer):
with tf.GradientTape() as tape:
recon_x, mean, logvar = model(x)
bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))
kld_loss = tf.reduce_mean(kl_divergence(mean, logvar))
loss = bce_loss + kld_loss
print(loss, kld_loss, bce_loss)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
def final_loss(reconstruction, train_x, mu, logvar):
BCE = torch.nn.BCEWithLogitsLoss(reduction='sum')(reconstruction, train_x)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train_step(train_x):
train_x = torch.from_numpy(train_x)
optimizer.zero_grad()
reconstruction, mu, logvar = model(train_x)
loss = final_loss(reconstruction, train_x, mu, logvar)
running_loss += loss.item()
loss.backward()
optimizer.step()
Помните, что VAE обучаются путем максимизации нижней границы доказательств, известной как ELBO.
Тренировочный цикл
Наконец, пришло время для всего цикла обучения, который будет выполнять train_step
функционировать итеративно.
Во Flax модель должна быть инициализирована перед обучением, которое выполняется init
функция, такая как: params = model().init(key, init_data, rng)('params')
. Аналогичная инициализация необходима и для оптимизатора: optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params )
.
jax.device_put
используется для переноса оптимизатора в память графического процессора.
rng = random.PRNGKey(0)
rng, key = random.split(rng)
init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)
params = model().init(key, init_data, rng)('params')
optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)
optimizer = jax.device_put(optimizer)
rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, LATENTS))
steps_per_epoch = 50000 // BATCH_SIZE
for epoch in range(NUM_EPOCHS):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
optimizer = train_step(optimizer, batch, key)
vae = VAE(latent_dim=LATENTS)
optimizer = tf.keras.optimizers.Adam(1e-4)
for epoch in range(NUM_EPOCHS):
for train_x in train_ds:
train_step(vae, train_x, optimizer)
def train(model,training_data):
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
running_loss = 0.0
for epoch in range(NUM_EPOCHS):
for i, train_x in enumerate(training_data, 0):
train_step(train_x)
vae = VAE(LATENTS)
train(vae, train_ds)
Загрузка и обработка данных
Одна вещь, которую я не упомянул, это данные. Как мы загружаем и предварительно обрабатываем данные во Flax? Что ж, Flax еще не включает в себя пакеты для обработки данных, кроме базовых операций jax.numpy
. Сейчас лучше всего заимствовать пакеты из других фреймворков, таких как наборы данных Tensorflow (tfds) или Torchvision. Чтобы сделать статью самодостаточной, я включу код, который использовал для загрузки примера обучающего набора данных. tfds
. Не стесняйтесь использовать свой собственный загрузчик данных, если вы планируете запускать реализации, представленные в этой статье.
import tensorflow_datasets as tfds
tf.config.experimental.set_visible_devices((), 'GPU')
def prepare_image(x):
x = tf.cast(x('image'), tf.float32)
x = tf.reshape(x, (-1,))
return x
ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))
test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)(0))
Заключительные наблюдения
Чтобы закрыть статью, давайте обсудим несколько заключительных наблюдений, которые появляются после тщательного анализа кода:
-
Все 3 фреймворка сократили шаблонный код до минимума, а Flax требует немного больше, особенно в части обучения. Однако это делается только для того, чтобы гарантировать, что мы используем все доступные преобразования, такие как автоматическое дифференцирование, векторизация и компилятор «точно в срок».
-
Определение модулей, слоев и моделей во всех них практически идентично.
-
Flax и JAX по дизайну достаточно гибкие и расширяемые.
-
Flax пока не имеет возможности загрузки и обработки данных
-
Что касается готовых слоев и оптимизаторов, Flax не нужно завидовать Tensorflow и Pytorch. Конечно, ему не хватает такой гигантской библиотеки, как у его конкурентов, но он постепенно добирается до нее.
* Раскрытие информации: обратите внимание, что некоторые из приведенных выше ссылок могут быть партнерскими ссылками, и мы без дополнительных затрат для вас получим комиссию, если вы решите совершить покупку после перехода по ссылке.