Home Искусственный интеллект JAX против Tensorflow против Pytorch: создание вариационного автоэнкодера (VAE) | DeepTech

JAX против Tensorflow против Pytorch: создание вариационного автоэнкодера (VAE) | DeepTech

0
JAX против Tensorflow против Pytorch: создание вариационного автоэнкодера (VAE)
 | DeepTech

Мне было очень любопытно посмотреть, как JAX сравнивается с Pytorch или Tensorflow. Я решил, что лучший способ для кого-то сравнить фреймворки — создать одно и то же с нуля в обоих из них. И это именно то, что я сделал. В этой статье я разрабатываю вариационный автоэнкодер с JAX, Tensorflow и Pytorch одновременно. Я представлю код для каждого компонента рядом, чтобы найти различия, сходства, слабые и сильные стороны.

Начнем?

Пролог

Некоторые вещи, на которые следует обратить внимание, прежде чем мы приступим к изучению кода:

  • я буду использовать Лен поверх JAX, который представляет собой библиотеку нейронной сети, разработанную Google. Он содержит множество готовых к использованию модулей глубокого обучения, слоев, функций и операций.

  • Для реализации Tensorflow я буду полагаться на Керас абстракции.

  • Для Pytorch я буду использовать стандартный nn.module.

Поскольку большинство из нас немного знакомы с Tensorflow и Pytorch, мы уделим больше внимания JAX и Flax. Вот почему я буду объяснять вещи по пути, которые могут быть незнакомы многим. Так что вы можете рассматривать эту статью как легкое руководство по Flax.

Кроме того, я предполагаю, что вы знакомы с основными принципами VAE. Если нет, то можете посоветовать мою предыдущую статью о моделях скрытых переменных. Если вроде все понятно, продолжим.

Краткое резюме: Ванильный автоэнкодер состоит из кодировщика и декодера. Кодер преобразует ввод в скрытое представление гг и декодер пытается восстановить ввод на основе этого представления. В вариационных автоэнкодерах к смеси также добавляется стохастичность в том смысле, что скрытое представление обеспечивает распределение вероятностей. Это происходит с трюком репараметризации.


ваэ


Изображение автора

Кодер

Для кодировщика простого линейного слоя с последующей активацией RELU должно быть достаточно для игрушечного примера. Результатом слоя будет как среднее значение, так и стандартное отклонение распределения вероятностей.

Основным строительным блоком Flax API является Module абстракция, которую мы будем использовать для реализации нашего кодировщика в JAX. module является частью linen подпакет. Похож на Pytorch nn.module, нам снова нужно определить аргументы нашего класса. В Pytorch мы привыкли объявлять их внутри __init__ функцию и реализацию прямого прохода внутри forward метод. Во льне дела обстоят немного иначе. Аргументы определяются либо как атрибуты класса данных, либо как аргументы метода. Обычно фиксированные свойства определяются как аргументы класса данных, а динамические свойства — как аргументы метода. Также вместо реализации forward метод, мы реализуем __call__

Модуль класса данных появился в Python 3.7 как служебный инструмент для создания структурированных классов, специально предназначенных для хранения данных. Эти классы содержат определенные свойства и функции для работы с данными и их представлением. Они также сокращают объем шаблонного кода по сравнению с обычными классами.

Итак, чтобы создать новый модуль во Flax, нам нужно:

  • Инициализировать класс, который наследует flax.linen.nn.Module

  • Определите статические аргументы как аргументы класса данных

  • Реализуйте прямой проход внутри __call_ метод.

Чтобы связать аргументы с моделью и иметь возможность определять подмодули непосредственно внутри модуля, нам также необходимо аннотировать __call__ метод с @nn.compact.

Обратите внимание, что вместо использования аргументов класса данных и @nn.compact аннотацию, мы могли бы объявить все аргументы внутри setup точно так же, как в Pytorch или Tensorflow. __init__.

import numpy as np

import jax

import jax.numpy as jnp

from jax import random

from flax import linen as nn

from flax import optim

class Encoder(nn.Module):

latents: int

@nn.compact

def __call__(self, x):

x = nn.Dense(500, name='fc1')(x)

x = nn.relu(x)

mean_x = nn.Dense(self.latents, name='fc2_mean')(x)

logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)

return mean_x, logvar_x

import tensorflow as tf

from tensorflow.keras import layers

class Encoder(layers.Layer):

def __init__(self,

latent_dim =20,

name='encoder',

**kwargs):

super(Encoder, self).__init__(name=name, **kwargs)

self.enc1 = layers.Dense(500, activation='relu')

self.mean_x = layers.Dense(latent_dim)

self.logvar_x = layers.Dense(latent_dim)

def call(self, inputs):

x = self.enc1(inputs)

z_mean = self.mean_x(x)

z_log_var = self.logvar_x(x)

return z_mean, z_log_var

import torch

import torch.nn.functional as F

class Encoder(torch.nn.Module):

def __init__(self, latent_dim=20):

super(Encoder, self).__init__()

self.enc1 = torch.nn.Linear(784, 500)

self.mean_x = torch.nn.Linear(500,latent_dim)

self.logvar_x = torch.nn.Linear(500, latent_dim)

def forward(self,inputs):

x = self.enc1(inputs)

x= F.relu(x)

z_mean = self.mean_x(x)

z_log_var = self.logvar_x(x)

return z_mean, z_log_var

Еще несколько вещей, на которые следует обратить внимание, прежде чем мы продолжим:

  • льна nn.linen пакет содержит большинство уровней глубокого обучения и операции, такие как Dense, reluи многое другое

  • Код во Flax, Tensorflow и Pytorch практически неотличим друг от друга.

Декодер

Очень похожим образом мы можем разработать декодер во всех трех фреймворках. Декодер будет представлять собой два линейных слоя, получающих скрытое представление гг и выведите восстановленный ввод.

Опять же, реализации очень похожи.

class Decoder(nn.Module):

@nn.compact

def __call__(self, z):

z = nn.Dense(500, name='fc1')(z)

z = nn.relu(z)

z = nn.Dense(784, name='fc2')(z)

return z

class Decoder(layers.Layer):

def __init__(self,

name='decoder',

**kwargs):

super(Decoder, self).__init__(name=name, **kwargs)

self.dec1 = layers.Dense(500, activation='relu')

self.out = layers.Dense(784)

def call(self, z):

z = self.dec1(z)

return self.out(z)

class Decoder(torch.nn.Module):

def __init__(self, latent_dim=20):

super(Decoder, self).__init__()

self.dec1 = torch.nn.Linear(latent_dim, 500)

self.out = torch.nn.Linear(500, 784)

def forward(self,z):

z = self.dec1(z)

z = F.relu(z)

return self.out(z)

Вариационный автоэнкодер

Чтобы объединить кодировщик и декодер, давайте создадим еще один класс, называемый VAE, который будет представлять всю архитектуру. Здесь нам также нужно написать код для трюка с репараметризацией. В целом имеем: скрытая переменная из кодировщика перепараметрируется и подается в декодер, который производит реконструированный ввод.

Напоминаем, вот интуитивное изображение, объясняющее трюк с репараметризацией:


репараметризация-трюк


Источник: Александр Амини и Ава Солеймани, Глубокое генеративное моделирование | Массачусетский технологический институт 6.S191, http://introtodeeplearning.com/

Обратите внимание, что на этот раз в JAX мы используем setup метод вместо nn.compact аннотация. Кроме того, проверьте, насколько похожи функции репараметризации. Конечно, каждый фреймворк использует свои собственные функции и операции, но общий образ почти идентичен.

class VAE(nn.Module):

latents: int = 20

def setup(self):

self.encoder = Encoder(self.latents)

self.decoder = Decoder()

def __call__(self, x, z_rng):

mean, logvar = self.encoder(x)

z = reparameterize(z_rng, mean, logvar)

recon_x = self.decoder(z)

return recon_x, mean, logvar

def reparameterize(rng, mean, logvar):

std = jnp.exp(0.5 * logvar)

eps = random.normal(rng, logvar.shape)

return mean + eps * std

def model():

return VAE(latents=LATENTS)

class VAE(tf.keras.Model):

def __init__(self,

latent_dim=20,

name='vae',

**kwargs):

super(VAE, self).__init__(name=name, **kwargs)

self.encoder = Encoder(latent_dim=latent_dim)

