"""Registry for the TF Encrypted Converter."""
import array
import logging
import os
from collections import OrderedDict
from typing import Any
from typing import List

import numpy as np
import tensorflow as tf
import yaml

import tf_encrypted as tfe

from ..keras.layers import BatchNormalization
from ..keras.layers import DepthwiseConv2D
from ..keras.layers import GlobalAveragePooling2D
from ..keras.layers import GlobalMaxPooling2D
from ..layers import AveragePooling2D
from ..layers import Conv2D
from ..layers import Dense
from ..layers import MaxPooling2D
from ..layers import Relu
from ..layers import Sigmoid
from ..protocol.pond import PondMaskedTensor
from ..protocol.pond import PondPrivateTensor


def registry():
    """Map reserved names and scopes to their conversion functions."""
    reg = {
        "Placeholder": _placeholder,
        "Const": _constant,
        "Conv2D": _conv2d,
        "Relu": _relu,
        "Sigmoid": _sigmoid,
        "MatMul": _matmul,
        "Shape": _shape,
        "StridedSlice": _strided_slice,
        "Add": _add,
        "AddV2": _add,
        "Sub": _sub,
        "Transpose": _transpose,
        "Reshape": _reshape,
        "Pack": _pack,
        "Rsqrt": _rsqrt,
        "Mul": _mul,
        "ExpandDims": _expand_dims,
        "AvgPool": _avgpool,
        "Squeeze": _squeeze,
        "ConcatV2": _concat,
        "BiasAdd": _bias_add,
        "MaxPool": _maxpool,
        "Pad": _pad,
        "BatchToSpaceND": _batch_to_space_nd,
        "SpaceToBatchND": _space_to_batch_nd,
        "ArgMax": _argmax,
        "required_space_to_batch_paddings": _required_space_to_batch_paddings,
        "flatten": _flatten,
        "conv2d": _keras_conv2d,
        "Slice": _slice,
        "Neg": _negative,
        "Split": _split,
        "SplitV": _split,
        "Identity": _identity,
        "GatherV2": _gather,
        "dense": _keras_dense,
        "batch_normalization_v1": _keras_batchnorm,
        "depthwise_conv2d": _keras_depthwise_conv2d,
        "Mean": _keras_global_avgpool,
        "Max": _keras_global_maxpool,
    }

    return reg


convert_dir = os.path.dirname(os.path.abspath(__file__))
specops_path = os.path.join(convert_dir, "specops.yaml")
with open(specops_path, "r") as stream:
    loaded_yaml = yaml.load(stream, Loader=yaml.SafeLoader)
    sorted_yaml = sorted(loaded_yaml.items(), key=lambda kv: kv[0])
    REGISTERED_SPECOPS = OrderedDict(sorted_yaml)


# pylint: disable=unused-argument
# pylint: disable=missing-docstring
def _placeholder(converter, node: Any, inputs: List[str]) -> Any:
    return tf.placeholder(node.attr["dtype"].type, shape=node.attr["shape"].shape)


def _constant(converter, node: Any, inputs: List[str]) -> Any:
    # need to able to access the underlying weights return the node
    return node


def _identity(converter, node: Any, inputs: List[str]) -> Any:
    # need to able to access the underlying weights return the node
    return converter.outputs[inputs[0]]


def _matmul(converter, node: Any, inputs: List[str]) -> Any:
    a = converter.outputs[inputs[0]]
    b = converter.outputs[inputs[1]]

    tensor = b.attr["value"].tensor

    b_shape = [i.size for i in tensor.tensor_shape.dim]

    transpose_a = node.attr["transpose_a"].b
    transpose_b = node.attr["transpose_b"].b

    layer = Dense(
        a.shape.as_list(),
        b_shape[1],
        transpose_input=transpose_a,
        transpose_weight=transpose_b,
    )

    dtype = tensor.dtype

    if dtype == tf.float32:
        nums = array.array("f", tensor.tensor_content)
    elif dtype == tf.float64:
        nums = array.array("d", tensor.tensor_content)
    else:
        raise TypeError("Unsupported dtype for weights")

    def inputter_fn():
        return tf.constant(np.array(nums).reshape(b_shape))

    w = tfe.define_private_input(converter.model_provider, inputter_fn)

    layer.initialize(initial_weights=w)

    return layer.forward(a)


