Home Машинное обучение Детализация алгоритма рекомендаций ALS Spark | DeepTech

Детализация алгоритма рекомендаций ALS Spark | DeepTech

0
Детализация алгоритма рекомендаций ALS Spark
 | DeepTech
  • 25 февраля 2017 г.
  • Василис Вриниотис
  • . 3 комментария

Алгоритм ALS, представленный Ху и др., — очень популярный метод, используемый в задачах системы рекомендаций, особенно когда у нас есть неявные наборы данных (например, клики, лайки и т. д.). Он может достаточно хорошо обрабатывать большие объемы данных, и мы можем найти много хороших реализаций в различных средах машинного обучения. Spark включает алгоритм в компонент MLlib, который недавно был рефакторинг для улучшения читабельности и архитектуры кода.

Реализация Spark требует, чтобы Item и User id были числами в пределах целочисленного диапазона (целочисленного типа или Long в пределах целочисленного диапазона), что разумно, поскольку это может помочь ускорить операции и снизить потребление памяти. Одна вещь, которую я заметил при чтении кода, заключается в том, что эти столбцы id преобразуются в Doubles, а затем в Integer в начале методов fit/predict. Это кажется немного хакерским, и я видел, что это создает ненужную нагрузку на сборщик мусора. Вот строки на код АЛС которые превращают идентификаторы в двойники:

Чтобы понять, почему это сделано, нужно прочитать checkedCast():

Эта UDF получает Double и проверяет его диапазон, а затем приводит его к целому числу. Этот UDF используется для проверки схемы. Вопрос в том, можем ли мы добиться этого, не используя уродливые двойные приведения? я верю да:


  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

В приведенном выше коде показан модифицированный метод checkedCast(), который получает входные данные, проверяет, является ли значение числовым, и вызывает исключения в противном случае. Поскольку вводом является Any, мы можем безопасно удалить все операторы приведения к Double из остального кода. Более того, разумно ожидать, что, поскольку ALS требует идентификаторов в диапазоне целых чисел, большинство людей на самом деле используют целые типы. В результате в строке 3 этот метод явно обрабатывает целые числа, чтобы избежать приведения типов. Для всех других числовых значений он проверяет, находится ли ввод в диапазоне целых чисел. Эта проверка происходит в строке 7.

Можно написать это по-другому и явно обрабатывать все разрешенные типы. К сожалению, это приведет к дублированию кода. Вместо этого я здесь преобразовываю число в целое число и сравниваю его с исходным числом. Если значения идентичны, верно одно из следующего:

  1. Значение — Byte или Short.
  2. Значение длинное, но находится в диапазоне целых чисел.
  3. Значение — Double или Float, но без дробной части.

Чтобы убедиться, что код работает хорошо, я протестировал его с помощью стандартных модульных тестов Spark и вручную, проверив поведение метода для различных допустимых и недопустимых значений. Чтобы убедиться, что решение работает как минимум так же быстро, как и оригинал, я много раз тестировал его, используя приведенный ниже фрагмент кода. Это можно разместить в АЛСюит-класс в Спарке:



  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

После нескольких тестов мы видим, что новое исправление немного быстрее оригинала:

Код

Количество запусков

Общее время выполнения

Среднее время выполнения за прогон

Оригинал 100 588,458 с 5,88458 с
Зафиксированный 100 566,722 с 5,66722 с

Я повторил эксперименты несколько раз, чтобы подтвердить, и результаты согласуются. Здесь вы можете найти подробные результаты одного эксперимента для исходный код и исправить. Разница невелика для крошечного набора данных, но в прошлом мне удавалось добиться заметного сокращения накладных расходов на сборщик мусора с помощью этого исправления. Мы можем подтвердить это, запустив Spark локально и подключив профилировщик Java к экземпляру Spark. я открыл билет и Пулл-реквест в официальном репозитории Spark но поскольку неясно, будет ли он объединен, я решил поделиться им здесь с вами. и теперь он является частью Spark 2.2.

Любые мысли, комментарии или критика приветствуются! 🙂

LEAVE A REPLY

Please enter your comment!
Please enter your name here