package neural; import java.util.Random; import java.util.function.BiFunction; import java.util.function.DoubleFunction; /** * Matrix class for storage & calculations * * @author Sebastian Gössl * @version 1.2 26.03.2018 */ public class Matrix { /** Matrix dimensions */ private final int height, width; /** Matrix elements */ private final double[][] matrix; /** * Constructs a new copy of an existing matrix * @param input Matrix to copy */ public Matrix(Matrix input) { this(input.getHeight(), input.getWidth()); for(int j=0; j<getHeight(); j++) { for(int i=0; i<getWidth(); i++) { set(j, i, input.get(j, i)); } } } /** * Constructs a new Matrix with the content of a 2-dimensional array * @param array Array which contents should be copied into this matrix * @throws IllegalArgumentException If the input array is not rectangular */ public Matrix(double[][] array) { this(array.length, array[0].length); for(int j=0; j<getHeight(); j++) { if(array[j].length != getWidth()) { throw new IllegalArgumentException("Input array not rectangular!"); } for(int i=0; i<getWidth(); i++) { set(j, i, array[j][i]); } } } /** * Constructs a new matrix with the given dimensions * @param height Number of rows * @param width Number of columns * @throws IllegalArgumentException If the dimensions are * impossible to implement */ public Matrix(int height, int width) { if(height < 1 || width < 1) { throw new IllegalArgumentException("Dimension(s) less than 1!"); } this.height = height; this.width = width; matrix = new double[height][width]; } /** * Sets the value of a specific element * @param row Row index of the element * @param column Column index of the element * @param value Value to set the element to * @throws ArrayIndexOutOfBoundsException If the indices are smaller than 0 * or bigger than the width/height -1 */ public void set(int row, int column, double value) { if(row < 0 || row >= getHeight() || column < 0 || column >= getWidth()) { throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); } matrix[row][column] = value; } /** * Returns the value of a specific element * @param row Row index of the element * @param column Column index of the element * @return The value of the element * @throws ArrayIndexOutOfBoundsException If the indices are smaller than 0 * or bigger than the width/height -1 */ public double get(int row, int column) { if(row < 0 || row >= getHeight() || column < 0 || column >= getWidth()) { throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); } return matrix[row][column]; } /** * Returns the height (number of rows) of the matrix * @return Height of the matrix */ public int getHeight() { return height; } /** * Returns the width (number of columns) of the matrix * @return Width of the matrix */ public int getWidth() { return width; } /** * Sets every element of the matrix to the given value * @param value Value to set every element to */ public void fill(double value) { for(int j=0; j<getHeight(); j++) { for(int i=0; i<getWidth(); i++) { set(j, i, value); } } } /** * Adds the given matrix to this matrix. * C = A + B * C[i, j] = A[i, j] + B[i, j] * @param matrix2 Second matrix to add (summand) * @return Sum of the two matricies * @throws IllegalArgumentException If the matricies are not * of the same dimensions */ public Matrix add(Matrix matrix2) { if(matrix2.getHeight() != getHeight() || matrix2.getWidth() != getWidth()) { throw new IllegalArgumentException("Dimensions not compatible!"); } final Matrix result = new Matrix(getHeight(), getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(j, i) + matrix2.get(j, i)); } } return result; } /** * Subtracts the given matrix from this matrix. * C = A - B * C[i, j] = A[i, j] - B[i, j] * @param matrix2 Second matrix to subtract (subtrahend) * @return Difference of the two matricies * @throws IllegalArgumentException If the matricies are not * of the same dimensions */ public Matrix subtract(Matrix matrix2) { if(matrix2.getHeight() != getHeight() || matrix2.getWidth() != getWidth()){ throw new IllegalArgumentException("Dimensions not compatible!"); } final Matrix result = new Matrix(getHeight(), getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(j, i) - matrix2.get(j, i)); } } return result; } /** * Multiplies this matrix with the given one. * C = AB * C[i, j] = A[i, 1]*B[1, j] + ... + A[i, m]*B[m, j] * @param matrix2 Second matrix to multiply (factor) * @return Matrix product of the two matricies * @throws IllegalArgumentException If the second matrix' height does not * equal this matrix' width */ public Matrix multiply(Matrix matrix2) { if(matrix2.getHeight() != getWidth()) { throw new IllegalArgumentException("Matrix dimensions not compatible!"); } final Matrix result = new Matrix(getHeight(), matrix2.getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { double sum = 0; for(int k=0; k<getWidth(); k++) { sum += get(j, k) * matrix2.get(k, i); } result.set(j, i, sum); } } return result; } /** * Multiplies this matrix with the given matrix elementwise. * C = A o B * C[i, j] = A[i, j] * B[i, j] * @param matrix2 Second matrix to multiply elementwise (factor) * @return Elementwise/Hadamard product * @throws IllegalArgumentException If the matricies are not * of the same dimensions */ public Matrix multiplyElementwise(Matrix matrix2) { if(matrix2.getHeight() != getHeight() || matrix2.getWidth() != getWidth()) { throw new IllegalArgumentException("Dimensions not compatible!"); } final Matrix result = new Matrix(getHeight(), getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(j, i) * matrix2.get(j, i)); } } return result; } /** * Scalar multiplies this matrix with the given factor. * C = b * A * C[i, j] = b * A[i, j] * @param value Scalar to multiply every element with * @return Scaled matrix */ public Matrix multiply(double value) { final Matrix result = new Matrix(getHeight(), getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, value * get(j, i)); } } return result; } /** * Divides this matrix by the given matrix elementwise. * C[i, j] = A[i, j] / B[i, j] * @param matrix2 Second matrix to divide by elementwise (divisor) * @return Elementwise quotient * @throws IllegalArgumentException If the matricies are not * of the same dimensions */ public Matrix divideElementwise(Matrix matrix2) { if(matrix2.getHeight() != getHeight() || matrix2.getWidth() != getWidth()) { throw new IllegalArgumentException("Dimensions not compatible!"); } final Matrix result = new Matrix(getHeight(), getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(j, i) / matrix2.get(j, i)); } } return result; } /** * Applies the given function on every element of the matrix. * B[i, j] = f(A[i, j]) * @param function Function that gets applied on every element * @return Resulting matrix */ public Matrix apply(DoubleFunction<Double> function) { final Matrix result = new Matrix(getHeight(), getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, function.apply(get(j, i))); } } return result; } /** * Applies the given function, * on every element of this and the given matrix, * and writes the output into the corresponding element of a new matrix. * C[i, j] = f(A[i, j], B[i, j]) * @param matrix2 * @param function Function that gets applied on every element * @return Resulting matrix */ public Matrix apply(Matrix matrix2, BiFunction<Double, Double, Double> function) { final Matrix result = new Matrix(getHeight(), getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, function.apply(get(j, i), matrix2.get(j, i))); } } return result; } /** * Transposes this matrix * @return Transpose */ public Matrix transpose() { final Matrix result = new Matrix(getWidth(), getHeight()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(i, j)); } } return result; } /** * Extracts a single row as a new Matrix * @param index Index of the row that should be extracted * @return The single row as a new Matrix * @throws ArrayIndexOutOfBoundsException If the index does not point * to an existing row */ public Matrix getRow(int index) { if(index < 0 || index >= getWidth()) { throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); } return getRows(index, index + 1); } /** * Extracts multiple rows as a new Matrix * @param fromIndex Index of the first row * that should be extracted (inclusive) * @param toIndex Index of the last row that should be extracted (exclusive) * @return The rows as a new Matrix * @throws ArrayIndexOutOfBoundsException If an index does not point * to an existing row */ public Matrix getRows(int fromIndex, int toIndex) { if(fromIndex < 0 || fromIndex >= getHeight() || toIndex < 0 || toIndex > getHeight()) { throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); } if(fromIndex >= toIndex) { throw new IllegalArgumentException("Illegal index direction!"); } final Matrix result = new Matrix(toIndex - fromIndex, getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(fromIndex + j, i)); } } return result; } /** * Appends a matrix to the bottom end of this matrix * @param rows Matrix to append * @return Merged matrix * @throws IllegalArgumentException If the givenn matrix * is not as wide as this matrix */ public Matrix appendRows(Matrix rows) { if(rows.getWidth() != getWidth()) { throw new IllegalArgumentException("Rows not compatible!"); } final Matrix result = new Matrix(getHeight() + rows.getHeight(), getWidth()); for(int j=0; j<getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(j, i)); } } for(int j=0; j<rows.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(getHeight() + j, i, rows.get(j, i)); } } return result; } /** * Removes a single row of this matrix * @param index Index of the row that should be removed * @return Resulting matrix * @throws ArrayIndexOutOfBoundsException If this matrix is to small to * remove a row or the index does not point to an existing row */ public Matrix removeRow(int index) { if(getHeight() < 2) { throw new ArrayIndexOutOfBoundsException("Matrix to small!"); } if(index < 0 || index >= getHeight()) { throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); } return removeRows(index, index + 1); } /** * Removes multiple rows of this matrix * @param fromIndex Index of the first row that should be removed (inclusive) * @param toIndex Index of the last row that should be removed (exclusive) * @return Resulting matrix * @throws ArrayIndexOutOfBoundsException If this matrix is to small to * remove the rows or an index does not point to an existing row */ public Matrix removeRows(int fromIndex, int toIndex) { if(getHeight() <= toIndex - fromIndex) { throw new ArrayIndexOutOfBoundsException("Matrix to small!"); } if(fromIndex < 0 || fromIndex >= getHeight() || toIndex < 0 || toIndex > getHeight()) { throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); } if(fromIndex >= toIndex) { throw new IllegalArgumentException("Illegal index direction!"); } final Matrix result = new Matrix(getHeight() - (toIndex - fromIndex), getWidth()); for(int j=0; j<fromIndex; j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(j, i)); } } for(int j=fromIndex; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get((toIndex - fromIndex) + j, i)); } } return result; } /** * Extracts a single column as a new Matrix * @param index Index of the column that should be extracted * @return The single column as a new Matrix * @throws ArrayIndexOutOfBoundsException If the index does not point * to an existing column */ public Matrix getColumn(int index) { if(index < 0 || index >= getHeight()) { throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); } return getColumns(index, index + 1); } /** * Extracts multiple columns as a new Matrix * @param fromIndex Index of the first column * that should be extracted (inclusive) * @param toIndex Index of the last column * that should be extracted (exclusive) * @return The columns as a new Matrix * @throws ArrayIndexOutOfBoundsException If an index does not point * to an existing column */ public Matrix getColumns(int fromIndex, int toIndex) { if(fromIndex < 0 || fromIndex >= getWidth() || toIndex < 0 || toIndex > getWidth()) { throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); } if(fromIndex >= toIndex) { throw new IllegalArgumentException("Illegal index direction!"); } final Matrix result = new Matrix(getHeight(), toIndex - fromIndex); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<result.getWidth(); i++) { result.set(j, i, get(j, fromIndex + i)); } } return result; } /** * Appends a matrix to the right end of this matrix * @param columns Matrix to append * @return Merged matrix * @throws IllegalArgumentException If the given matrix * is not as high as this matrix */ public Matrix appendColumns(Matrix columns) { if(columns.getHeight() != getHeight()) { throw new IllegalArgumentException("Column not compatible!"); } final Matrix result = new Matrix(getHeight(), getWidth() + columns.getWidth()); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<getWidth(); i++) { result.set(j, i, get(j, i)); } for(int i=0; i<columns.getWidth(); i++) { result.set(j, getWidth() + i, columns.get(j, i)); } } return result; } /** * Removes a single column of this matrix * @param index Index of the column that should be remove * @return Resulting matrix * @throws ArrayIndexOutOfBoundsException If this matrix is to small to * remove a column or the index does not point to an existing column */ public Matrix removeColumn(int index) { if(getWidth() < 2) { throw new ArrayIndexOutOfBoundsException("Matrix to small!"); } if(index < 0 || index >= getWidth()) { throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); } return removeColumns(index, index + 1); } /** * Removes multiple columns of this matrix * @param fromIndex Index of the first column * that should be removed (inclusive) * @param toIndex Index of the last column that should be removed (exclusive) * @return Resulting matrix * @throws ArrayIndexOutOfBoundsException If this matrix is to small to * remove the columns or an index does not point to an existing column */ public Matrix removeColumns(int fromIndex, int toIndex) { if(getWidth() <= toIndex - fromIndex) { throw new ArrayIndexOutOfBoundsException("Matrix to small!"); } if(fromIndex < 0 || fromIndex >= getWidth() || toIndex < 0 || toIndex > getWidth()) { throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); } if(fromIndex >= toIndex) { throw new IllegalArgumentException("Illegal index direction!"); } final Matrix result = new Matrix(getHeight(), getWidth() - (toIndex - fromIndex)); for(int j=0; j<result.getHeight(); j++) { for(int i=0; i<fromIndex; i++) { result.set(j, i, get(j, i)); } for(int i=fromIndex; i<result.getWidth(); i++) { result.set(j, i, get(j, (toIndex - fromIndex) + i)); } } return result; } /** * Fills the matrix with random values */ public void rand() { rand(new Random(), -Double.MAX_VALUE, Double.MAX_VALUE); } /** * Fills the matrix with random values, * from minimum (inclusive) to maximum (exclusive), * given by the random number generator * @param rand Random number generator * @param minimum Minimum value (inclusive) * @param maximum Maximum value (exclusive) */ public void rand(Random rand, double minimum, double maximum) { final double range = maximum - minimum; for(int j=0; j<getHeight(); j++) { for(int i=0; i<getWidth(); i++) { set(j, i, range*rand.nextDouble() + minimum); } } } /** * Copies the content of the matrix into a 2-dimensional array * @return Array copy */ public double[][] toArray() { final double[][] array = new double[getHeight()][getWidth()]; for(int j=0; j<array.length; j++) { for(int i=0; i<array[j].length; i++) { array[j][i] = get(j, i); } } return array; } @Override public String toString() { final StringBuilder builder = new StringBuilder("[[").append(get(0, 0)); for(int i=1; i<getWidth(); i++) { builder.append(", ").append(get(0, i)); } builder.append("]"); for(int j=1; j<getHeight(); j++) { builder.append("\n [").append(get(j, 0)); for(int i=1; i<getWidth(); i++) { builder.append(", ").append(get(j, i)); } builder.append("]"); } builder.append("]"); return builder.toString(); } public static void main(String[] args) { double scalar = 2; Matrix matrix1 = new Matrix(new double[][] { {1, 2, 3}, {4, 42, 6}, {7, 8, 9}}); Matrix matrix2 = new Matrix(new double[][] { {1, 4, 7}, {2, 5, 8}, {3, 6, 9}}); System.out.println("Scalar:"); System.out.println(scalar + "\n"); System.out.println("Matrix1:"); System.out.println(matrix1 + "\n"); System.out.println("Matrix2:"); System.out.println(matrix2 + "\n"); System.out.println(); System.out.println("Set [1, 1] to 5:"); matrix1.set(1, 1, 5); System.out.println(matrix1 + "\n"); System.out.println("Get [1, 1]:" + matrix1.get(1, 1)); System.out.println("Height: " + matrix1.getHeight()); System.out.println("Width: " + matrix1.getWidth() + "\n\n"); System.out.println("Addition (Matrix1 & Matrix2):"); System.out.println(matrix1.add(matrix2) + "\n"); System.out.println("Subtraction (Matrix1 & Matrix2):"); System.out.println(matrix1.subtract(matrix2) + "\n"); System.out.println("Matrix multiplication:"); System.out.println(matrix1.multiply(matrix2) + "\n"); System.out.println("Elementwise multiplication:"); System.out.println(matrix1.multiplyElementwise(matrix2) + "\n"); System.out.println("Scalar multiplication:"); System.out.println(matrix1.multiply(scalar) + "\n"); System.out.println("Elementwise division:"); System.out.println(matrix1.divideElementwise(matrix2) + "\n"); System.out.println("Applying sine:"); System.out.println(matrix1.apply(x -> Math.sin(x)) + "\n\n"); System.out.println("Transpose:"); System.out.println(matrix1.transpose() + "\n\n"); System.out.println("Get row 1:"); System.out.println(matrix1.getRow(1) + "\n"); System.out.println("Get rows 1 & 2:"); System.out.println(matrix1.getRows(1, 3) + "\n"); System.out.println("Append rows:"); System.out.println(matrix1.appendRows(matrix2) + "\n"); System.out.println("Remove row 1:"); System.out.println(matrix1.removeRow(1) + "\n"); System.out.println("Remove rows 0 & 1:"); System.out.println(matrix1.removeRows(0, 2) + "\n\n"); System.out.println("Get column 1:"); System.out.println(matrix1.getColumn(1) + "\n"); System.out.println("Get columns 1 & 2:"); System.out.println(matrix1.getColumns(1, 3) + "\n"); System.out.println("Append columns:"); System.out.println(matrix1.appendColumns(matrix2) + "\n"); System.out.println("Remove column 1:"); System.out.println(matrix1.removeColumn(1) + "\n"); System.out.println("Remove columns 0 & 1:"); System.out.println(matrix1.removeColumns(0, 2) + "\n"); System.out.println("Randomize within [-1, 1[:"); matrix1.rand(new Random(), -1, 1); System.out.println(matrix1 + "\n"); } }