Home Машинное обучение Точная настройка моделей Whisper на Amazon SageMaker с помощью LoRA | DeepTech

Точная настройка моделей Whisper на Amazon SageMaker с помощью LoRA | DeepTech

0
Точная настройка моделей Whisper на Amazon SageMaker с помощью LoRA
 | DeepTech

Whisper — это модель автоматического распознавания речи (ASR), которая была обучена с использованием 680 000 часов контролируемых данных из Интернета, охватывающих широкий спектр языков и задач. Одним из его ограничений является низкая производительность на языках с ограниченными ресурсами, таких как язык маратхи и дравидийские языки, которую можно исправить с помощью тонкой настройки. Однако точная настройка модели Whisper стала серьезной проблемой, как с точки зрения вычислительных ресурсов, так и требований к памяти. Пять-десять прогонов полной тонкой настройки моделей Whisper требуют примерно 100 часов работы графического процессора A100 (40 ГБ SXM4) (зависит от размеров и параметров модели), а для каждой точной настройки требуется около 7 ГБ дискового пространства. Такое сочетание высоких требований к вычислениям и хранению данных может создавать серьезные препятствия, особенно в средах с ограниченными ресурсами, что часто делает достижение значимых результатов исключительно трудным.

Адаптация низкого ранга, также известная как ЛоРА, использует уникальный подход к точной настройке модели. Он поддерживает предварительно обученные веса модели в статическом состоянии и вводит обучаемые матрицы ранговой декомпозиции на каждый уровень структуры Transformer. Этот метод может уменьшить количество обучаемых параметров, необходимых для последующих задач, в 10 000 раз и снизить требования к памяти графического процессора в 3 раза. Было показано, что с точки зрения качества модели LoRA соответствует или даже превосходит производительность традиционных методов точной настройки, несмотря на то, что работает с меньшим количеством обучаемых параметров (см. результаты оригинального метода). бумага лора). Это также дает преимущество увеличения производительности обучения. в отличие от адаптер Методы LoRA не вводят дополнительную задержку во время вывода, тем самым сохраняя эффективность модели на этапе развертывания. Точная настройка Whisper с использованием LoRA показала многообещающие результаты. Возьмем, к примеру, Whisper-Large-v2: обработка 3 эпох с 12-часовым общим голосовым набором данных на графическом процессоре с памятью 8 ГБ занимает 6–8 часов.что в 5 раз быстрее, чем полная тонкая настройка при сопоставимой производительности.

Amazon SageMaker — идеальная платформа для реализации тонкой настройки Whisper в соответствии с LoRA. Amazon SageMaker позволяет создавать, обучать и развертывать модели машинного обучения для любого сценария использования с полностью управляемой инфраструктурой, инструментами и рабочими процессами. Дополнительные преимущества обучения моделей могут включать снижение затрат на обучение с помощью управляемого спотового обучения, распределенных библиотек обучения для разделения моделей и наборов обучающих данных по экземплярам графического процессора AWS и многое другое. Обученные модели SageMaker можно легко развернуть для вывода непосредственно в SageMaker. В этом посте мы представляем пошаговое руководство по реализации тонкой настройки LoRA в SageMaker. Исходный код, связанный с этой реализацией, можно найти на GitHub.

Подготовьте набор данных для тонкой настройки

Для задачи точной настройки мы используем малоресурсный язык маратхи. Используя Наборы данных «Обнимающее лицо» В библиотеке вы можете загрузить и разделить набор данных Common Voice на наборы данных для обучения и тестирования. См. следующий код:

from datasets import load_dataset, DatasetDict

language = "Marathi"
language_abbr = "mr"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"

common_voice = DatasetDict()
common_voice("train") = load_dataset(dataset_name, language_abbr, split="train+validation", use_auth_token=True)
common_voice("test") = load_dataset(dataset_name, language_abbr, split="test", use_auth_token=True)

