Home Технологии Использование JAX для ускорения наших исследований | DeepTech

Использование JAX для ускорения наших исследований | DeepTech

0
Использование JAX для ускорения наших исследований
 | DeepTech

Инженеры DeepMind ускоряют наши исследования, создавая инструменты, масштабируя алгоритмы и создавая сложные виртуальные и физические миры для обучения и тестирования систем искусственного интеллекта (ИИ). В рамках этой работы мы постоянно оцениваем новые библиотеки и фреймворки для машинного обучения.

В последнее время мы обнаружили, что все больше проектов хорошо обслуживаются ДЖАКСинфраструктура машинного обучения, разработанная Исследования Google команды. JAX хорошо согласуется с нашей инженерной философией и за последний год получил широкое распространение в нашем исследовательском сообществе. Здесь мы делимся нашим опытом работы с JAX, объясняем, почему мы считаем его полезным для наших исследований в области ИИ, и даем обзор экосистемы, которую мы создаем для поддержки исследователей во всем мире.

Почему ДЖАКС?

JAX — это библиотека Python, предназначенная для высокопроизводительных численных вычислений, особенно для исследований в области машинного обучения. Его API для числовых функций основан на NumPy, набор функций, используемых в научных вычислениях. И Python, и NumPy широко используются и знакомы, что делает JAX простым, гибким и легким в использовании.

В дополнение к NumPy API JAX включает в себя расширяемую систему составные преобразования функций которые помогают поддерживать исследования в области машинного обучения, в том числе:

  • Дифференциация: Оптимизация на основе градиента является фундаментальной для ML. JAX изначально поддерживает как прямой, так и обратный режим. автоматическая дифференциация произвольных числовых функций с помощью преобразований функций, таких как grad, hessian, jacfwd и jacrev.
  • Векторизация: В исследованиях машинного обучения мы часто применяем одну функцию к большому количеству данных, например, вычисляем потери в партии или оценка градиентов для каждого примера для дифференцированно частного обучения. JAX обеспечивает автоматическую векторизацию посредством преобразования vmap, которое упрощает эту форму программирования. Например, исследователям не нужно думать о пакетной обработке при реализации новых алгоритмов. JAX также поддерживает крупномасштабный параллелизм данных посредством связанного преобразования pmap, элегантно распределяя данные, которые слишком велики для памяти одного ускорителя.
  • JIT-компиляция: XLA используется для своевременной (JIT) компиляции и выполнения программ JAX на GPU и Облачный ТПУ ускорители. JIT-компиляция вместе с NumPy-совместимым API JAX позволяет исследователям, не имевшим опыта работы с высокопроизводительными вычислениями, легко масштабироваться до одного или нескольких ускорителей.

Мы обнаружили, что JAX позволяет быстро экспериментировать с новыми алгоритмами и архитектурами, и теперь он лежит в основе многих наших недавних публикаций. Чтобы узнать больше, присоединяйтесь к нашему круглому столу JAX, который состоится в среду, 9 декабря, в 19:00 по Гринвичу, на НейриПС виртуальная конференция.

JAX в DeepMind

Поддержка современных исследований в области искусственного интеллекта означает сочетание быстрого прототипирования и быстрой итерации с возможностью развертывания экспериментов в масштабах, традиционно связанных с производственными системами. Что делает такие проекты особенно сложными, так это то, что исследовательская среда быстро развивается и ее трудно прогнозировать. В любой момент новый исследовательский прорыв может и регулярно меняет траекторию и требования целых команд. В этом постоянно меняющемся ландшафте основная обязанность нашей инженерной команды состоит в том, чтобы убедиться, что извлеченные уроки и код, написанный для одного исследовательского проекта, эффективно используются в следующем.

Одним из подходов, доказавших свою эффективность, является модульность: мы извлекаем наиболее важные и критически важные строительные блоки, разработанные в каждом исследовательском проекте, в хорошо проверенные и эффективные компоненты. Это позволяет исследователям сосредоточиться на своих исследованиях, а также получать выгоду от повторного использования кода, исправления ошибок и повышения производительности алгоритмических компонентов, реализованных в наших основных библиотеках. Мы также обнаружили, что важно убедиться, что каждая библиотека имеет четко определенную область действия, а также обеспечить их совместимость, но независимость. Дополнительный бай-инспособность выбирать функции, не привязываясь к другим, имеет решающее значение для обеспечения максимальной гибкости для исследователей и постоянной поддержки их в выборе правильного инструмента для работы.