self.decoder = Decoder()

def call(self, inputs):

z_mean, z_log_var = self.encoder(inputs)

z = self.reparameterize(z_mean, z_log_var)

reconstructed = self.decoder(z)

return reconstructed, z_mean, z_log_var

def reparameterize(self, mean, logvar):

eps = tf.random.normal(shape=mean.shape)

return mean + eps * tf.exp(logvar * .5)

class VAE(torch.nn.Module):

def __init__(self, latent_dim=20):

super(VAE, self).__init__()

self.encoder = Encoder(latent_dim)

self.decoder = Decoder(latent_dim)

def forward(self,inputs):

z_mean, z_log_var = self.encoder(inputs)

z = self.reparameterize(z_mean, z_log_var)

reconstructed = self.decoder(z)

return reconstructed, z_mean, z_log_var

def reparameterize(self, mu, log_var):

std = torch.exp(0.5 * log_var)

eps = torch.randn_like(std)

return mu + (eps * std)

Потеря и шаг обучения

Все начинает меняться, когда мы начинаем реализовывать этап обучения и функцию потерь. Но не намного.

  1. Чтобы в полной мере использовать возможности JAX, нам нужно добавить в наш код автоматическую векторизацию и XLA-компиляцию. Это можно легко сделать с помощью vmap и jit аннотации.

  2. Более того, мы должны включить автоматическую дифференциацию, которую можно выполнить с помощью grad_fn трансформация

  3. Мы используем flax.optim пакет алгоритмов оптимизации

Еще одно небольшое отличие, о котором нам нужно знать, заключается в том, как мы передаем данные в нашу модель. Это может быть достигнуто с помощью метода apply в виде model().apply({'params': params}, batch, z_rng)где batch наши тренировочные данные.

@jax.vmap

def kl_divergence(mean, logvar):

return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap

def binary_cross_entropy_with_logits(logits, labels):

logits = nn.log_sigmoid(logits)

return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))

@jax.jit

def train_step(optimizer, batch, z_rng):

def loss_fn(params):

recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)

bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()

kld_loss = kl_divergence(mean, logvar).mean()

loss = bce_loss + kld_loss

return loss, recon_x

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

_, grad = grad_fn(optimizer.target)

optimizer = optimizer.apply_gradient(grad)

return optimizer

def kl_divergence(mean, logvar):

return -0.5 * tf.reduce_sum(

1 + logvar - tf.square(mean) -

tf.exp(logvar), axis=1)

def binary_cross_entropy_with_logits(logits, labels):

logits = tf.math.log(logits)

return - tf.reduce_sum(

labels * logits +

(1-labels) * tf.math.log(- tf.math.expm1(logits)),

axis=1

)

@tf.function

def train_step(model, x, optimizer):

with tf.GradientTape() as tape:

recon_x, mean, logvar = model(x)

bce_loss = tf.reduce_mean(binary_cross_entropy_with_logits(recon_x, batch))

kld_loss = tf.reduce_mean(kl_divergence(mean, logvar))

loss = bce_loss + kld_loss

print(loss, kld_loss, bce_loss)

gradients = tape.gradient(loss, model.trainable_variables)

optimizer.apply_gradients(zip(gradients, model.trainable_variables))

def final_loss(reconstruction, train_x, mu, logvar):

BCE = torch.nn.BCEWithLogitsLoss(reduction='sum')(reconstruction, train_x)

KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

return BCE + KLD

def train_step(train_x):

train_x = torch.from_numpy(train_x)

optimizer.zero_grad()

reconstruction, mu, logvar = model(train_x)

loss = final_loss(reconstruction, train_x, mu, logvar)

running_loss += loss.item()

loss.backward()

optimizer.step()

Помните, что VAE обучаются путем максимизации нижней границы доказательств, известной как ELBO.

лθ,ф(Икс)“=”Едф(гИкс)(логпθ(Иксг))КЛ(дф(гИкс)пθ(г))L _ {\ theta, \ phi} (x) = \ textbf {E} _ {q _ {\ phi} (z | x)} ( log p _ {\ theta} (x | z)) – \ textbf {KL} ( q _ {\ phi} (z | x) || p _ {\ theta} (z))

Тренировочный цикл

