/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.iceberg.spark.data;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.Collection;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.arrow.vector.ValueVector;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericData.Record;
import org.apache.iceberg.Schema;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.data.vectorized.IcebergArrowColumnVector;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.orc.storage.serde2.io.DateWritable;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.unsafe.types.UTF8String;
import org.junit.Assert;
import scala.collection.Seq;

import static org.apache.iceberg.spark.SparkSchemaUtil.convert;
import static scala.collection.JavaConverters.mapAsJavaMapConverter;
import static scala.collection.JavaConverters.seqAsJavaListConverter;

@SuppressWarnings("checkstyle:OverloadMethodsDeclarationOrder")
public class TestHelpers {

  private TestHelpers() {}

  public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row) {
    List<Types.NestedField> fields = struct.fields();
    for (int i = 0; i < fields.size(); i += 1) {
      Type fieldType = fields.get(i).type();

      Object expectedValue = rec.get(i);
      Object actualValue = row.get(i);

      assertEqualsSafe(fieldType, expectedValue, actualValue);
    }
  }

  public static void assertEqualsBatch(Types.StructType struct, Iterator<Record> expected, ColumnarBatch batch,
                                       boolean checkArrowValidityVector) {
    for (int rowId = 0; rowId < batch.numRows(); rowId++) {
      List<Types.NestedField> fields = struct.fields();
      InternalRow row = batch.getRow(rowId);
      Record rec = expected.next();
      for (int i = 0; i < fields.size(); i += 1) {
        Type fieldType = fields.get(i).type();
        Object expectedValue = rec.get(i);
        Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType));
        assertEqualsUnsafe(fieldType, expectedValue, actualValue);

        if (checkArrowValidityVector) {
          ColumnVector columnVector = batch.column(i);
          ValueVector arrowVector = ((IcebergArrowColumnVector) columnVector).vectorAccessor().getVector();
          Assert.assertEquals("Nullability doesn't match", expectedValue == null, arrowVector.isNull(rowId));
        }
      }
    }
  }


  private static void assertEqualsSafe(Types.ListType list, Collection<?> expected, List actual) {
    Type elementType = list.elementType();
    List<?> expectedElements = Lists.newArrayList(expected);
    for (int i = 0; i < expectedElements.size(); i += 1) {
      Object expectedValue = expectedElements.get(i);
      Object actualValue = actual.get(i);

      assertEqualsSafe(elementType, expectedValue, actualValue);
    }
  }

  private static void assertEqualsSafe(Types.MapType map,
                                       Map<?, ?> expected, Map<?, ?> actual) {
    Type keyType = map.keyType();
    Type valueType = map.valueType();

    for (Object expectedKey : expected.keySet()) {
      Object matchingKey = null;
      for (Object actualKey : actual.keySet()) {
        try {
          assertEqualsSafe(keyType, expectedKey, actualKey);
          matchingKey = actualKey;
        } catch (AssertionError e) {
          // failed
        }
      }

      Assert.assertNotNull("Should have a matching key", matchingKey);
      assertEqualsSafe(valueType, expected.get(expectedKey), actual.get(matchingKey));
    }
  }

  private static final OffsetDateTime EPOCH = Instant.ofEpochMilli(0L).atOffset(ZoneOffset.UTC);
  private static final LocalDate EPOCH_DAY = EPOCH.toLocalDate();

  @SuppressWarnings("unchecked")
  private static void assertEqualsSafe(Type type, Object expected, Object actual) {
    if (expected == null && actual == null) {
      return;
    }

    switch (type.typeId()) {
      case BOOLEAN:
      case INTEGER:
      case LONG:
      case FLOAT:
      case DOUBLE:
        Assert.assertEquals("Primitive value should be equal to expected", expected, actual);
        break;
      case DATE:
        Assert.assertTrue("Should be an int", expected instanceof Integer);
        Assert.assertTrue("Should be a Date", actual instanceof Date);
        int daysFromEpoch = (Integer) expected;
        LocalDate date = ChronoUnit.DAYS.addTo(EPOCH_DAY, daysFromEpoch);
        Assert.assertEquals("ISO-8601 date should be equal", date.toString(), actual.toString());
        break;
      case TIMESTAMP:
        Assert.assertTrue("Should be a long", expected instanceof Long);
        Assert.assertTrue("Should be a Timestamp", actual instanceof Timestamp);
        Timestamp ts = (Timestamp) actual;
        // milliseconds from nanos has already been added by getTime
        long tsMicros = (ts.getTime() * 1000) + ((ts.getNanos() / 1000) % 1000);
        Assert.assertEquals("Timestamp micros should be equal", expected, tsMicros);
        break;
      case STRING:
        Assert.assertTrue("Should be a String", actual instanceof String);
        Assert.assertEquals("Strings should be equal", String.valueOf(expected), actual);
        break;
      case UUID:
        Assert.assertTrue("Should expect a UUID", expected instanceof UUID);
        Assert.assertTrue("Should be a String", actual instanceof String);
        Assert.assertEquals("UUID string representation should match",
            expected.toString(), actual);
        break;
      case FIXED:
        Assert.assertTrue("Should expect a Fixed", expected instanceof GenericData.Fixed);
        Assert.assertTrue("Should be a byte[]", actual instanceof byte[]);
        Assert.assertArrayEquals("Bytes should match",
            ((GenericData.Fixed) expected).bytes(), (byte[]) actual);
        break;
      case BINARY:
        Assert.assertTrue("Should expect a ByteBuffer", expected instanceof ByteBuffer);
        Assert.assertTrue("Should be a byte[]", actual instanceof byte[]);
        Assert.assertArrayEquals("Bytes should match",
            ((ByteBuffer) expected).array(), (byte[]) actual);
        break;
      case DECIMAL:
        Assert.assertTrue("Should expect a BigDecimal", expected instanceof BigDecimal);
        Assert.assertTrue("Should be a BigDecimal", actual instanceof BigDecimal);
        Assert.assertEquals("BigDecimals should be equal", expected, actual);
        break;
      case STRUCT:
        Assert.assertTrue("Should expect a Record", expected instanceof Record);
        Assert.assertTrue("Should be a Row", actual instanceof Row);
        assertEqualsSafe(type.asNestedType().asStructType(), (Record) expected, (Row) actual);
        break;
      case LIST:
        Assert.assertTrue("Should expect a Collection", expected instanceof Collection);
        Assert.assertTrue("Should be a Seq", actual instanceof Seq);
        List<?> asList = seqAsJavaListConverter((Seq<?>) actual).asJava();
        assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList);
        break;
      case MAP:
        Assert.assertTrue("Should expect a Collection", expected instanceof Map);
        Assert.assertTrue("Should be a Map", actual instanceof scala.collection.Map);
        Map<String, ?> asMap = mapAsJavaMapConverter(
            (scala.collection.Map<String, ?>) actual).asJava();
        assertEqualsSafe(type.asNestedType().asMapType(), (Map<String, ?>) expected, asMap);
        break;
      case TIME:
      default:
        throw new IllegalArgumentException("Not a supported type: " + type);
    }
  }

  public static void assertEqualsUnsafe(Types.StructType struct, Record rec, InternalRow row) {
    List<Types.NestedField> fields = struct.fields();
    for (int i = 0; i < fields.size(); i += 1) {
      Type fieldType = fields.get(i).type();

      Object expectedValue = rec.get(i);
      Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType));

      assertEqualsUnsafe(fieldType, expectedValue, actualValue);
    }
  }

  private static void assertEqualsUnsafe(Types.ListType list, Collection<?> expected, ArrayData actual) {
    Type elementType = list.elementType();
    List<?> expectedElements = Lists.newArrayList(expected);
    for (int i = 0; i < expectedElements.size(); i += 1) {
      Object expectedValue = expectedElements.get(i);
      Object actualValue = actual.get(i, convert(elementType));

      assertEqualsUnsafe(elementType, expectedValue, actualValue);
    }
  }

  private static void assertEqualsUnsafe(Types.MapType map, Map<?, ?> expected, MapData actual) {
    Type keyType = map.keyType();
    Type valueType = map.valueType();

    List<Map.Entry<?, ?>> expectedElements = Lists.newArrayList(expected.entrySet());
    ArrayData actualKeys = actual.keyArray();
    ArrayData actualValues = actual.valueArray();

    for (int i = 0; i < expectedElements.size(); i += 1) {
      Map.Entry<?, ?> expectedPair = expectedElements.get(i);
      Object actualKey = actualKeys.get(i, convert(keyType));
      Object actualValue = actualValues.get(i, convert(keyType));

      assertEqualsUnsafe(keyType, expectedPair.getKey(), actualKey);
      assertEqualsUnsafe(valueType, expectedPair.getValue(), actualValue);
    }
  }

  private static void assertEqualsUnsafe(Type type, Object expected, Object actual) {
    if (expected == null && actual == null) {
      return;
    }

    switch (type.typeId()) {
      case BOOLEAN:
      case INTEGER:
      case LONG:
      case FLOAT:
      case DOUBLE:
      case DATE:
      case TIMESTAMP:
        Assert.assertEquals("Primitive value should be equal to expected", expected, actual);
        break;
      case STRING:
        Assert.assertTrue("Should be a UTF8String", actual instanceof UTF8String);
        Assert.assertEquals("Strings should be equal", expected, actual.toString());
        break;
      case UUID:
        Assert.assertTrue("Should expect a UUID", expected instanceof UUID);
        Assert.assertTrue("Should be a UTF8String", actual instanceof UTF8String);
        Assert.assertEquals("UUID string representation should match",
            expected.toString(), actual.toString());
        break;
      case FIXED:
        Assert.assertTrue("Should expect a Fixed", expected instanceof GenericData.Fixed);
        Assert.assertTrue("Should be a byte[]", actual instanceof byte[]);
        Assert.assertArrayEquals("Bytes should match",
            ((GenericData.Fixed) expected).bytes(), (byte[]) actual);
        break;
      case BINARY:
        Assert.assertTrue("Should expect a ByteBuffer", expected instanceof ByteBuffer);
        Assert.assertTrue("Should be a byte[]", actual instanceof byte[]);
        Assert.assertArrayEquals("Bytes should match",
            ((ByteBuffer) expected).array(), (byte[]) actual);
        break;
      case DECIMAL:
        Assert.assertTrue("Should expect a BigDecimal", expected instanceof BigDecimal);
        Assert.assertTrue("Should be a Decimal", actual instanceof Decimal);
        Assert.assertEquals("BigDecimals should be equal",
            expected, ((Decimal) actual).toJavaBigDecimal());
        break;
      case STRUCT:
        Assert.assertTrue("Should expect a Record", expected instanceof Record);
        Assert.assertTrue("Should be an InternalRow", actual instanceof InternalRow);
        assertEqualsUnsafe(type.asNestedType().asStructType(), (Record) expected, (InternalRow) actual);
        break;
      case LIST:
        Assert.assertTrue("Should expect a Collection", expected instanceof Collection);
        Assert.assertTrue("Should be an ArrayData", actual instanceof ArrayData);
        assertEqualsUnsafe(type.asNestedType().asListType(), (Collection) expected, (ArrayData) actual);
        break;
      case MAP:
        Assert.assertTrue("Should expect a Map", expected instanceof Map);
        Assert.assertTrue("Should be an ArrayBasedMapData", actual instanceof MapData);
        assertEqualsUnsafe(type.asNestedType().asMapType(), (Map) expected, (MapData) actual);
        break;
      case TIME:
      default:
        throw new IllegalArgumentException("Not a supported type: " + type);
    }
  }

  /**
   * Check that the given InternalRow is equivalent to the Row.
   * @param prefix context for error messages
   * @param type the type of the row
   * @param expected the expected value of the row
   * @param actual the actual value of the row
   */
  public static void assertEquals(String prefix, Types.StructType type,
                                  InternalRow expected, Row actual) {
    if (expected == null || actual == null) {
      Assert.assertEquals(prefix, expected, actual);
    } else {
      List<Types.NestedField> fields = type.fields();
      for (int c = 0; c < fields.size(); ++c) {
        String fieldName = fields.get(c).name();
        Type childType = fields.get(c).type();
        switch (childType.typeId()) {
          case BOOLEAN:
          case INTEGER:
          case LONG:
          case FLOAT:
          case DOUBLE:
          case STRING:
          case DECIMAL:
          case DATE:
          case TIMESTAMP:
            Assert.assertEquals(prefix + "." + fieldName + " - " + childType,
                getValue(expected, c, childType),
                getPrimitiveValue(actual, c, childType));
            break;
          case UUID:
          case FIXED:
          case BINARY:
            assertEqualBytes(prefix + "." + fieldName,
                (byte[]) getValue(expected, c, childType),
                (byte[]) actual.get(c));
            break;
          case STRUCT: {
            Types.StructType st = (Types.StructType) childType;
            assertEquals(prefix + "." + fieldName, st,
                expected.getStruct(c, st.fields().size()), actual.getStruct(c));
            break;
          }
          case LIST:
            assertEqualsLists(prefix + "." + fieldName, childType.asListType(),
                expected.getArray(c),
                toList((Seq<?>) actual.get(c)));
            break;
          case MAP:
            assertEqualsMaps(prefix + "." + fieldName, childType.asMapType(), expected.getMap(c),
                toJavaMap((scala.collection.Map<?, ?>) actual.getMap(c)));
            break;
          default:
            throw new IllegalArgumentException("Unhandled type " + childType);
        }
      }
    }
  }

  private static void assertEqualsLists(String prefix, Types.ListType type,
                                        ArrayData expected, List actual) {
    if (expected == null || actual == null) {
      Assert.assertEquals(prefix, expected, actual);
    } else {
      Assert.assertEquals(prefix + " length", expected.numElements(), actual.size());
      Type childType = type.elementType();
      for (int e = 0; e < expected.numElements(); ++e) {
        switch (childType.typeId()) {
          case BOOLEAN:
          case INTEGER:
          case LONG:
          case FLOAT:
          case DOUBLE:
          case STRING:
          case DECIMAL:
          case DATE:
          case TIMESTAMP:
            Assert.assertEquals(prefix + ".elem " + e + " - " + childType,
                getValue(expected, e, childType),
                actual.get(e));
            break;
          case UUID:
          case FIXED:
          case BINARY:
            assertEqualBytes(prefix + ".elem " + e,
                (byte[]) getValue(expected, e, childType),
                (byte[]) actual.get(e));
            break;
          case STRUCT: {
            Types.StructType st = (Types.StructType) childType;
            assertEquals(prefix + ".elem " + e, st,
                expected.getStruct(e, st.fields().size()), (Row) actual.get(e));
            break;
          }
          case LIST:
            assertEqualsLists(prefix + ".elem " + e, childType.asListType(),
                expected.getArray(e),
                toList((Seq<?>) actual.get(e)));
            break;
          case MAP:
            assertEqualsMaps(prefix + ".elem " + e, childType.asMapType(),
                expected.getMap(e), toJavaMap((scala.collection.Map<?, ?>) actual.get(e)));
            break;
          default:
            throw new IllegalArgumentException("Unhandled type " + childType);
        }
      }
    }
  }

  private static void assertEqualsMaps(String prefix, Types.MapType type,
                                       MapData expected, Map<?, ?> actual) {
    if (expected == null || actual == null) {
      Assert.assertEquals(prefix, expected, actual);
    } else {
      Type keyType = type.keyType();
      Type valueType = type.valueType();
      ArrayData expectedKeyArray = expected.keyArray();
      ArrayData expectedValueArray = expected.valueArray();
      Assert.assertEquals(prefix + " length", expected.numElements(), actual.size());
      for (int e = 0; e < expected.numElements(); ++e) {
        Object expectedKey = getValue(expectedKeyArray, e, keyType);
        Object actualValue = actual.get(expectedKey);
        if (actualValue == null) {
          Assert.assertEquals(prefix + ".key=" + expectedKey + " has null", true,
              expected.valueArray().isNullAt(e));
        } else {
          switch (valueType.typeId()) {
            case BOOLEAN:
            case INTEGER:
            case LONG:
            case FLOAT:
            case DOUBLE:
            case STRING:
            case DECIMAL:
            case DATE:
            case TIMESTAMP:
              Assert.assertEquals(prefix + ".key=" + expectedKey + " - " + valueType,
                  getValue(expectedValueArray, e, valueType),
                  actual.get(expectedKey));
              break;
            case UUID:
            case FIXED:
            case BINARY:
              assertEqualBytes(prefix + ".key=" + expectedKey,
                  (byte[]) getValue(expectedValueArray, e, valueType),
                  (byte[]) actual.get(expectedKey));
              break;
            case STRUCT: {
              Types.StructType st = (Types.StructType) valueType;
              assertEquals(prefix + ".key=" + expectedKey, st,
                  expectedValueArray.getStruct(e, st.fields().size()),
                  (Row) actual.get(expectedKey));
              break;
            }
            case LIST:
              assertEqualsLists(prefix + ".key=" + expectedKey,
                  valueType.asListType(),
                  expectedValueArray.getArray(e),
                  toList((Seq<?>) actual.get(expectedKey)));
              break;
            case MAP:
              assertEqualsMaps(prefix + ".key=" + expectedKey, valueType.asMapType(),
                  expectedValueArray.getMap(e),
                  toJavaMap((scala.collection.Map<?, ?>) actual.get(expectedKey)));
              break;
            default:
              throw new IllegalArgumentException("Unhandled type " + valueType);
          }
        }
      }
    }
  }

  private static Object getValue(SpecializedGetters container, int ord,
                                 Type type) {
    if (container.isNullAt(ord)) {
      return null;
    }
    switch (type.typeId()) {
      case BOOLEAN:
        return container.getBoolean(ord);
      case INTEGER:
        return container.getInt(ord);
      case LONG:
        return container.getLong(ord);
      case FLOAT:
        return container.getFloat(ord);
      case DOUBLE:
        return container.getDouble(ord);
      case STRING:
        return container.getUTF8String(ord).toString();
      case BINARY:
      case FIXED:
      case UUID:
        return container.getBinary(ord);
      case DATE:
        return new DateWritable(container.getInt(ord)).get();
      case TIMESTAMP:
        return DateTimeUtils.toJavaTimestamp(container.getLong(ord));
      case DECIMAL: {
        Types.DecimalType dt = (Types.DecimalType) type;
        return container.getDecimal(ord, dt.precision(), dt.scale()).toJavaBigDecimal();
      }
      case STRUCT:
        Types.StructType struct = type.asStructType();
        InternalRow internalRow = container.getStruct(ord, struct.fields().size());
        Object[] data = new Object[struct.fields().size()];
        for (int i = 0; i < data.length; i += 1) {
          if (internalRow.isNullAt(i)) {
            data[i] = null;
          } else {
            data[i] = getValue(internalRow, i, struct.fields().get(i).type());
          }
        }
        return new GenericRow(data);
      default:
        throw new IllegalArgumentException("Unhandled type " + type);
    }
  }

  private static Object getPrimitiveValue(Row row, int ord, Type type) {
    if (row.isNullAt(ord)) {
      return null;
    }
    switch (type.typeId()) {
      case BOOLEAN:
        return row.getBoolean(ord);
      case INTEGER:
        return row.getInt(ord);
      case LONG:
        return row.getLong(ord);
      case FLOAT:
        return row.getFloat(ord);
      case DOUBLE:
        return row.getDouble(ord);
      case STRING:
        return row.getString(ord);
      case BINARY:
      case FIXED:
      case UUID:
        return row.get(ord);
      case DATE:
        return row.getDate(ord);
      case TIMESTAMP:
        return row.getTimestamp(ord);
      case DECIMAL:
        return row.getDecimal(ord);
      default:
        throw new IllegalArgumentException("Unhandled type " + type);
    }
  }

  private static <K, V> Map<K, V> toJavaMap(scala.collection.Map<K, V> map) {
    return map == null ? null : mapAsJavaMapConverter(map).asJava();
  }

  private static List toList(Seq<?> val) {
    return val == null ? null : seqAsJavaListConverter(val).asJava();
  }

  private static void assertEqualBytes(String context, byte[] expected,
                                       byte[] actual) {
    if (expected == null || actual == null) {
      Assert.assertEquals(context, expected, actual);
    } else {
      Assert.assertArrayEquals(context, expected, actual);
    }
  }

  static void assertEquals(Schema schema, Object expected, Object actual) {
    assertEquals("schema", convert(schema), expected, actual);
  }

  private static void assertEquals(String context, DataType type, Object expected, Object actual) {
    if (expected == null && actual == null) {
      return;
    }

    if (type instanceof StructType) {
      Assert.assertTrue("Expected should be an InternalRow: " + context,
          expected instanceof InternalRow);
      Assert.assertTrue("Actual should be an InternalRow: " + context,
          actual instanceof InternalRow);
      assertEquals(context, (StructType) type, (InternalRow) expected, (InternalRow) actual);

    } else if (type instanceof ArrayType) {
      Assert.assertTrue("Expected should be an ArrayData: " + context,
          expected instanceof ArrayData);
      Assert.assertTrue("Actual should be an ArrayData: " + context,
          actual instanceof ArrayData);
      assertEquals(context, (ArrayType) type, (ArrayData) expected, (ArrayData) actual);

    } else if (type instanceof MapType) {
      Assert.assertTrue("Expected should be a MapData: " + context,
          expected instanceof MapData);
      Assert.assertTrue("Actual should be a MapData: " + context,
          actual instanceof MapData);
      assertEquals(context, (MapType) type, (MapData) expected, (MapData) actual);

    } else if (type instanceof BinaryType) {
      assertEqualBytes(context, (byte[]) expected, (byte[]) actual);
    } else {
      Assert.assertEquals("Value should match expected: " + context, expected, actual);
    }
  }

  private static void assertEquals(String context, StructType struct,
                           InternalRow expected, InternalRow actual) {
    Assert.assertEquals("Should have correct number of fields", struct.size(), actual.numFields());
    for (int i = 0; i < actual.numFields(); i += 1) {
      StructField field = struct.fields()[i];
      DataType type = field.dataType();
      assertEquals(context + "." + field.name(), type, expected.get(i, type), actual.get(i, type));
    }
  }

  private static void assertEquals(String context, ArrayType array, ArrayData expected, ArrayData actual) {
    Assert.assertEquals("Should have the same number of elements",
        expected.numElements(), actual.numElements());
    DataType type = array.elementType();
    for (int i = 0; i < actual.numElements(); i += 1) {
      assertEquals(context + ".element", type, expected.get(i, type), actual.get(i, type));
    }
  }

  private static void assertEquals(String context, MapType map, MapData expected, MapData actual) {
    Assert.assertEquals("Should have the same number of elements",
        expected.numElements(), actual.numElements());

    DataType keyType = map.keyType();
    ArrayData expectedKeys = expected.keyArray();
    ArrayData expectedValues = expected.valueArray();

    DataType valueType = map.valueType();
    ArrayData actualKeys = actual.keyArray();
    ArrayData actualValues = actual.valueArray();

    for (int i = 0; i < actual.numElements(); i += 1) {
      assertEquals(context + ".key", keyType,
          expectedKeys.get(i, keyType), actualKeys.get(i, keyType));
      assertEquals(context + ".value", valueType,
          expectedValues.get(i, valueType), actualValues.get(i, valueType));
    }
  }
}