Перейти к основному содержимому

Умножение матриц в Java

· 11 мин. чтения

1. Обзор

В этом уроке мы рассмотрим, как мы можем перемножить две матрицы в Java.

Поскольку в языке изначально не существует понятия матрицы, мы реализуем его сами, а также поработаем с несколькими библиотеками, чтобы посмотреть, как они обрабатывают умножение матриц.

В конце мы проведем небольшой сравнительный анализ различных решений, которые мы исследовали, чтобы определить самое быстрое из них.

2. Пример

Давайте начнем с создания примера, на который мы сможем ссылаться в этом руководстве.

Сначала представим матрицу 3×2:

./4af726a0a3cca04e3fe44ee375e20e8f.png

Давайте теперь представим вторую матрицу, на этот раз две строки по четыре столбца:

./cd8bdad52797894330fdff9b8620c230.png

Затем умножение первой матрицы на вторую матрицу, в результате чего получается матрица 3×4:

./f088d9b2ee1bafb313c2f39282665c00.png

Напоминаем, что этот результат получается путем вычисления каждой ячейки результирующей матрицы по следующей формуле :

./09f5148a54cd016f10c4d97fb2704692.png

Где r — количество строк матрицы A , c — количество столбцов матрицы B , n — количество столбцов матрицы A , которое должно совпадать с количеством строк матрицы B.

3. Умножение матриц

3.1. Собственная реализация

Начнем с собственной реализации матриц.

Мы не будем усложнять и будем использовать двумерные двойные массивы :

double[][] firstMatrix = {
new double[]{1d, 5d},
new double[]{2d, 3d},
new double[]{1d, 7d}
};

double[][] secondMatrix = {
new double[]{1d, 2d, 3d, 7d},
new double[]{5d, 2d, 8d, 1d}
};

Это две матрицы нашего примера. Создадим ожидаемый в результате их умножения:

double[][] expected = {
new double[]{26d, 12d, 43d, 12d},
new double[]{17d, 10d, 30d, 17d},
new double[]{36d, 16d, 59d, 14d}
};

Теперь, когда все настроено, давайте реализуем алгоритм умножения. Сначала мы создадим пустой массив результатов и пройдемся по его ячейкам, чтобы сохранить ожидаемое значение в каждой из них:

double[][] multiplyMatrices(double[][] firstMatrix, double[][] secondMatrix) {
double[][] result = new double[firstMatrix.length][secondMatrix[0].length];

for (int row = 0; row < result.length; row++) {
for (int col = 0; col < result[row].length; col++) {
result[row][col] = multiplyMatricesCell(firstMatrix, secondMatrix, row, col);
}
}

return result;
}

Наконец, давайте реализуем вычисление одной ячейки. Чтобы добиться этого, мы будем использовать формулу, показанную ранее в презентации примера :

double multiplyMatricesCell(double[][] firstMatrix, double[][] secondMatrix, int row, int col) {
double cell = 0;
for (int i = 0; i < secondMatrix.length; i++) {
cell += firstMatrix[row][i] * secondMatrix[i][col];
}
return cell;
}

Наконец, давайте проверим, что результат алгоритма соответствует нашему ожидаемому результату:

double[][] actual = multiplyMatrices(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);

3.2. EJML

Первая библиотека, которую мы рассмотрим, — это EJML, что означает Efficient Java Matrix Library . На момент написания этого руководства это была одна из самых последних обновленных библиотек матриц Java . Его цель - быть максимально эффективным в отношении вычислений и использования памяти.

Нам нужно будет добавить зависимость к библиотеке в нашем pom.xml :

<dependency>
<groupId>org.ejml</groupId>
<artifactId>ejml-all</artifactId>
<version>0.38</version>
</dependency>

Мы будем использовать почти тот же шаблон, что и раньше: создадим две матрицы в соответствии с нашим примером и проверим, что результат их умножения совпадает с тем, который мы вычислили ранее.

Итак, давайте создадим наши матрицы с помощью EJML. Для этого мы будем использовать класс SimpleMatrix , предлагаемый библиотекой .

Он может принимать двумерный двойной массив в качестве входных данных для своего конструктора:

SimpleMatrix firstMatrix = new SimpleMatrix(
new double[][] {
new double[] {1d, 5d},
new double[] {2d, 3d},
new double[] {1d ,7d}
}
);