Наконец, пришло время для всего цикла обучения, который будет выполнять train_step функционировать итеративно.

Во Flax модель должна быть инициализирована перед обучением, которое выполняется init функция, такая как: params = model().init(key, init_data, rng)('params'). Аналогичная инициализация необходима и для оптимизатора: optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params ).

jax.device_put используется для переноса оптимизатора в память графического процессора.

rng = random.PRNGKey(0)

rng, key = random.split(rng)

init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)

params = model().init(key, init_data, rng)('params')

optimizer = optim.Adam(learning_rate=LEARNING_RATE).create(params)

optimizer = jax.device_put(optimizer)

rng, z_key, eval_rng = random.split(rng, 3)

z = random.normal(z_key, (64, LATENTS))

steps_per_epoch = 50000 // BATCH_SIZE

for epoch in range(NUM_EPOCHS):

for _ in range(steps_per_epoch):

batch = next(train_ds)

rng, key = random.split(rng)

optimizer = train_step(optimizer, batch, key)

vae = VAE(latent_dim=LATENTS)

optimizer = tf.keras.optimizers.Adam(1e-4)

for epoch in range(NUM_EPOCHS):

for train_x in train_ds:

train_step(vae, train_x, optimizer)

def train(model,training_data):

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

running_loss = 0.0

for epoch in range(NUM_EPOCHS):

for i, train_x in enumerate(training_data, 0):

train_step(train_x)

vae = VAE(LATENTS)

train(vae, train_ds)

Загрузка и обработка данных

Одна вещь, которую я не упомянул, это данные. Как мы загружаем и предварительно обрабатываем данные во Flax? Что ж, Flax еще не включает в себя пакеты для обработки данных, кроме базовых операций jax.numpy. Сейчас лучше всего заимствовать пакеты из других фреймворков, таких как наборы данных Tensorflow (tfds) или Torchvision. Чтобы сделать статью самодостаточной, я включу код, который использовал для загрузки примера обучающего набора данных. tfds. Не стесняйтесь использовать свой собственный загрузчик данных, если вы планируете запускать реализации, представленные в этой статье.

import tensorflow_datasets as tfds

tf.config.experimental.set_visible_devices((), 'GPU')

def prepare_image(x):

x = tf.cast(x('image'), tf.float32)

x = tf.reshape(x, (-1,))

return x

ds_builder = tfds.builder('binarized_mnist')

ds_builder.download_and_prepare()

train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)

train_ds = train_ds.map(prepare_image)

train_ds = train_ds.cache()

train_ds = train_ds.repeat()

train_ds = train_ds.shuffle(50000)

train_ds = train_ds.batch(BATCH_SIZE)

train_ds = iter(tfds.as_numpy(train_ds))

test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)

test_ds = test_ds.map(prepare_image).batch(10000)

test_ds = np.array(list(test_ds)(0))

Заключительные наблюдения

Чтобы закрыть статью, давайте обсудим несколько заключительных наблюдений, которые появляются после тщательного анализа кода:

  • Все 3 фреймворка сократили шаблонный код до минимума, а Flax требует немного больше, особенно в части обучения. Однако это делается только для того, чтобы гарантировать, что мы используем все доступные преобразования, такие как автоматическое дифференцирование, векторизация и компилятор «точно в срок».

  • Определение модулей, слоев и моделей во всех них практически идентично.

  • Flax и JAX по дизайну достаточно гибкие и расширяемые.

  • Flax пока не имеет возможности загрузки и обработки данных

  • Что касается готовых слоев и оптимизаторов, Flax не нужно завидовать Tensorflow и Pytorch. Конечно, ему не хватает такой гигантской библиотеки, как у его конкурентов, но он постепенно добирается до нее.

Книга «Глубокое обучение в производстве» 📖

Узнайте, как создавать, обучать, развертывать, масштабировать и поддерживать модели глубокого обучения. Изучите инфраструктуру машинного обучения и MLOps на практических примерах.

Узнать больше

* Раскрытие информации: обратите внимание, что некоторые из приведенных выше ссылок могут быть партнерскими ссылками, и мы без дополнительных затрат для вас получим комиссию, если вы решите совершить покупку после перехода по ссылке.

LEAVE A REPLY

Please enter your comment!
Please enter your name here