Другие соображения, которые учитывались при разработке нашей экосистемы JAX, включают в себя обеспечение ее соответствия (где это возможно) дизайну нашей существующей ТензорФлоу библиотеки (например Сонет и TRFL). Мы также стремились создавать компоненты, которые (там, где это уместно) максимально точно соответствуют лежащей в их основе математике, чтобы они были самоописательными и сводили к минимуму мыслительные переходы «от бумаги к коду». Наконец, мы решили Открытый исходный код наши библиотеки для облегчения обмена результатами исследований и поощрения более широкого сообщества к изучению экосистемы JAX.

Наша экосистема сегодня

Хайку

Модель программирования JAX компонуемых преобразований функций может усложнить работу с объектами с состоянием, например нейронными сетями с обучаемыми параметрами. Haiku — это библиотека нейронных сетей, которая позволяет пользователям использовать знакомые модели объектно-ориентированного программирования, одновременно используя мощь и простоту чисто функциональной парадигмы JAX.

Haiku активно используется сотнями исследователей из DeepMind и Google и уже нашла применение в нескольких внешних проектах (например, Коаксиальный, ДипХим, NumPyro). Он основан на API для Сонетнашу модульную модель программирования для нейронных сетей в TensorFlow, и мы стремились максимально упростить перенос из Sonnet в Haiku.

Узнайте больше на GitHub

Оптакс

Оптимизация на основе градиента является фундаментальной для ML. Optax предоставляет библиотеку преобразований градиента вместе с операторами композиции (например, цепочкой), которые позволяют реализовать множество стандартных оптимизаторов (например, RMSProp или Adam) всего в одной строке кода.

Композиционная природа Optax естественным образом поддерживает повторное комбинирование одних и тех же основных ингредиентов в пользовательских оптимизаторах. Кроме того, он предлагает ряд утилит для оценки стохастического градиента и оптимизации второго порядка.

Многие пользователи Optax переняли Haiku, но в соответствии с нашей философией постепенного приобретения поддерживается любая библиотека, представляющая параметры в виде древовидной структуры JAX (например, Элегия, Лен и Стакс). Пожалуйста, посмотри здесь для получения дополнительной информации об этой богатой экосистеме библиотек JAX.

Узнайте больше на GitHub

Рлакс

Многие из наших самых успешных проектов находятся на стыке глубокого обучения и обучения с подкреплением (RL), также известного как глубокое обучение с подкреплением. RLax — это библиотека, которая предоставляет полезные стандартные блоки для создания агентов RL.

Компоненты RLax охватывают широкий спектр алгоритмов и идей: TD-обучение, градиенты политики, критика акторов, MAP, проксимальная оптимизация политики, нелинейное преобразование значений, общие функции значений и ряд методов исследования.

Хотя некоторые вводные пример агентов предоставляются, RLax не предназначен для создания и развертывания полных систем агентов RL. Одним из примеров полнофункциональной среды агентов, основанной на компонентах RLax, является Акме.

Узнайте больше на GitHub

Чекс

Тестирование имеет решающее значение для надежности программного обеспечения, и исследовательский код не является исключением. Чтобы делать научные выводы из исследовательских экспериментов, нужно быть уверенным в правильности своего кода. Chex — это набор утилит для тестирования, используемых авторами библиотек для проверки правильности и надежности общих строительных блоков, а конечными пользователями — для проверки своего экспериментального кода.

Chex предоставляет набор утилит, включая модульное тестирование с поддержкой JAX, утверждения свойств типов данных JAX, имитации и подделки, а также среды тестирования с несколькими устройствами. Chex используется во всей экосистеме DeepMind JAX и внешними проектами, такими как Коаксиальный и Минерл.

Узнайте больше на GitHub

Джраф

Граф нейронных сетей (GNN) — захватывающая область исследований со многими многообещающими приложениями. См., например, нашу недавнюю работу по предсказание трафика в Картах Google и нашей работе над моделирование физики. Jraph (произносится как «жираф») — это облегченная библиотека для поддержки работы с GNN в JAX.

Jraph предоставляет стандартизированную структуру данных для графов, набор утилит для работы с графами и «зоопарк» легко разветвляемых и расширяемых графовых нейросетевых моделей. Другие ключевые функции включают в себя: группирование GraphTuples, которые эффективно используют аппаратные ускорители, поддержку JIT-компиляции графов переменной формы с помощью заполнения и маскирования, а также потери, определенные для входных разделов. Подобно Optax и другим нашим библиотекам, Jraph не накладывает ограничений на выбор пользователем библиотеки нейронной сети.

Узнайте больше об использовании библиотеки из нашей богатой коллекции Примеры.

Узнайте больше на GitHub

Наша экосистема JAX постоянно развивается, и мы призываем исследовательское сообщество машинного обучения исследовать наши библиотеки и потенциал JAX для ускорения собственных исследований.

LEAVE A REPLY

Please enter your comment!
Please enter your name here