SimpleMatrix secondMatrix = new SimpleMatrix(
new double[][] {
new double[] {1d, 2d, 3d, 7d},
new double[] {5d, 2d, 8d, 1d}
}
);

А теперь давайте определим нашу ожидаемую матрицу для умножения:

SimpleMatrix expected = new SimpleMatrix(
new double[][] {
new double[] {26d, 12d, 43d, 12d},
new double[] {17d, 10d, 30d, 17d},
new double[] {36d, 16d, 59d, 14d}
}
);

Теперь, когда мы все настроили, давайте посмотрим, как перемножить две матрицы вместе. Класс SimpleMatrix предлагает метод mult () , принимающий в качестве параметра еще одну SimpleMatrix и возвращающий произведение двух матриц:

SimpleMatrix actual = firstMatrix.mult(secondMatrix);

Проверим, соответствует ли полученный результат ожидаемому.

Поскольку SimpleMatrix не переопределяет метод equals() , мы не можем полагаться на него при выполнении проверки. Но он предлагает альтернативу: метод isIdentical () , который принимает не только еще один матричный параметр, но и двойную отказоустойчивость, чтобы игнорировать небольшие различия из-за двойной точности:

assertThat(actual).matches(m -> m.isIdentical(expected, 0d));

На этом умножение матриц с помощью библиотеки EJML завершено. Посмотрим, что предлагают другие.

3.3. ND4J

Давайте теперь попробуем библиотеку ND4J . ND4J — это вычислительная библиотека, являющаяся частью проекта deeplearning4j . Помимо прочего, ND4J предлагает функции вычисления матриц.

Прежде всего, мы должны получить зависимость от библиотеки :

<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>1.0.0-beta4</version>
</dependency>

Обратите внимание, что мы используем здесь бета-версию, потому что, похоже, в выпуске GA есть некоторые ошибки.

Для краткости мы не будем переписывать двухмерные двойные массивы, а просто сосредоточимся на том, как они используются с каждой библиотекой. Таким образом, с ND4J мы должны создать INDArray . Для этого мы вызовем фабричный метод Nd4j.create() и передадим ему двойной массив, представляющий нашу матрицу :

INDArray matrix = Nd4j.create(/* a two dimensions double array */);

Как и в предыдущем разделе, мы создадим три матрицы: две мы собираемся перемножить вместе, а одна — ожидаемый результат.

После этого мы хотим выполнить умножение между первыми двумя матрицами, используя метод INDArray.mmul() :

INDArray actual = firstMatrix.mmul(secondMatrix);

Затем мы снова проверяем, соответствует ли фактический результат ожидаемому. На этот раз мы можем положиться на проверку на равенство:

assertThat(actual).isEqualTo(expected);

Это демонстрирует, как можно использовать библиотеку ND4J для выполнения матричных вычислений.

3.4. Апач Коммонс

Давайте теперь поговорим о модуле Apache Commons Math3 , который предоставляет нам математические вычисления, включая манипуляции с матрицами.

Опять же, нам нужно будет указать зависимость в нашем pom.xml :

<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>

После настройки мы можем использовать интерфейс RealMatrix и его реализацию Array2DRowRealMatrix для создания наших обычных матриц. Конструктор класса реализации принимает в качестве параметра двумерный массив типа double :

RealMatrix matrix = new Array2DRowRealMatrix(/* a two dimensions double array */);

Что касается умножения матриц, интерфейс RealMatrix предлагает методmulti () , принимающий другой параметр RealMatrix :

RealMatrix actual = firstMatrix.multiply(secondMatrix);

Наконец, мы можем убедиться, что результат равен ожидаемому:

assertThat(actual).isEqualTo(expected);

Давайте посмотрим следующую библиотеку!

3.5. LA4J

Он называется LA4J, что означает Linear Algebra for Java .

Давайте также добавим зависимость для этого:

<dependency>
<groupId>org.la4j</groupId>
<artifactId>la4j</artifactId>
<version>0.6.0</version>
</dependency>

Теперь LA4J работает почти так же, как и другие библиотеки. Он предлагает интерфейс Matrix с реализацией Basic2DMatrix , которая принимает двумерный двойной массив в качестве входных данных:

Matrix matrix = new Basic2DMatrix(/* a two dimensions double array */);

Как и в модуле Apache Commons Math3, метод умножения — умножение () и принимает другую матрицу в качестве параметра:

