- 25 февраля 2017 г.
- Василис Вриниотис
- . 3 комментария
Алгоритм ALS, представленный Ху и др., — очень популярный метод, используемый в задачах системы рекомендаций, особенно когда у нас есть неявные наборы данных (например, клики, лайки и т. д.). Он может достаточно хорошо обрабатывать большие объемы данных, и мы можем найти много хороших реализаций в различных средах машинного обучения. Spark включает алгоритм в компонент MLlib, который недавно был рефакторинг для улучшения читабельности и архитектуры кода.
Реализация Spark требует, чтобы Item и User id были числами в пределах целочисленного диапазона (целочисленного типа или Long в пределах целочисленного диапазона), что разумно, поскольку это может помочь ускорить операции и снизить потребление памяти. Одна вещь, которую я заметил при чтении кода, заключается в том, что эти столбцы id преобразуются в Doubles, а затем в Integer в начале методов fit/predict. Это кажется немного хакерским, и я видел, что это создает ненужную нагрузку на сборщик мусора. Вот строки на код АЛС которые превращают идентификаторы в двойники:
Чтобы понять, почему это сделано, нужно прочитать checkedCast():
Эта UDF получает Double и проверяет его диапазон, а затем приводит его к целому числу. Этот UDF используется для проверки схемы. Вопрос в том, можем ли мы добиться этого, не используя уродливые двойные приведения? я верю да:
protected val checkedCast = udf { (n: Any) => n match { case v: Int => v // Avoid unnecessary casting case v: Number => val intV = v.intValue() // True for Byte/Short, Long within the Int range and Double/Float with no fractional part. if (v.doubleValue == intV) { intV } else { throw new IllegalArgumentException(s"ALS only supports values in Integer range " + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") } case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.") } }
В приведенном выше коде показан модифицированный метод checkedCast(), который получает входные данные, проверяет, является ли значение числовым, и вызывает исключения в противном случае. Поскольку вводом является Any, мы можем безопасно удалить все операторы приведения к Double из остального кода. Более того, разумно ожидать, что, поскольку ALS требует идентификаторов в диапазоне целых чисел, большинство людей на самом деле используют целые типы. В результате в строке 3 этот метод явно обрабатывает целые числа, чтобы избежать приведения типов. Для всех других числовых значений он проверяет, находится ли ввод в диапазоне целых чисел. Эта проверка происходит в строке 7.
Можно написать это по-другому и явно обрабатывать все разрешенные типы. К сожалению, это приведет к дублированию кода. Вместо этого я здесь преобразовываю число в целое число и сравниваю его с исходным числом. Если значения идентичны, верно одно из следующего:
- Значение — Byte или Short.
- Значение длинное, но находится в диапазоне целых чисел.
- Значение — Double или Float, но без дробной части.
Чтобы убедиться, что код работает хорошо, я протестировал его с помощью стандартных модульных тестов Spark и вручную, проверив поведение метода для различных допустимых и недопустимых значений. Чтобы убедиться, что решение работает как минимум так же быстро, как и оригинал, я много раз тестировал его, используя приведенный ниже фрагмент кода. Это можно разместить в АЛСюит-класс в Спарке:
test("Speed difference") { val (training, test) = genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01) val runs = 100 var totalTime = 0.0 println("Performing "+runs+" runs") for(i <- 0 until runs) { val t0 = System.currentTimeMillis testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1) val secs = (System.currentTimeMillis - t0)/1000.0 println("Run "+i+" executed in "+secs+"s") totalTime += secs } println("AVG Execution Time: "+(totalTime/runs)+"s") }
После нескольких тестов мы видим, что новое исправление немного быстрее оригинала:
Код |
Количество запусков |
Общее время выполнения |
Среднее время выполнения за прогон |
Оригинал | 100 | 588,458 с | 5,88458 с |
Зафиксированный | 100 | 566,722 с | 5,66722 с |
Я повторил эксперименты несколько раз, чтобы подтвердить, и результаты согласуются. Здесь вы можете найти подробные результаты одного эксперимента для исходный код и исправить. Разница невелика для крошечного набора данных, но в прошлом мне удавалось добиться заметного сокращения накладных расходов на сборщик мусора с помощью этого исправления. Мы можем подтвердить это, запустив Spark локально и подключив профилировщик Java к экземпляру Spark. я открыл билет и Пулл-реквест в официальном репозитории Spark но поскольку неясно, будет ли он объединен, я решил поделиться им здесь с вами. и теперь он является частью Spark 2.2.
Любые мысли, комментарии или критика приветствуются! 🙂