Перейти к основному содержимому

Логистическая регрессия в Java

· 7 мин. чтения

1. Введение

Логистическая регрессия является важным инструментом в наборе инструментов для практиков машинного обучения (ML).

В этом уроке мы рассмотрим основную идею логистической регрессии .

Во-первых, давайте начнем с краткого обзора парадигм и алгоритмов машинного обучения.

2. Обзор

Машинное обучение позволяет нам решать проблемы, которые мы можем сформулировать в понятной для человека терминологии. Однако этот факт может представлять собой проблему для нас, разработчиков программного обеспечения. Мы привыкли решать проблемы, которые можем сформулировать в компьютерных терминах. Например, как люди, мы можем легко определить объекты на фотографии или определить настроение фразы. Как мы могли сформулировать такую задачу для компьютера?

Для того, чтобы придумать решение, в ML есть специальный этап, называемый обучением . На этом этапе мы передаем входные данные нашему алгоритму, чтобы он попытался подобрать оптимальный набор параметров (так называемые веса). Чем больше входных данных мы можем передать алгоритму, тем более точных прогнозов мы можем ожидать от него.

Обучение является частью итеративного рабочего процесса машинного обучения:

./8f5c495555c24ac9770676279bcfb959.png

Начнем с получения данных. Часто данные поступают из разных источников. Поэтому мы должны сделать его одного формата. Мы также должны контролировать, чтобы набор данных справедливо представлял область исследования. Если модель никогда не обучалась на красных яблоках, она вряд ли сможет это предсказать.

Затем мы должны построить модель, которая будет использовать данные и сможет делать прогнозы. В ML нет заранее определенных моделей, которые хорошо работают во всех ситуациях.

При поиске правильной модели может легко случиться так, что мы создадим модель, обучим ее, посмотрим ее предсказания и отбросим модель, потому что нас не устраивают сделанные ею предсказания. В этом случае мы должны сделать шаг назад, построить другую модель и повторить процесс снова.

3. Парадигмы машинного обучения

В ML, исходя из того, какие исходные данные мы имеем в нашем распоряжении, мы можем выделить три основные парадигмы:

  • контролируемое обучение (классификация изображений, распознавание объектов, анализ настроений)
  • неконтролируемое обучение (обнаружение аномалий)
  • обучение с подкреплением (игровые стратегии)

Случай, который мы собираемся описать в этом руководстве, относится к обучению с учителем.

4. Набор инструментов машинного обучения

В ML есть набор инструментов, которые мы можем применять при построении модели. Упомянем некоторые из них:

  • Линейная регрессия
  • Логистическая регрессия
  • Нейронные сети
  • Машина опорных векторов
  • k-ближайшие соседи

Мы можем комбинировать несколько инструментов при построении модели с высокой прогнозируемостью. Фактически, для этого урока наша модель будет использовать логистическую регрессию и нейронные сети.

5. Библиотеки машинного обучения

Несмотря на то, что Java не является самым популярным языком для создания прототипов моделей машинного обучения, `` он имеет репутацию надежного инструмента для создания надежного программного обеспечения во многих областях, включая машинное обучение. Поэтому мы можем найти библиотеки ML, написанные на Java.

В этом контексте мы можем упомянуть де-факто стандартную библиотеку Tensorflow , которая также имеет версию для Java. Еще стоит упомянуть библиотеку глубокого обучения под названием Deeplearning4j . Это очень мощный инструмент, и мы также будем использовать его в этом уроке.

6. Логистическая регрессия при распознавании цифр

Основная идея логистической регрессии заключается в построении модели, которая максимально точно предсказывает метки входных данных.

Мы обучаем модель до тех пор, пока так называемая функция потерь или целевая функция не достигнет некоторого минимального значения. Функция потерь зависит от фактических прогнозов модели и ожидаемых (меток входных данных). Наша цель — свести к минимуму расхождение фактических прогнозов модели и ожидаемых.

Если нас не устраивает это минимальное значение, мы должны построить другую модель и снова выполнить обучение.

Чтобы увидеть логистическую регрессию в действии, мы проиллюстрируем ее на распознавании рукописных цифр. Эта проблема уже стала классической. В библиотеке Deeplearning4j есть ряд реалистичных примеров , показывающих, как использовать ее API. Часть этого руководства, связанная с кодом, в значительной степени основана на классификаторе MNIST .

6.1. Входные данные

В качестве входных данных используется известная база рукописных цифр MNIST . В качестве входных данных у нас есть изображения 28×28 пикселей в оттенках серого. Каждое изображение имеет естественную метку, которая представляет собой цифру, которую представляет изображение:

./f54af7531b1848ae9a77aa24ab384d5f.png