Matrix actual = firstMatrix.multiply(secondMatrix);

Еще раз можем проверить, соответствует ли результат нашим ожиданиям:

assertThat(actual).isEqualTo(expected);

Давайте теперь посмотрим на нашу последнюю библиотеку: Colt.

3.6. Кольт

Colt — это библиотека, разработанная CERN. Он предоставляет функции, обеспечивающие высокую производительность научных и технических вычислений.

Как и в случае с предыдущими библиотеками, мы должны получить правильную зависимость :

<dependency>
<groupId>colt</groupId>
<artifactId>colt</artifactId>
<version>1.2.0</version>
</dependency>

Чтобы создавать матрицы с помощью Colt, мы должны использовать класс DoubleFactory2D . Он поставляется с тремя фабричными экземплярами: плотным, разреженным и rowCompressed . Каждый из них оптимизирован для создания соответствующей матрицы.

Для нашей цели мы будем использовать плотный экземпляр. На этот раз вызывается метод make() , и он снова принимает двумерный массив double , создавая объект DoubleMatrix2D :

DoubleMatrix2D matrix = doubleFactory2D.make(/* a two dimensions double array */);

После создания экземпляров наших матриц мы хотим их умножить. На этот раз для матричного объекта нет метода для этого. Нам нужно создать экземпляр класса Algebra , который имеет метод mult() , принимающий две матрицы в качестве параметров:

Algebra algebra = new Algebra();
DoubleMatrix2D actual = algebra.mult(firstMatrix, secondMatrix);

Затем мы можем сравнить фактический результат с ожидаемым:

assertThat(actual).isEqualTo(expected);

4. Сравнительный анализ

Теперь, когда мы закончили изучение различных возможностей умножения матриц, давайте проверим, какие из них наиболее эффективны.

4.1. Маленькие матрицы

Начнем с малых матриц. Здесь матрицы 3×2 и 2×4.

Для реализации теста производительности воспользуемся бенчмаркинговой библиотекой JMH . Давайте настроим класс бенчмаркинга со следующими параметрами:

public static void main(String[] args) throws Exception {
Options opt = new OptionsBuilder()
.include(MatrixMultiplicationBenchmarking.class.getSimpleName())
.mode(Mode.AverageTime)
.forks(2)
.warmupIterations(5)
.measurementIterations(10)
.timeUnit(TimeUnit.MICROSECONDS)
.build();

new Runner(opt).run();
}

Таким образом, JMH выполнит два полных прогона для каждого метода, аннотированного @Benchmark , каждый с пятью прогревочными итерациями (не учитываемыми при расчете среднего) и десятью измерениями. Что касается измерений, то они будут собирать среднее время выполнения различных библиотек в микросекундах.

Затем нам нужно создать объект состояния, содержащий наши массивы:

@State(Scope.Benchmark)
public class MatrixProvider {
private double[][] firstMatrix;
private double[][] secondMatrix;

public MatrixProvider() {
firstMatrix =
new double[][] {
new double[] {1d, 5d},
new double[] {2d, 3d},
new double[] {1d ,7d}
};

secondMatrix =
new double[][] {
new double[] {1d, 2d, 3d, 7d},
new double[] {5d, 2d, 8d, 1d}
};
}
}

Таким образом, мы гарантируем, что инициализация массивов не является частью бенчмаркинга. После этого нам еще нужно создать методы, выполняющие умножение матриц, используя объект MatrixProvider в качестве источника данных. Мы не будем повторять здесь код, так как мы видели каждую библиотеку ранее.

Наконец, мы запустим процесс бенчмаркинга, используя наш основной метод. Это дает нам следующий результат:

Benchmark                                                           Mode  Cnt   Score   Error  Units
MatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication avgt 20 1,008 ± 0,032 us/op
MatrixMultiplicationBenchmarking.coltMatrixMultiplication avgt 20 0,219 ± 0,014 us/op
MatrixMultiplicationBenchmarking.ejmlMatrixMultiplication avgt 20 0,226 ± 0,013 us/op
MatrixMultiplicationBenchmarking.homemadeMatrixMultiplication avgt 20 0,389 ± 0,045 us/op
MatrixMultiplicationBenchmarking.la4jMatrixMultiplication avgt 20 0,427 ± 0,016 us/op
MatrixMultiplicationBenchmarking.nd4jMatrixMultiplication avgt 20 12,670 ± 2,582 us/op