Модель распознавания речи Whisper требует, чтобы аудиовходы были 16-битные целочисленные WAV-файлы со знаком, 16 кГц, моно. Поскольку набор данных Common Voice имеет частоту дискретизации 48 КБ, вам необходимо сначала понизить частоту дискретизации аудиофайлов. Затем вам нужно применить экстрактор функций Whisper к аудио, чтобы извлечь функции спектрограммы log-mel, и применить токенизатор Whisper к функциям в рамке, чтобы преобразовать каждое предложение в расшифровке в идентификатор токена. См. следующий код:

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)

def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch("audio")

# compute log-Mel input features from input audio array
batch("input_features") = feature_extractor(audio("array"), sampling_rate=audio("sampling_rate")).input_features(0)

# encode target text to label ids
batch("labels") = tokenizer(batch("sentence")).input_ids
return batch

#apply the data preparation function to all of our fine-tuning dataset samples using dataset's .map method.
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names("train"), num_proc=2)
common_voice.save_to_disk("marathi-common-voice-processed")
!aws s3 cp --recursive "marathi-common-voice-processed" s3://<Your-S3-Bucket>

После того как вы обработали все обучающие выборки, загрузите обработанные данные в Amazon S3, чтобы при использовании обработанных обучающих данных на этапе тонкой настройки можно было использовать FastFile для монтирования файла S3 напрямую, а не копировать его на локальный диск:

from sagemaker.inputs import TrainingInput
training_input_path=s3uri
training = TrainingInput(
s3_data_type="S3Prefix", # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
s3_data=training_input_path,
distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key
input_mode="FastFile"
)

Обучение модели

Для демонстрации мы используем шепот-large-v2 в качестве предварительно обученной модели (теперь доступен шепот v3), которую можно импортировать через библиотеку трансформеров Hugging Face. Вы можете использовать 8-битное квантование для дальнейшего повышения эффективности обучения. 8-битное квантование обеспечивает оптимизацию памяти путем округления от чисел с плавающей запятой до 8-битных целых чисел. Это широко используемый метод сжатия модели, позволяющий сэкономить за счет сокращения памяти, не слишком жертвуя при этом точностью вывода.

Чтобы загрузить предварительно обученную модель в 8-битном квантованном формате, мы просто добавляем аргумент load_in_8bit=True при создании экземпляра модели, как показано в следующем коде. Это загрузит веса модели, квантованные до 8 бит, что уменьшит объем памяти.

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")

Мы используем реализацию LoRA от Hugging Face. пефт упаковка. Для точной настройки модели с использованием LoRA необходимо выполнить четыре шага:

  1. Создайте экземпляр базовой модели (как мы это делали на последнем шаге).
  2. Создайте конфигурацию (LoraConfig), где определены параметры, специфичные для LoRA.
  3. Оберните базовую модель get_peft_model() чтобы получить обучаемый PeftModel.
  4. Тренируйте PeftModel в качестве базовой модели.

См. следующий код:

from peft import LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=64, target_modules=("q_proj", "v_proj"), lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)

training_args = Seq2SeqTrainingArguments(
output_dir=args.model_dir,
per_device_train_batch_size=int(args.train_batch_size),
gradient_accumulation_steps=1,
learning_rate=float(args.learning_rate),
warmup_steps=args.warmup_steps,
num_train_epochs=args.num_train_epochs,
evaluation_strategy="epoch",
fp16=True,
per_device_eval_batch_size=args.eval_batch_size,
generation_max_length=128,
logging_steps=25,
remove_unused_columns=False,
label_names=("labels"),
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset("train"),
eval_dataset=train_dataset.get("test", train_dataset("test")),
data_collator=data_collator,
tokenizer=processor.feature_extractor,
)

Для запуска учебного задания SageMaker мы используем собственный контейнер Docker. Вы можете скачать образ Docker по адресу GitHub, где ffmpeg4 и git-lfs упакованы вместе с другими требованиями Python. Дополнительные сведения о том, как адаптировать собственный контейнер Docker для работы с SageMaker, см. в разделе Адаптация собственного обучающего контейнера. Затем вы можете использовать Hugging Face Estimator и начать обучающее задание SageMaker:

OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'

huggingface_estimator = HuggingFace(entry_point="train.sh",
source_dir="./src",
output_path= OUTPUT_PATH,
instance_type=instance_type,
instance_count=1,
# transformers_version='4.17.0',
# pytorch_version='1.10.2',
py_version='py310',
image_uri=<ECR-PATH>,
role=ROLE,
metric_definitions = metric_definitions,
volume_size=200,
distribution=distribution,
keep_alive_period_in_seconds=1800,
environment=environment,
)

huggingface_estimator.fit(job_name=TRAINING_JOB_NAME, wait=False)

Реализация LoRA позволила нам запустить задачу тонкой настройки Whisper big на одном экземпляре графического процессора (например, ml.g5.2xlarge). Для сравнения, задача полной тонкой настройки Whisper big требует нескольких графических процессоров (например, ml.p4d.24xlarge) и гораздо большего времени обучения. В частности, наш эксперимент продемонстрировал, что полная задача тонкой настройки требует в 24 раза больше часов работы графического процессора по сравнению с подходом LoRA.

Оцените производительность модели

Чтобы оценить производительность точно настроенной модели Whisper, мы рассчитываем коэффициент ошибок в словах (WER) на проверенном наборе тестов. WER измеряет разницу между предсказанной расшифровкой и истинной расшифровкой. Более низкий WER указывает на лучшую производительность. Вы можете запустить следующий скрипт для предварительно обученной модели и точно настроенной модели и сравнить их разницу в WER:

metric = evaluate.load("wer")

eval_dataloader = DataLoader(common_voice("test"), batch_size=8, collate_fn=data_collator)

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
with torch.cuda.amp.autocast():
with torch.no_grad():
generated_tokens = (
model.generate(
input_features=batch("input_features").to("cuda"),
decoder_input_ids=batch("labels")(:, :4).to("cuda"),
max_new_tokens=255,
)
.cpu()
.numpy()
)
labels = batch("labels").cpu().numpy()
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
del generated_tokens, labels, batch
gc.collect()
wer = 100 * metric.compute()
print(f"{wer=}")

Заключение

В этом посте мы продемонстрировали тонкую настройку Whisper, современной модели распознавания речи. В частности, мы использовали PEFT LoRA от Hugging Face и включили 8-битное квантование для эффективного обучения. Мы также продемонстрировали, как запустить задание обучения в SageMaker.

Хотя это важный первый шаг, на основе этой работы можно использовать несколько способов дальнейшего улучшения модели шепота. В дальнейшем рассмотрите возможность использования распределенного обучения SageMaker для масштабирования обучения на гораздо большем наборе данных. Это позволит модели обучаться на более разнообразных и полных данных, повышая точность. Вы также можете оптимизировать задержку при использовании модели Whisper, чтобы обеспечить распознавание речи в реальном времени. Кроме того, вы можете расширить работу для обработки более длинных аудиотранскрипций, что потребует изменений в архитектуре модели и схемах обучения.

Подтверждение

Авторы выражают благодарность Парасу Мехре, Джону Солу и Эвандро Франко за их содержательные отзывы и рецензию на публикацию.


Об авторах

Цзюнь Ши — старший архитектор решений в Amazon Web Services (AWS). В настоящее время его сферой деятельности являются инфраструктура и приложения AI/ML. Он имеет более чем десятилетний опыт работы в сфере финансовых технологий в качестве инженера-программиста.

Доктор Чанша Ма — специалист по AI/ML в AWS. Она технолог с докторской степенью в области компьютерных наук, степенью магистра в области педагогической психологии и многолетним опытом работы в области науки о данных и независимого консультирования в области искусственного интеллекта и машинного обучения. Она увлечена исследованием методологических подходов к машинному и человеческому интеллекту. Вне работы она любит пешие походы, готовку, охоту за едой и проводить время с друзьями и семьей.

LEAVE A REPLY

Please enter your comment!
Please enter your name here