def _conv2d(converter, node, inputs):
    x_in = converter.outputs[inputs[0]]
    kernel = converter.outputs[inputs[1]]

    if isinstance(kernel, tf.NodeDef):
        shape = [i.size for i in kernel.attr["value"].tensor.tensor_shape.dim]
        w = _nodef_to_private_pond(converter, kernel)
    else:
        shape = kernel.shape.as_list()
        w = kernel

    fmt = node.attr["data_format"].s.decode("ascii")

    layer = Conv2D(
        x_in.shape.as_list(),
        shape,
        strides=int(max(node.attr["strides"].list.i)),
        padding=node.attr["padding"].s.decode("ascii"),
        channels_first=fmt == "NCHW",
    )

    layer.initialize(initial_weights=w)

    out = layer.forward(x_in)

    return out


def _keras_conv2d(converter, interiors, inputs):
    x_in = converter.outputs[inputs[0]]

    conv_op = interiors["Conv2D"]
    kernel = interiors["kernel"]
    k = _nodef_to_private_pond(converter, kernel)
    try:
        bias = interiors["bias"]
        b = _nodef_to_private_pond(converter, bias)
        for ax in [0, -1, -1]:
            b = b.expand_dims(axis=ax)
    except KeyError:
        b = None

    input_shape = x_in.shape.as_list()
    shape = [i.size for i in kernel.attr["value"].tensor.tensor_shape.dim]
    fmt = conv_op.attr["data_format"].s.decode("ascii")
    strides = int(max(conv_op.attr["strides"].list.i))
    padding = conv_op.attr["padding"].s.decode("ascii")

    layer = Conv2D(
        input_shape,
        shape,
        strides=strides,
        padding=padding,
        channels_first=fmt == "NCHW",
    )

    layer.initialize(initial_weights=k, initial_bias=b)
    out = layer.forward(x_in)

    return out


def _keras_depthwise_conv2d(converter, interiors, inputs):
    x_in = converter.outputs[inputs[0]]

    conv_op = interiors["depthwise"]

    kernel = interiors["depthwise_kernel"]
    k = _nodef_to_numpy_array(kernel)
    kernel_init = tf.keras.initializers.Constant(k)

    try:
        bias = interiors["bias"]
        b = _nodef_to_numpy_array(bias)
        bias_init = tf.keras.initializers.Constant(b)
        use_bias = True
    except KeyError:
        use_bias = False
        bias_init = "zeros"

    shape = [i.size for i in kernel.attr["value"].tensor.tensor_shape.dim]

    fmt = conv_op.attr["data_format"].s.decode("ascii")
    fmt = "channels_last" if fmt == "NHWC" else "channels_first"

    strides = int(max(conv_op.attr["strides"].list.i))
    padding = conv_op.attr["padding"].s.decode("ascii")

    layer = DepthwiseConv2D(
        kernel_size=(shape[0], shape[1]),
        strides=strides,
        padding=padding,
        depth_multiplier=1,
        data_format=fmt,
        use_bias=use_bias,
        depthwise_initializer=kernel_init,
        bias_initializer=bias_init,
    )

    return layer(x_in)


def _keras_dense(converter, interiors, inputs):
    x_in = converter.outputs[inputs[0]]

    kernel = interiors["kernel"]
    k = _nodef_to_private_pond(converter, kernel)
    try:
        bias = interiors["bias"]
        b = _nodef_to_private_pond(converter, bias)
    except KeyError:
        b = None

    input_shape = x_in.shape.as_list()
    shape = [i.size for i in kernel.attr["value"].tensor.tensor_shape.dim]

    layer = Dense(input_shape, out_features=shape[1])

    layer.initialize(initial_weights=k, initial_bias=b)
    out = layer.forward(x_in)

    return out


