Перейти к основному содержимому

Введение в Tensorflow для Java

· 9 мин. чтения

1. Обзор

TensorFlow — это библиотека с открытым исходным кодом для программирования потоков данных . Первоначально он был разработан Google и доступен для широкого спектра платформ. Хотя TensorFlow может работать на одном ядре, он также может легко получить выгоду от нескольких доступных процессоров, графических процессоров или TPU .

В этом руководстве мы рассмотрим основы TensorFlow и способы его использования в Java. Обратите внимание, что TensorFlow Java API — это экспериментальный API, поэтому на него не распространяется гарантия стабильности. Позже в этом руководстве мы рассмотрим возможные варианты использования TensorFlow Java API.

2. Основы

Вычисления TensorFlow в основном вращаются вокруг двух фундаментальных концепций: Graph и Session . Давайте быстро пройдемся по ним, чтобы получить фон, необходимый для прохождения остальной части урока.

2.1. График TensorFlow

Для начала давайте разберемся с основными строительными блоками программ TensorFlow. Вычисления представлены в виде графиков в TensorFlow . Граф обычно представляет собой ориентированный ациклический граф операций и данных, например:

[

./c5f2c08d3dac9c5a79b7003a2691da96.jpg

](/lessons/b/-wp-content-uploads-2019-03-TensorFlow-Graph-1-1-jpg)

На приведенном выше рисунке представлен расчетный график для следующего уравнения:

f(x, y) = z = a*x + b*y

Вычислительный граф TensorFlow состоит из двух элементов:

  1. Тензор: это основная единица данных в TensorFlow. Они представлены в виде ребер вычислительного графа, изображающего поток данных через граф. Тензор может иметь форму с любым количеством измерений. Количество измерений в тензоре обычно называют его рангом. Таким образом, скаляр — это тензор ранга 0, вектор — это тензор ранга 1, матрица — это тензор ранга 2 и так далее и тому подобное.
  2. Операция: это узлы вычислительного графа. Они относятся к широкому спектру вычислений, которые могут происходить с тензорами, входящими в операцию. Они также часто приводят к тензорам, которые возникают в результате операции в вычислительном графе.

2.2. Сессия TensorFlow

Теперь график TensorFlow — это просто схема вычислений, которая на самом деле не содержит значений. Такой граф должен быть запущен внутри так называемого сеанса TensorFlow для оценки тензоров в графе . Сеанс может принимать кучу тензоров для оценки из графика в качестве входных параметров. Затем он движется назад по графу и запускает все узлы, необходимые для оценки этих тензоров.

Обладая этими знаниями, мы теперь готовы применить их к Java API!

3. Настройка Мавена

Мы настроим быстрый проект Maven для создания и запуска графа TensorFlow в Java. Нам просто нужна зависимость тензорного потока :

<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.12.0</version>
</dependency>

4. Создание графика

Теперь давайте попробуем построить граф, который мы обсуждали в предыдущем разделе, используя Java API TensorFlow. Точнее, в этом руководстве мы будем использовать TensorFlow Java API для решения функции, представленной следующим уравнением:

z = 3*x + 2*y

Первый шаг — объявить и инициализировать граф:

Graph graph = new Graph()

Теперь нам нужно определить все необходимые операции. Помните, что операции в TensorFlow потребляют и производят ноль или более тензоров . Более того, каждый узел в графе — это операция, включающая константы и заполнители. Это может показаться нелогичным, но потерпите немного!

Класс Graph имеет общую функцию под названием opBuilder() для создания любых операций в TensorFlow.

4.1. Определение констант

Для начала давайте определим константные операции в нашем графе выше. Обратите внимание, что постоянной операции потребуется тензор для ее значения :

Operation a = graph.opBuilder("Const", "a")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(3.0, Double.class))
.build();
Operation b = graph.opBuilder("Const", "b")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(2.0, Double.class))
.build();

Здесь мы определили операцию постоянного типа, вводя тензор со значениями Double 2.0 и 3.0. Поначалу это может показаться немного ошеломляющим, но на данный момент именно так обстоит дело с Java API. Эти конструкции гораздо более лаконичны в таких языках, как Python.

