package org.nd4j.linalg.slicing;

import lombok.val;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;

import static org.junit.Assert.assertEquals;

/**
 * @author Adam Gibson
 */
@RunWith(Parameterized.class)
public class SlicingTests extends BaseNd4jTest {

    public SlicingTests(Nd4jBackend backend) {
        super(backend);
    }

    @Test
    public void testSlices() {
        INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[] {4, 3, 2});
        for (int i = 0; i < arr.slices(); i++) {
            INDArray slice = arr.slice(i).slice(1);
            val slices = slice.slices();
            assertEquals(2, slices);
        }

    }



    @Test
    public void testSlice() {
        INDArray arr = Nd4j.linspace(1, 24, 24).reshape(4, 3, 2);
        INDArray assertion = Nd4j.create(new double[][] {{1, 13}, {5, 17}, {9, 21}});

        INDArray firstSlice = arr.slice(0);
        INDArray slice1Assertion = Nd4j.create(new double[][] {{2, 14}, {6, 18}, {10, 22},

        });

        INDArray secondSlice = arr.slice(1);
        assertEquals(assertion, firstSlice);
        assertEquals(slice1Assertion, secondSlice);

    }


    @Override
    public char ordering() {
        return 'f';
    }
}