def _keras_batchnorm(converter, interiors, inputs):
    x_in = converter.outputs[inputs[0]]

    bn_op = interiors["FusedBatchNorm"]
    fmt = bn_op.attr["data_format"].s.decode("ascii")

    gamma = _nodef_to_numpy_array(interiors["gamma"])
    gamma_init = tf.keras.initializers.Constant(gamma)

    beta = _nodef_to_numpy_array(interiors["beta"])
    beta_init = tf.keras.initializers.Constant(beta)

    moving_mean = _nodef_to_numpy_array(interiors["moving_mean"])
    moving_mean_init = tf.keras.initializers.Constant(moving_mean)

    moving_variance = _nodef_to_numpy_array(interiors["moving_variance"])
    moving_variance_init = tf.keras.initializers.Constant(moving_variance)

    input_shape = x_in.shape.as_list()

    layer = BatchNormalization(
        input_shape=input_shape,
        axis=(3 if fmt == "NHWC" else 1),
        gamma_initializer=gamma_init,
        beta_initializer=beta_init,
        moving_mean_initializer=moving_mean_init,
        moving_variance_initializer=moving_variance_init,
    )

    return layer(x_in)


def _relu(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    return Relu(x_in.shape.as_list()).forward(x_in)


def _sigmoid(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    return Sigmoid(x_in.shape.as_list()).forward(x_in)


def _strided_slice(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    if isinstance(x_in, tf.NodeDef):
        input_out = _nodef_to_private_pond(converter, x_in)
    else:
        input_out = x_in

    begin = converter.outputs[inputs[1]]
    end = converter.outputs[inputs[2]]
    strides = converter.outputs[inputs[3]]

    begin_mask = node.attr["begin_mask"].i
    end_mask = node.attr["end_mask"].i
    ellipsis_mask = node.attr["ellipsis_mask"].i
    new_axis_mask = node.attr["new_axis_mask"].i
    shrink_axis_mask = node.attr["shrink_axis_mask"].i

    begin = tf.constant(begin.attr["value"].tensor)
    end = tf.constant(end.attr["value"].tensor)
    strides = tf.constant(strides.attr["value"].tensor)

    return tfe.strided_slice(
        input_out,
        begin,
        end,
        strides=strides,
        begin_mask=begin_mask,
        end_mask=end_mask,
        ellipsis_mask=ellipsis_mask,
        new_axis_mask=new_axis_mask,
        shrink_axis_mask=shrink_axis_mask,
    )


def _pack(converter, node: Any, inputs: List[str]) -> Any:
    final_inputs = []

    for x_in in inputs:
        input_c = converter.outputs[x_in]
        if isinstance(input_c, tf.NodeDef):
            final_inputs.append(_nodef_to_private_pond(converter, input_c))
        else:
            final_inputs.append(input_c)

    return tfe.stack(final_inputs, axis=node.attr["axis"].i)


def _bias_add(converter, node: Any, inputs: List[str]) -> Any:
    a = converter.outputs[inputs[0]]
    b = converter.outputs[inputs[1]]

    if isinstance(a, tf.NodeDef):
        a_out = _nodef_to_private_pond(converter, a)
    else:
        a_out = a

    if isinstance(b, tf.NodeDef):
        b_out = _nodef_to_private_pond(converter, b)
    else:
        b_out = b

    return tfe.add(a_out, b_out)


def _maxpool(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    ksize = node.attr["ksize"].list.i
    s = node.attr["strides"].list.i

    padding = node.attr["padding"].s.decode("ascii")
    pool_size = [ksize[1], ksize[2]]
    strides = [s[1], s[2]]

    shape = [int(i) for i in x_in.shape]

    channels_first = node.attr["data_format"].s.decode("ascii") == "NCHW"

    pooler = MaxPooling2D(shape, pool_size, strides, padding, channels_first)

    out = pooler.forward(x_in)

    return out


def _shape(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    return x_in.shape


def _reshape(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]
    shape = converter.outputs[inputs[1]]

    tensor = shape.attr["value"].tensor
    dtype = shape.attr["dtype"].type
    if dtype == tf.int32:
        nums = array.array("i", tensor.tensor_content)
    elif dtype == tf.int64:
        nums = array.array("l", tensor.tensor_content)
    else:
        raise TypeError("Unsupported dtype for reshape shape")

    return tfe.reshape(x_in, list(nums))


def _transpose(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]
    perm = converter.outputs[inputs[1]]

    tensor = perm.attr["value"].tensor
    shape = [i.size for i in tensor.tensor_shape.dim]

    dtype = perm.attr["dtype"].type
    if dtype == tf.int32:
        nums = array.array("i", tensor.tensor_content)
    elif dtype == tf.int64:
        nums = array.array("l", tensor.tensor_content)
    else:
        raise TypeError("Unsupported dtype for transpose perm")

    return tfe.transpose(x_in, np.array(nums).reshape(shape))


def _expand_dims(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    if isinstance(x_in, tf.NodeDef):
        input_out = _nodef_to_private_pond(converter, x_in)
    else:
        input_out = x_in

    input_axis = converter.outputs[inputs[1]]
    axis_attr = input_axis.attr["value"].tensor.int_val
    axis_val = array.array("i", axis_attr)[0]

    return tfe.expand_dims(input_out, axis_val)


def _negative(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    if isinstance(x_in, tf.NodeDef):
        input_out = _nodef_to_private_pond(converter, x_in)
    else:
        input_out = x_in

    return tfe.negative(input_out)


def _gather(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]
    indices = converter.outputs[inputs[1]]
    axis = converter.outputs[inputs[2]]

    if isinstance(x_in, tf.NodeDef):
        input_out = _nodef_to_private_pond(converter, x_in)
    else:
        input_out = x_in

    indices_out = list(_nodef_to_numpy_array(indices))

    axis_val = axis.attr["value"].tensor.int_val[0]

    return tfe.gather(input_out, indices_out, axis_val)


def _squeeze(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    axis = node.attr["squeeze_dims"].list.i

    return tfe.squeeze(x_in, list(axis))


def _split(converter, node: Any, inputs: List[str]) -> Any:
    if node.op == "SplitV":
        # node.op is SplitV when num_or_size_splits is a list
        x_in = converter.outputs[inputs[0]]
        size_splits = converter.outputs[inputs[1]]
        axis = converter.outputs[inputs[2]]

        size_splits = size_splits.attr["value"].tensor
        num_or_size_splits = list(array.array("I", size_splits.tensor_content))

    else:
        # node.op is Split when num_or_size_splits is an integer
        axis = converter.outputs[inputs[0]]
        x_in = converter.outputs[inputs[1]]

        num_or_size_splits = node.attr["num_split"].i

    if isinstance(x_in, tf.NodeDef):
        input_out = _nodef_to_private_pond(converter, x_in)
    else:
        input_out = x_in

    axis_val = axis.attr["value"].tensor.int_val[0]

    return tfe.split(input_out, num_or_size_splits, axis_val)


def _pad(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]
    p = converter.outputs[inputs[1]]

    paddings_t = p.attr["value"].tensor

    paddings_arr = list(array.array("I", paddings_t.tensor_content))
    paddings_lst = [paddings_arr[i : i + 2] for i in range(0, len(paddings_arr), 2)]

    return tfe.pad(x_in, paddings_lst)


def _rsqrt(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    if isinstance(x_in, tf.NodeDef):
        tensor = x_in.attr["value"].tensor
        shape = [i.size for i in tensor.tensor_shape.dim]

        dtype = x_in.attr["dtype"].type
        if dtype == tf.float32:
            nums = array.array("f", tensor.tensor_content)
        elif dtype == tf.float64:
            nums = array.array("d", tensor.tensor_content)

        else:
            raise TypeError("Unsupported dtype for rsqrt")

        def inputter_fn():
            return tf.constant(1 / np.sqrt(np.array(nums).reshape(shape)))

    else:
        # XXX this is a little weird but the input into rsqrt is public and
        # being used only for batchnorm at the moment
        prot = tfe.get_protocol()
        # pylint: disable=protected-access
        decoded = prot._decode(x_in.value_on_0, True)

        # pylint: enable=protected-access

        def inputter_fn():
            return tf.rsqrt(decoded)

    x = tfe.define_public_input(converter.model_provider, inputter_fn)

    return x


def _add(converter, node: Any, inputs: List[str]) -> Any:
    a = converter.outputs[inputs[0]]
    b = converter.outputs[inputs[1]]

    if isinstance(a, tf.NodeDef):
        a_out = _nodef_to_public_pond(converter, a)
    else:
        a_out = a

    if isinstance(b, tf.NodeDef):
        b_out = _nodef_to_public_pond(converter, b)
    else:
        b_out = b

    return tfe.add(a_out, b_out)


def _sub(converter, node: Any, inputs: List[str]) -> Any:
    a = converter.outputs[inputs[0]]
    b = converter.outputs[inputs[1]]

    if isinstance(a, tf.NodeDef):
        a_out = _nodef_to_public_pond(converter, a)
    else:
        a_out = a

    if isinstance(b, tf.NodeDef):
        b_out = _nodef_to_public_pond(converter, b)
    else:
        b_out = b

    return tfe.sub(a_out, b_out)


def _mul(converter, node: Any, inputs: List[str]) -> Any:
    a = converter.outputs[inputs[0]]
    b = converter.outputs[inputs[1]]

    if isinstance(a, tf.NodeDef):
        a_out = _nodef_to_public_pond(converter, a)
    else:
        a_out = a

    if isinstance(b, tf.NodeDef):
        b_out = _nodef_to_public_pond(converter, b)
    else:
        b_out = b

    return tfe.mul(a_out, b_out)


def _avgpool(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    ksize = node.attr["ksize"].list.i
    s = node.attr["strides"].list.i

    padding = node.attr["padding"].s.decode("ascii")
    pool_size = [ksize[1], ksize[2]]
    strides = [s[1], s[2]]

    shape = [int(i) for i in x_in.shape]

    channels_first = node.attr["data_format"].s.decode("ascii") == "NCHW"

    avg = AveragePooling2D(shape, pool_size, strides, padding, channels_first)

    out = avg.forward(x_in)

    return out


def _keras_global_avgpool(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    content = converter.outputs[inputs[1]].attr["value"].tensor.tensor_content
    reduction_indices = array.array("i", content)

    if reduction_indices == array.array("i", [1, 2]):
        data_format = "channels_last"
    else:
        data_format = "channels_first"

    layer = GlobalAveragePooling2D(data_format=data_format)

    return layer(x_in)


def _keras_global_maxpool(converter, node: Any, inputs: List[str]) -> Any:
    x_in = converter.outputs[inputs[0]]

    content = converter.outputs[inputs[1]].attr["value"].tensor.tensor_content
    reduction_indices = array.array("i", content)

    if reduction_indices == array.array("i", [1, 2]):
        data_format = "channels_last"
    else:
        data_format = "channels_first"

    layer = GlobalMaxPooling2D(data_format=data_format)

    return layer(x_in)


def _concat(converter, node: Any, inputs: List[str]) -> Any:
    input_list = [converter.outputs[inputs[i]] for i in range(len(inputs) - 1)]
    axis = converter.outputs[inputs[-1]]
    axis_int = axis.attr["value"].tensor.int_val[0]

    return tfe.concat(input_list, axis_int)


def _batch_to_space_nd(converter, node, inputs):
    x_in = converter.outputs[inputs[0]]
    block_shape = converter.outputs[inputs[1]].attr["value"].tensor
    crops = converter.outputs[inputs[2]].attr["value"].tensor

    return tfe.batch_to_space_nd(x_in, block_shape, crops)


def _space_to_batch_nd(converter, node, inputs):
    x_in = converter.outputs[inputs[0]]
    block_shape = converter.outputs[inputs[1]].attr["value"].tensor
    paddings = converter.outputs[inputs[2]].attr["value"].tensor

    return tfe.space_to_batch_nd(x_in, block_shape, paddings)


def _flatten(converter, node, inputs):
    x_in = converter.outputs[inputs[0]]

    shape = x_in.shape.as_list()
    non_batch = 1
    for dim in shape[1:]:
        non_batch *= dim

    return tfe.reshape(x_in, [-1, non_batch])


def _required_space_to_batch_paddings(converter, node, inputs: List[str]):

    inputs_node = [converter.outputs[inputs[i]] for i in range(len(inputs))]
    inputs_int32 = []
    for x_in in inputs_node:
        pvt_check = isinstance(x_in, PondPrivateTensor)
        msk_check = isinstance(x_in, PondMaskedTensor)
        if pvt_check or msk_check:
            logging.warning(
                (
                    "Revealing private input: "
                    "required_space_to_batch_paddings assumes public "
                    "input."
                )
            )
            inputs_int32.append(tf.cast(x_in.reveal().decode(), tf.int32))
        elif isinstance(x_in, tf.NodeDef):
            inputs_int32.append(_nodef_to_numpy_array(x_in))
        else:
            raise TypeError("Unexpected input of type {}.".format(type(x_in)))

    if len(inputs_int32) == 2:
        input_shape, block_shape = inputs_int32

        def inputter_pad():
            pads, _ = tf.required_space_to_batch_paddings(input_shape, block_shape)
            return tf.cast(pads, tf.float64)

        def inputter_crop():
            _, crops = tf.required_space_to_batch_paddings(input_shape, block_shape)
            return tf.cast(crops, tf.float64)

    else:
        base_paddings, input_shape, block_shape = inputs_int32

        def inputter_pad():
            pads, _ = tf.required_space_to_batch_paddings(
                input_shape, block_shape, base_paddings=base_paddings,
            )
            return tf.cast(pads, tf.float64)

        def inputter_crop():
            _, crops = tf.required_space_to_batch_paddings(
                input_shape, block_shape, base_paddings=base_paddings,
            )
            return tf.cast(crops, tf.float64)

    pad_private = tfe.define_public_input(converter.model_provider, inputter_pad,)
    crop_private = tfe.define_public_input(converter.model_provider, inputter_crop,)

    return (pad_private, crop_private)


def _argmax(converter, node, inputs):
    x_in = converter.outputs[inputs[0]]
    axis = converter.outputs[inputs[1]].attr["value"].tensor.int_val[0]

    return tfe.argmax(x_in, axis=axis)


def _slice(converter, node, inputs):
    x_in = converter.outputs[inputs[0]]
    begin = _nodef_to_numpy_array(converter.outputs[inputs[1]])
    size = _nodef_to_numpy_array(converter.outputs[inputs[2]])

    if isinstance(x_in, tf.NodeDef):
        input_out = _nodef_to_private_pond(converter, x_in)
    else:
        input_out = x_in

    # Slice is a special case of strided_slice. Slice takes size (the number of
    # elements we want to slice) as an input. However strided_slice takes end
    # (integer until which the slicing takes place) as input.
    # We can infere the end parameter with : end[i] = begin[i] + size[i].
    # If size is negative, the stepping go towards smaller indices.
    # In this case we can infer the end parameter with:
    # end[i] = input_shape[i] - size[i] + 1
    end = np.zeros(len(begin))
    input_shape = x_in.shape.as_list()

    # if size is negative take the input dimension
    for i in range(len(end)):  # pylint: disable=consider-using-enumerate
        if size[i] < 0:
            end[i] = input_shape[i] - size[i] + 1
        else:
            end[i] = begin[i] + size[i]

    return tfe.strided_slice(input_out, begin, end)


# pylint: enable=unused-argument
# pylint: enable=missing-docstring
def _nodef_to_public_pond(converter, x):
    """Map a NodeDef x to a PublicPondTensor."""
    dtype = x.attr["dtype"].type
    x_shape = [i.size for i in x.attr["value"].tensor.tensor_shape.dim]

    if not x_shape:
        if dtype == tf.float32:
            nums = x.attr["value"].tensor.float_val
        elif dtype == tf.float64:
            nums = x.attr["value"].tensor.float_val
        elif dtype == tf.int32:
            nums = x.attr["value"].tensor.int_val
        else:
            raise TypeError("Unsupported dtype")

        def inputter_fn():
            return tf.constant(np.array(nums).reshape(1, 1))

    else:
        if dtype == tf.float32:
            nums = array.array("f", x.attr["value"].tensor.tensor_content)
        elif dtype == tf.float64:
            nums = array.array("d", x.attr["value"].tensor.tensor_content)
        elif dtype == tf.int32:
            nums = array.array("i", x.attr["value"].tensor.tensor_content)
        else:
            raise TypeError("Unsupported dtype")

        def inputter_fn():
            return tf.constant(np.array(nums).reshape(x_shape))

    x_public = tfe.define_public_input(converter.model_provider, inputter_fn)

    return x_public


def _nodef_to_private_pond(converter, x):
    """Map a NodeDef x to a PrivatePondTensor."""
    dtype = x.attr["dtype"].type
    warn_msg = "Unexpected dtype {} found at node {}"
    err_msg = "Unsupported dtype {} found at node {}"

    x_shape = [i.size for i in x.attr["value"].tensor.tensor_shape.dim]

    if not x_shape:
        if dtype == tf.float32:
            nums = x.attr["value"].tensor.float_val
        elif dtype == tf.float64:
            nums = x.attr["value"].tensor.float_val
        elif dtype == tf.int32:
            logging.warning(warn_msg, dtype, x.name)
            nums = x.attr["value"].tensor.int_val
        else:
            raise TypeError(err_msg.format(dtype, x.name))

        def inputter_fn():
            return tf.constant(np.array(nums).reshape(1, 1))

    else:
        if dtype == tf.float32:
            nums = array.array("f", x.attr["value"].tensor.tensor_content)
        elif dtype == tf.float64:
            nums = array.array("d", x.attr["value"].tensor.tensor_content)
        elif dtype == tf.int32:
            logging.warning(warn_msg, dtype, x.name)
            nums = array.array("i", x.attr["value"].tensor.tensor_content)
        else:
            raise TypeError(err_msg.format(dtype, x.name))

        def inputter_fn():
            return tf.constant(np.array(nums).reshape(x_shape))

    x_private = tfe.define_private_input(converter.model_provider, inputter_fn)

    return x_private


def _nodef_to_numpy_array(x):
    """Map a NodeDef x to a np.array."""
    dtype = x.attr["dtype"].type
    x_shape = [i.size for i in x.attr["value"].tensor.tensor_shape.dim]

    content = x.attr["value"].tensor.tensor_content

    if dtype == tf.float32:
        type_code = "f"
        if not content:
            content = x.attr["value"].tensor.float_val
    elif dtype == tf.float64:
        type_code = "d"
        if not content:
            content = x.attr["value"].tensor.double_val
    elif dtype == tf.int32:
        type_code = "i"
        if not content:
            content = x.attr["value"].tensor.int_val
    else:
        raise TypeError("Unsupported dtype")

    nums = array.array(type_code, content)

    return np.array(nums).reshape(x_shape)