Как мы видим, EJML и Colt действительно хорошо работают с примерно одной пятой микросекунды на операцию, где ND4j менее производительен с чуть более десяти микросекунд на операцию . В других библиотеках спектакли проходят между ними.

Также стоит отметить, что при увеличении количества итераций прогрева с 5 до 10 производительность увеличивается для всех библиотек.

4.2. Большие матрицы

А что произойдет, если мы возьмем матрицы большего размера, например 3000×3000? Чтобы проверить, что происходит, давайте сначала создадим еще один класс состояния, предоставляющий сгенерированные матрицы такого размера:

@State(Scope.Benchmark)
public class BigMatrixProvider {
private double[][] firstMatrix;
private double[][] secondMatrix;

public BigMatrixProvider() {}

@Setup
public void setup(BenchmarkParams parameters) {
firstMatrix = createMatrix();
secondMatrix = createMatrix();
}

private double[][] createMatrix() {
Random random = new Random();

double[][] result = new double[3000][3000];
for (int row = 0; row < result.length; row++) {
for (int col = 0; col < result[row].length; col++) {
result[row][col] = random.nextDouble();
}
}
return result;
}
}

Как мы видим, мы создадим двумерные двойные массивы размером 3000×3000, заполненные случайными вещественными числами.

Давайте теперь создадим класс бенчмаркинга:

public class BigMatrixMultiplicationBenchmarking {
public static void main(String[] args) throws Exception {
Map<String, String> parameters = parseParameters(args);

ChainedOptionsBuilder builder = new OptionsBuilder()
.include(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
.mode(Mode.AverageTime)
.forks(2)
.warmupIterations(10)
.measurementIterations(10)
.timeUnit(TimeUnit.SECONDS);

new Runner(builder.build()).run();
}

@Benchmark
public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) {
return HomemadeMatrix
.multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix());
}

@Benchmark
public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) {
SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix());
SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix());

return firstMatrix.mult(secondMatrix);
}

@Benchmark
public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) {
RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix());
RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix());

return firstMatrix.multiply(secondMatrix);
}

@Benchmark
public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix());
Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix());

return firstMatrix.multiply(secondMatrix);
}

@Benchmark
public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix());
INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix());

return firstMatrix.mmul(secondMatrix);
}

@Benchmark
public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) {
DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;

DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix());
DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix());

Algebra algebra = new Algebra();
return algebra.mult(firstMatrix, secondMatrix);
}
}

Когда мы запускаем этот бенчмаркинг, мы получаем совершенно другие результаты:

Benchmark                                                              Mode  Cnt    Score    Error  Units
BigMatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication avgt 20 511.140 ± 13.535 s/op
BigMatrixMultiplicationBenchmarking.coltMatrixMultiplication avgt 20 197.914 ± 2.453 s/op
BigMatrixMultiplicationBenchmarking.ejmlMatrixMultiplication avgt 20 25.830 ± 0.059 s/op
BigMatrixMultiplicationBenchmarking.homemadeMatrixMultiplication avgt 20 497.493 ± 2.121 s/op
BigMatrixMultiplicationBenchmarking.la4jMatrixMultiplication avgt 20 35.523 ± 0.102 s/op
BigMatrixMultiplicationBenchmarking.nd4jMatrixMultiplication avgt 20 0.548 ± 0.006 s/op

Как мы видим, самодельные реализации и библиотека Apache теперь намного хуже, чем раньше: умножение двух матриц занимает почти 10 минут.

Кольт занимает чуть больше 3 минут, что лучше, но все равно очень долго. EJML и LA4J работают довольно хорошо, поскольку выполняются почти за 30 секунд. Но именно ND4J побеждает в этом бенчмаркинге, работая менее чем за секунду на бэкенде ЦП .

4.3. Анализ

Это показывает нам, что результаты бенчмаркинга действительно зависят от характеристик матриц, и поэтому сложно выделить единственного победителя.

5. Вывод

В этой статье мы узнали, как умножать матрицы в Java самостоятельно или с помощью внешних библиотек. Изучив все решения, мы провели бенчмарк всех из них и увидели, что все они, за исключением ND4J, неплохо показали себя на небольших матрицах. С другой стороны, на больших матрицах лидирует ND4J.

Как обычно, полный код этой статьи можно найти на GitHub .