FIGS (быстро интерпретируемые суммы жадного дерева): Метод построения интерпретируемых моделей путем одновременного выращивания множества деревьев решений, конкурирующих друг с другом.
Недавние достижения в области машинного обучения привели к созданию все более сложных прогностических моделей, часто за счет интерпретируемости. Нам часто нужна интерпретируемость, особенно в приложениях с высокими ставками, таких как принятие клинических решений; интерпретируемые модели помогают во многих вещах, таких как выявление ошибок, использование знаний предметной области и быстрое прогнозирование.
В этом сообщении блога мы рассмотрим ФИГновый метод установки интерпретируемая модель который принимает форму суммы деревьев. Реальные эксперименты и теоретические результаты показывают, что FIGS может эффективно адаптироваться к широкому спектру структур данных, достигая передовой производительности в различных условиях без ущерба для интерпретируемости.
Как работает ФИС?
Интуитивно FIGS работает, расширяя CART, типичный жадный алгоритм для выращивания дерева решений, чтобы учитывать рост сумма деревьев одновременно (см. рис. 1). На каждой итерации FIGS может вырастить любое существующее дерево, которое уже запущено, или начать новое дерево; он жадно выбирает то правило, которое больше всего уменьшает общую необъяснимую дисперсию (или альтернативный критерий разделения). Чтобы деревья были синхронизированы друг с другом, каждое дерево предсказывает остатки оставшиеся после суммирования предсказаний всех остальных деревьев (см. бумага Больше подробностей).
FIGS интуитивно похож на ансамблевые подходы, такие как повышение градиента/случайный лес, но, что важно, поскольку все деревья выращиваются, чтобы конкурировать друг с другом, модель может больше адаптироваться к базовой структуре данных. Количество деревьев и размер/форма каждого дерева появляются автоматически из данных, а не задаются вручную.
Рисунок 1. Высокоуровневая интуиция о том, как FIGS соответствует модели.
Пример использования FIGS
Использование FIGS чрезвычайно просто. Он легко устанавливается через пакет imodels (pip install imodels
), а затем их можно использовать так же, как стандартные модели scikit-learn: просто импортируйте классификатор или регрессор и используйте fit
и predict
методы. Вот полный пример его использования в образце набора клинических данных, целью которого является риск травмы шейного отдела позвоночника (CSI).
from imodels import FIGSClassifier, get_clean_dataset
from sklearn.model_selection import train_test_split
# prepare data (in this a sample clinical dataset)
X, y, feat_names = get_clean_dataset('csi_pecarn_pred')
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
# fit the model
model = FIGSClassifier(max_rules=4) # initialize a model
model.fit(X_train, y_train) # fit model
preds = model.predict(X_test) # discrete predictions: shape is (n_test, 1)
preds_proba = model.predict_proba(X_test) # predicted probabilities: shape is (n_test, n_classes)
# visualize the model
model.plot(feature_names=feat_names, filename='out.svg', dpi=300)
В результате получается простая модель – в ней всего 4 разбиения (поскольку мы указали, что в модели должно быть не более 4 разбиений (max_rules=4
). Прогнозы делаются путем сбрасывания выборки на каждое дерево и подведение итогов значения корректировки риска, полученные из результирующих листьев каждого дерева. Эта модель чрезвычайно интерпретируема, так как теперь врач может (i) легко делать прогнозы, используя 4 соответствующих функции, и (ii) проверять модель, чтобы убедиться, что она соответствует его опыту в предметной области. Обратите внимание, что эта модель предназначена только для иллюстрации и обеспечивает точность ~84\%.
Рис 2. Простая модель, изученная FIGS, для прогнозирования риска травмы шейного отдела позвоночника.
Если нам нужна более гибкая модель, мы также можем снять ограничение на количество правил (изменив код на model = FIGSClassifier()
), что привело к увеличению модели (см. рис. 3). Обратите внимание, что количество деревьев и то, насколько они сбалансированы, зависят от структуры данных — можно указать только общее количество правил.
Рис 3. Немного более крупная модель, изученная FIGS для прогнозирования риска травмы шейного отдела позвоночника.
Насколько хорошо работает FIGS?
Во многих случаях, когда требуется интерпретируемость, например, моделирование правил клинического принятия решений, FIGS может достичь самой современной производительности. Например, на рис. 4 показаны различные наборы данных, в которых FIGS достигает отличной производительности, особенно при ограничении использованием очень небольшого количества разбиений.
Рис 4. FIGS дает хорошие прогнозы с очень небольшим количеством расщеплений.
Почему FIGS работает хорошо?
FIGS мотивирован наблюдением, что отдельные деревья решений часто имеют разбиения, которые повторяются в разных ветвях, что может произойти, когда существует аддитивная структура в данных. Наличие нескольких деревьев помогает избежать этого, разделяя аддитивные компоненты на отдельные деревья.
Заключение
В целом, интерпретируемое моделирование предлагает альтернативу обычному моделированию черного ящика и во многих случаях может предложить значительные улучшения с точки зрения эффективности и прозрачности без потери производительности.
Этот пост основан на двух документах: ФИГ и G-фиг. – весь код доступен через пакет imodels. Это совместная работа с Киян Нассери, Абхинит Агарвал, Джеймс Дункан, Омер Ронени Аарон Корнблит.