Инженеры DeepMind ускоряют наши исследования, создавая инструменты, масштабируя алгоритмы и создавая сложные виртуальные и физические миры для обучения и тестирования систем искусственного интеллекта (ИИ). В рамках этой работы мы постоянно оцениваем новые библиотеки и фреймворки для машинного обучения.
В последнее время мы обнаружили, что все больше проектов хорошо обслуживаются ДЖАКСинфраструктура машинного обучения, разработанная Исследования Google команды. JAX хорошо согласуется с нашей инженерной философией и за последний год получил широкое распространение в нашем исследовательском сообществе. Здесь мы делимся нашим опытом работы с JAX, объясняем, почему мы считаем его полезным для наших исследований в области ИИ, и даем обзор экосистемы, которую мы создаем для поддержки исследователей во всем мире.
Почему ДЖАКС?
JAX — это библиотека Python, предназначенная для высокопроизводительных численных вычислений, особенно для исследований в области машинного обучения. Его API для числовых функций основан на NumPy, набор функций, используемых в научных вычислениях. И Python, и NumPy широко используются и знакомы, что делает JAX простым, гибким и легким в использовании.
В дополнение к NumPy API JAX включает в себя расширяемую систему составные преобразования функций которые помогают поддерживать исследования в области машинного обучения, в том числе:
- Дифференциация: Оптимизация на основе градиента является фундаментальной для ML. JAX изначально поддерживает как прямой, так и обратный режим. автоматическая дифференциация произвольных числовых функций с помощью преобразований функций, таких как grad, hessian, jacfwd и jacrev.
- Векторизация: В исследованиях машинного обучения мы часто применяем одну функцию к большому количеству данных, например, вычисляем потери в партии или оценка градиентов для каждого примера для дифференцированно частного обучения. JAX обеспечивает автоматическую векторизацию посредством преобразования vmap, которое упрощает эту форму программирования. Например, исследователям не нужно думать о пакетной обработке при реализации новых алгоритмов. JAX также поддерживает крупномасштабный параллелизм данных посредством связанного преобразования pmap, элегантно распределяя данные, которые слишком велики для памяти одного ускорителя.
- JIT-компиляция: XLA используется для своевременной (JIT) компиляции и выполнения программ JAX на GPU и Облачный ТПУ ускорители. JIT-компиляция вместе с NumPy-совместимым API JAX позволяет исследователям, не имевшим опыта работы с высокопроизводительными вычислениями, легко масштабироваться до одного или нескольких ускорителей.
Мы обнаружили, что JAX позволяет быстро экспериментировать с новыми алгоритмами и архитектурами, и теперь он лежит в основе многих наших недавних публикаций. Чтобы узнать больше, присоединяйтесь к нашему круглому столу JAX, который состоится в среду, 9 декабря, в 19:00 по Гринвичу, на НейриПС виртуальная конференция.
JAX в DeepMind
Поддержка современных исследований в области искусственного интеллекта означает сочетание быстрого прототипирования и быстрой итерации с возможностью развертывания экспериментов в масштабах, традиционно связанных с производственными системами. Что делает такие проекты особенно сложными, так это то, что исследовательская среда быстро развивается и ее трудно прогнозировать. В любой момент новый исследовательский прорыв может и регулярно меняет траекторию и требования целых команд. В этом постоянно меняющемся ландшафте основная обязанность нашей инженерной команды состоит в том, чтобы убедиться, что извлеченные уроки и код, написанный для одного исследовательского проекта, эффективно используются в следующем.
Одним из подходов, доказавших свою эффективность, является модульность: мы извлекаем наиболее важные и критически важные строительные блоки, разработанные в каждом исследовательском проекте, в хорошо проверенные и эффективные компоненты. Это позволяет исследователям сосредоточиться на своих исследованиях, а также получать выгоду от повторного использования кода, исправления ошибок и повышения производительности алгоритмических компонентов, реализованных в наших основных библиотеках. Мы также обнаружили, что важно убедиться, что каждая библиотека имеет четко определенную область действия, а также обеспечить их совместимость, но независимость. Дополнительный бай-инспособность выбирать функции, не привязываясь к другим, имеет решающее значение для обеспечения максимальной гибкости для исследователей и постоянной поддержки их в выборе правильного инструмента для работы.
Другие соображения, которые учитывались при разработке нашей экосистемы JAX, включают в себя обеспечение ее соответствия (где это возможно) дизайну нашей существующей ТензорФлоу библиотеки (например Сонет и TRFL). Мы также стремились создавать компоненты, которые (там, где это уместно) максимально точно соответствуют лежащей в их основе математике, чтобы они были самоописательными и сводили к минимуму мыслительные переходы «от бумаги к коду». Наконец, мы решили Открытый исходный код наши библиотеки для облегчения обмена результатами исследований и поощрения более широкого сообщества к изучению экосистемы JAX.
Наша экосистема сегодня
Хайку
Модель программирования JAX компонуемых преобразований функций может усложнить работу с объектами с состоянием, например нейронными сетями с обучаемыми параметрами. Haiku — это библиотека нейронных сетей, которая позволяет пользователям использовать знакомые модели объектно-ориентированного программирования, одновременно используя мощь и простоту чисто функциональной парадигмы JAX.
Haiku активно используется сотнями исследователей из DeepMind и Google и уже нашла применение в нескольких внешних проектах (например, Коаксиальный, ДипХим, NumPyro). Он основан на API для Сонетнашу модульную модель программирования для нейронных сетей в TensorFlow, и мы стремились максимально упростить перенос из Sonnet в Haiku.
Оптакс
Оптимизация на основе градиента является фундаментальной для ML. Optax предоставляет библиотеку преобразований градиента вместе с операторами композиции (например, цепочкой), которые позволяют реализовать множество стандартных оптимизаторов (например, RMSProp или Adam) всего в одной строке кода.
Композиционная природа Optax естественным образом поддерживает повторное комбинирование одних и тех же основных ингредиентов в пользовательских оптимизаторах. Кроме того, он предлагает ряд утилит для оценки стохастического градиента и оптимизации второго порядка.
Многие пользователи Optax переняли Haiku, но в соответствии с нашей философией постепенного приобретения поддерживается любая библиотека, представляющая параметры в виде древовидной структуры JAX (например, Элегия, Лен и Стакс). Пожалуйста, посмотри здесь для получения дополнительной информации об этой богатой экосистеме библиотек JAX.
Рлакс
Многие из наших самых успешных проектов находятся на стыке глубокого обучения и обучения с подкреплением (RL), также известного как глубокое обучение с подкреплением. RLax — это библиотека, которая предоставляет полезные стандартные блоки для создания агентов RL.
Компоненты RLax охватывают широкий спектр алгоритмов и идей: TD-обучение, градиенты политики, критика акторов, MAP, проксимальная оптимизация политики, нелинейное преобразование значений, общие функции значений и ряд методов исследования.
Хотя некоторые вводные пример агентов предоставляются, RLax не предназначен для создания и развертывания полных систем агентов RL. Одним из примеров полнофункциональной среды агентов, основанной на компонентах RLax, является Акме.
Чекс
Тестирование имеет решающее значение для надежности программного обеспечения, и исследовательский код не является исключением. Чтобы делать научные выводы из исследовательских экспериментов, нужно быть уверенным в правильности своего кода. Chex — это набор утилит для тестирования, используемых авторами библиотек для проверки правильности и надежности общих строительных блоков, а конечными пользователями — для проверки своего экспериментального кода.
Chex предоставляет набор утилит, включая модульное тестирование с поддержкой JAX, утверждения свойств типов данных JAX, имитации и подделки, а также среды тестирования с несколькими устройствами. Chex используется во всей экосистеме DeepMind JAX и внешними проектами, такими как Коаксиальный и Минерл.
Джраф
Граф нейронных сетей (GNN) — захватывающая область исследований со многими многообещающими приложениями. См., например, нашу недавнюю работу по предсказание трафика в Картах Google и нашей работе над моделирование физики. Jraph (произносится как «жираф») — это облегченная библиотека для поддержки работы с GNN в JAX.
Jraph предоставляет стандартизированную структуру данных для графов, набор утилит для работы с графами и «зоопарк» легко разветвляемых и расширяемых графовых нейросетевых моделей. Другие ключевые функции включают в себя: группирование GraphTuples, которые эффективно используют аппаратные ускорители, поддержку JIT-компиляции графов переменной формы с помощью заполнения и маскирования, а также потери, определенные для входных разделов. Подобно Optax и другим нашим библиотекам, Jraph не накладывает ограничений на выбор пользователем библиотеки нейронной сети.
Узнайте больше об использовании библиотеки из нашей богатой коллекции Примеры.
Наша экосистема JAX постоянно развивается, и мы призываем исследовательское сообщество машинного обучения исследовать наши библиотеки и потенциал JAX для ускорения собственных исследований.