4.2. Определение заполнителей

Хотя нам нужно предоставлять значения нашим константам, заполнителям не нужно значение во время определения . Значения заполнителей необходимо указывать, когда граф запускается внутри сеанса. Мы рассмотрим эту часть позже в уроке.

А пока давайте посмотрим, как мы можем определить наши заполнители:

Operation x = graph.opBuilder("Placeholder", "x")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
Operation y = graph.opBuilder("Placeholder", "y")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();

Обратите внимание, что нам не нужно было указывать какое-либо значение для наших заполнителей. Эти значения будут переданы как тензоры при запуске.

4.3. Определение функций

Наконец, нам нужно определить математические операции нашего уравнения, а именно умножение и сложение, чтобы получить результат.

Это снова не что иное, как операции в TensorFlow, и Graph.opBuilder() снова удобен:

Operation ax = graph.opBuilder("Mul", "ax")
.addInput(a.output(0))
.addInput(x.output(0))
.build();
Operation by = graph.opBuilder("Mul", "by")
.addInput(b.output(0))
.addInput(y.output(0))
.build();
Operation z = graph.opBuilder("Add", "z")
.addInput(ax.output(0))
.addInput(by.output(0))
.build();

Здесь мы определили там Operation , две для умножения наших входов и последнюю для суммирования промежуточных результатов. Обратите внимание, что операции здесь получают тензоры, которые являются не чем иным, как результатом наших предыдущих операций.

Обратите внимание, что мы получаем выходной тензор из операции , используя индекс «0». Как мы обсуждали ранее, операция может привести к одному или нескольким тензорам , и, следовательно, при получении дескриптора для него нам нужно упомянуть индекс. Поскольку мы знаем, что наши операции возвращают только один Tensor , '0' работает отлично!

5. Визуализация графика

Трудно следить за графиком, поскольку он увеличивается в размерах. Это делает важным визуализировать его каким-то образом . Мы всегда можем создать рисунок от руки, подобный маленькому графику, который мы создали ранее, но это нецелесообразно для больших графиков. Для этого TensorFlow предоставляет утилиту под названием TensorBoard .

К сожалению, Java API не имеет возможности генерировать файл событий, который используется TensorBoard. Но с помощью API в Python мы можем сгенерировать файл событий, например:

writer = tf.summary.FileWriter('.')
......
writer.add_graph(tf.get_default_graph())
writer.flush()

Пожалуйста, не беспокойтесь, если это не имеет смысла в контексте Java, это было добавлено сюда только для полноты картины и не обязательно для продолжения остальной части руководства.

Теперь мы можем загрузить и визуализировать файл событий в TensorBoard, например:

tensorboard --logdir .

./d45216cc7c068c58fcc9e8265391bddf.png

TensorBoard входит в состав установки TensorFlow.

Обратите внимание на сходство между этим и ранее нарисованным вручную графиком!

6. Работа с сеансом

Теперь мы создали вычислительный граф для нашего простого уравнения в TensorFlow Java API. Но как мы его запускаем? Прежде чем обратиться к этому, давайте посмотрим, в каком состоянии находится Graph , который мы только что создали. Если мы попытаемся напечатать вывод нашей последней операции «z»:

System.out.println(z.output(0));

Это приведет к чему-то вроде:

<Add 'z:0' shape=<unknown> dtype=DOUBLE>

Это не то, что мы ожидали! Но если вспомнить то, что мы обсуждали ранее, в этом действительно есть смысл. Граф , который мы только что определили, еще не запускался, поэтому тензоры в нем на самом деле не имеют никакого фактического значения. Вывод выше просто говорит, что это будет Tensor типа Double .

Давайте теперь определим сеанс для запуска нашего графика :

Session sess = new Session(graph)

Наконец, теперь мы готовы запустить наш график и получить ожидаемый результат:

Tensor<Double> tensor = sess.runner().fetch("z")
.feed("x", Tensor.<Double>create(3.0, Double.class))
.feed("y", Tensor.<Double>create(6.0, Double.class))
.run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());

Итак, что мы здесь делаем? Это должно быть довольно интуитивно понятно:

  • Получить бегуна из сеанса
  • Определите операцию для извлечения по ее имени «z»
  • Вставьте тензоры для наших заполнителей «x» и «y».
  • Запустите график в сеансе