Чтобы оценить эффективность модели, которую мы собираемся построить, разобьем входные данные на обучающую и тестовую выборки:

DataSetIterator train = new RecordReaderDataSetIterator(...);
DataSetIterator test = new RecordReaderDataSetIterator(...);

После того, как мы разметили входные изображения и разделили их на два набора, этап «обработки данных» завершен, и мы можем перейти к «построению модели».

6.2. Построение модели

Как мы уже упоминали, нет моделей, которые хорошо работают в любой ситуации. Тем не менее, после многих лет исследований в области машинного обучения ученые нашли модели, которые очень хорошо распознают рукописные цифры. Здесь мы используем так называемую модель LeNet-5 .

LeNet-5 — это нейронная сеть, состоящая из ряда слоев, которые преобразуют изображение размером 28×28 пикселей в десятимерный вектор:

./c9715458d0f1de0a09447c9a9061713c.png

Десятимерный выходной вектор содержит вероятности того, что метка входного изображения равна 0, 1, 2 и так далее.

Например, если выходной вектор имеет следующий вид:

{0.1, 0.0, 0.3, 0.2, 0.1, 0.1, 0.0, 0.1, 0.1, 0.0}

это означает, что вероятность того, что входное изображение будет нулевым, равно 0,1, единице — 0, двум — 0,3 и т. д. Мы видим, что максимальная вероятность (0,3) соответствует метке 3.

Давайте углубимся в детали построения модели. Мы опускаем подробности, относящиеся к Java, и концентрируемся на концепциях ML.

Мы настраиваем модель, создавая объект MultiLayerNetwork :

MultiLayerNetwork model = new MultiLayerNetwork(config);

В его конструктор мы должны передать объект MultiLayerConfiguration . Это тот самый объект, который описывает геометрию нейронной сети. Чтобы определить геометрию сети, мы должны определить каждый слой.

Давайте покажем, как мы это делаем с первым и вторым:

ConvolutionLayer layer1 = new ConvolutionLayer
.Builder(5, 5).nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build();
SubsamplingLayer layer2 = new SubsamplingLayer
.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build();

Мы видим, что определения слоев содержат значительное количество специальных параметров, которые существенно влияют на производительность сети в целом. Именно здесь наша способность найти хорошую модель среди всех становится решающей.

Теперь мы готовы создать объект MultiLayerConfiguration :

MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
// preparation steps
.list()
.layer(layer1)
.layer(layer2)
// other layers and final steps
.build();

который мы передаем конструктору MultiLayerNetwork .

6.3. Подготовка

Построенная нами модель содержит 431080 параметров или весов. Мы не собираемся приводить здесь точное вычисление этого числа, но мы должны знать, что только первый слой имеет более 24x24x20 = 11520 весов.

Стадия обучения проста:

model.fit(train);

Изначально параметры 431080 имеют некоторые случайные значения, но после обучения они приобретают некоторые значения, определяющие производительность модели. Мы можем оценить предсказательность модели:

Evaluation eval = model.evaluate(test);
logger.info(eval.stats());

Модель LeNet-5 достигает довольно высокой точности почти 99% даже всего за одну обучающую итерацию (эпоху). Если мы хотим добиться более высокой точности, мы должны сделать больше итераций, используя простой цикл for :

for (int i = 0; i < epochs; i++) {
model.fit(train);
train.reset();
test.reset();
}

6.4. Прогноз

Теперь, когда мы обучили модель и довольны ее предсказаниями на тестовых данных, мы можем попробовать модель на каких-то абсолютно новых входных данных. Для этого давайте создадим новый класс MnistPrediction, в который мы будем загружать изображение из файла, который мы выбираем из файловой системы:

INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
new ImagePreProcessingScaler(0, 1).transform(image);

Переменное изображение содержит наше изображение, уменьшенное до 28×28 оттенков серого. Мы можем скормить его нашей модели:

INDArray output = model.output(image);

Выходная переменная будет содержать вероятности того, что изображение будет равно нулю, единице, двум и т. д.

Давайте теперь немного поиграем и напишем цифру 2, оцифруем это изображение и скормим его модели. Мы можем получить что-то вроде этого:

./1a4abc3be2cb7101c9b897132576cca1.png

Как видим, компонент с максимальным значением 0,99 имеет индекс два. Это означает, что модель правильно распознала нашу рукописную цифру.

7. Заключение

В этом уроке мы описали общие концепции машинного обучения. Мы проиллюстрировали эти концепции на примере логистической регрессии, который мы применили к распознаванию рукописных цифр.

Как всегда, мы можем найти соответствующие фрагменты кода в нашем репозитории GitHub .