package org.nd4j.linalg.shape.ones;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;

import static org.junit.Assert.assertEquals;

/**
 * @author Adam Gibson
 */
@RunWith(Parameterized.class)
public class LeadingAndTrailingOnesC extends BaseNd4jTest {

    public LeadingAndTrailingOnesC(Nd4jBackend backend) {
        super(backend);
    }

    @Test
    public void testCreateLeadingAndTrailingOnes() {
        INDArray arr = Nd4j.create(1, 10, 1, 1);
        arr.assign(1);
        System.out.println(arr);
    }

    @Test
    public void testMatrix() {
        INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
        INDArray slice1 = arr.slice(1);
        System.out.println(arr.slice(1));
        INDArray oneInMiddle = Nd4j.linspace(1, 4, 4).reshape(2, 1, 2);
        INDArray otherSlice = oneInMiddle.slice(1);
        assertEquals(2, otherSlice.offset());
        System.out.println(otherSlice);
        INDArray twoOnesInMiddle = Nd4j.linspace(1, 4, 4).reshape(2, 1, 1, 2);
        INDArray sub = twoOnesInMiddle.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all(),
                        NDArrayIndex.all());
        assertEquals(2, sub.offset());

    }

    @Test
    public void testMultipleOnesInMiddle() {
        INDArray tensor = Nd4j.linspace(1, 144, 144).reshape(2, 2, 1, 1, 6, 6);
        INDArray tensorSlice1 = tensor.slice(1);
        INDArray tensorSlice1Slice1 = tensorSlice1.slice(1);
        System.out.println(tensor);
    }

    @Test
    public void testOnesInMiddleTensor() {
        INDArray im2colAssertion = Nd4j.create(new double[] {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                        0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                        5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 10.0,
                        0.0, 0.0, 0.0, 0.0, 11.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
                        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 13.0, 14.0, 0.0, 0.0,
                        0.0, 0.0, 15.0, 16.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
                        new int[] {2, 2, 1, 1, 6, 6});
        System.out.println(im2colAssertion);
    }

    @Override
    public char ordering() {
        return 'c';
    }
}