И теперь мы видим скалярный вывод:

21.0

Это то, что мы ожидали, не так ли!

7. Пример использования Java API

На данный момент TensorFlow может показаться излишним для выполнения основных операций. Но, конечно же, TensorFlow предназначен для работы с графиками гораздо большего размера .

Кроме того, тензоры, с которыми он имеет дело в реальных моделях, намного больше по размеру и рангу . Это настоящие модели машинного обучения, в которых TensorFlow находит свое реальное применение.

Нетрудно заметить, что работа с основным API в TensorFlow может стать очень громоздкой по мере увеличения размера графа. С этой целью TensorFlow предоставляет высокоуровневые API, такие как Keras , для работы со сложными моделями . К сожалению, официальной поддержки Keras на Java пока практически нет.

Однако мы можем использовать Python для определения и обучения сложных моделей либо непосредственно в TensorFlow, либо с помощью высокоуровневых API, таких как Keras. Впоследствии мы можем экспортировать обученную модель и использовать ее в Java с помощью TensorFlow Java API.

Теперь, почему мы хотим сделать что-то подобное? Это особенно полезно в ситуациях, когда мы хотим использовать функции машинного обучения в существующих клиентах, работающих на Java. Например, рекомендовать подписи к пользовательским изображениям на устройстве Android. Тем не менее, есть несколько случаев, когда мы заинтересованы в выводе модели машинного обучения, но не обязательно хотим создавать и обучать эту модель на Java.

Именно здесь TensorFlow Java API находит большую часть своего использования. Мы рассмотрим, как этого можно достичь в следующем разделе.

8. Использование сохраненных моделей

Теперь мы поймем, как мы можем сохранить модель в TensorFlow в файловой системе и загрузить ее обратно, возможно, на совершенно другом языке и платформе. TensorFlow предоставляет API для создания файлов моделей в независимой от языка и платформы структуре, которая называется Protocol Buffer .

8.1. Сохранение моделей в файловой системе

Мы начнем с определения того же графа, который мы создали ранее в Python, и сохранения его в файловой системе.

Давайте посмотрим, что мы можем сделать в Python:

import tensorflow as tf
graph = tf.Graph()
builder = tf.saved_model.builder.SavedModelBuilder('./model')
with graph.as_default():
a = tf.constant(2, name='a')
b = tf.constant(3, name='b')
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
z = tf.math.add(a*x, b*y, name='z')
sess = tf.Session()
sess.run(z, feed_dict = {x: 2, y: 3})
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
builder.save()

Так как это руководство посвящено Java, давайте не будем обращать особого внимания на детали этого кода на Python, за исключением того факта, что он создает файл с именем «saved_model.pb». Обратите внимание на краткость определения аналогичного графа по сравнению с Java!

8.2. Загрузка моделей из файловой системы

Теперь мы загрузим «saved_model.pb» в Java. Java TensorFlow API имеет SavedModelBundle для работы с сохраненными моделями:

SavedModelBundle model = SavedModelBundle.load("./model", "serve"); 
Tensor<Integer> tensor = model.session().runner().fetch("z")
.feed("x", Tensor.<Integer>create(3, Integer.class))
.feed("y", Tensor.<Integer>create(3, Integer.class))
.run().get(0).expect(Integer.class);
System.out.println(tensor.intValue());

Теперь должно быть довольно интуитивно понятно, что делает приведенный выше код. Он просто загружает граф модели из буфера протокола и делает доступной сессию в нем. С этого момента мы можем делать с этим графом все, что угодно, как если бы мы делали это с локально определенным графом.

9. Заключение

Подводя итог, в этом уроке мы рассмотрели основные понятия, связанные с вычислительным графом TensorFlow. Мы увидели, как использовать TensorFlow Java API для создания и запуска такого графа. Затем мы обсудили варианты использования Java API в отношении TensorFlow.

В процессе мы также поняли, как визуализировать график с помощью TensorBoard, а также сохранять и перезагружать модель с помощью Protocol Buffer.

Как всегда, код примеров доступен на GitHub .