from typing import List, Optional, Union

import tensorflow as tf
from docqa.configurable import Configurable
from docqa.nn.layers import SequenceBiMapper, MergeLayer, Mapper, get_keras_initialization, SequenceMapper, SequenceEncoder, \
    FixedMergeLayer, AttentionPredictionLayer, SequencePredictionLayer, SequenceMultiEncoder
from docqa.nn.span_prediction_ops import best_span_from_bounds, to_unpacked_coordinates, \
    to_packed_coordinates, packed_span_f1_mask
from tensorflow import Tensor
from tensorflow.contrib.layers import fully_connected

from docqa.model import Prediction
from docqa.nn.ops import VERY_NEGATIVE_NUMBER, exp_mask, segment_logsumexp

Classes to take a sequence of vectors and build a loss function + predict a span

class BoundaryPrediction(Prediction):
    """ Individual logits for the span start/end """

    def __init__(self, start_prob, end_prob,
                 start_logits, end_logits, mask):
        self.start_probs = start_prob
        self.end_probs = end_prob
        self.start_logits = start_logits
        self.end_logits = end_logits
        self.mask = mask
        self._bound_predictions = {}

    def get_best_span(self, bound: int):
        if bound in self._bound_predictions:
            return self._bound_predictions[bound]
            pred = best_span_from_bounds(self.start_logits, self.end_logits, bound)
            self._bound_predictions[bound] = pred
            return pred

    def get_span_scores(self):
        return tf.exp(tf.expand_dims(self.start_logits, 2) + tf.expand_dims(self.end_logits, 1))

    def get_mean_logit(self):
        logits = (self.start_logits + self.end_logits) / 2.0
        bol_mask = tf.sequence_mask(self.mask, tf.shape(self.start_logits)[1])
        bol_mask = tf.cast(bol_mask, tf.float32)
        return tf.reduce_sum(logits*bol_mask, axis=[1]) / tf.reduce_sum(bol_mask, axis=[1])

class PackedSpanPrediction(Prediction):
    """ Logits for each span in packed format (batch, packed_coordinate) """
    def __init__(self, logits, l, bound):
        self.bound = bound
        self.logits = logits
        argmax = tf.argmax(logits, axis=1)
        self.best_score = tf.reduce_max(logits, axis=1)
        self.predicted_span = to_unpacked_coordinates(argmax, l, bound)
        self.l = l

    def get_best_span(self, bound):
        if bound > self.bound:
            raise ValueError()
        if bound < self.bound:
            cutoff = self.l * bound - bound * (bound - 1) // 2
            logits = self.logits[:, :cutoff]
            argmax = tf.argmax(logits, axis=1)
            best_score = tf.reduce_max(logits, axis=1)
            predicted_span = to_unpacked_coordinates(argmax, self.l, bound)
            return predicted_span, best_score

        return self.predicted_span, self.best_score

class ConfidencePrediction(Prediction):
    """ boundary logits with an additional confidence logit """
    def __init__(self, span_probs,
                 start_logits, end_logits,
                 none_prob, non_op_logit,
        self.span_probs = span_probs
        self.none_prob = none_prob
        self.start_logits = start_logits
        self.end_logits = end_logits
        self.none_logit = non_op_logit
        self.start_probs = tf.nn.softmax(start_logits)
        self.end_probs = tf.nn.softmax(end_logits)
        self.mask = mask

    def get_best_span(self, bound: int):
        return best_span_from_bounds(self.start_logits, self.end_logits, bound)

    def get_span_scores(self):
        return tf.exp(tf.expand_dims(self.start_logits, 2) + tf.expand_dims(self.end_logits, 1))

    def get_mean_logit(self):
        logits = self.start_logits + self.end_logits
        bol_mask = tf.sequence_mask(self.mask, tf.shape(self.start_logits)[1])
        bol_mask = tf.cast(bol_mask, tf.float32)
        return tf.reduce_sum(logits*bol_mask, axis=[1]) / tf.reduce_sum(bol_mask, axis=[1])

class SpanFromBoundsPredictor(Configurable):
    Adds a loss function and returns a prediction given start/end span bounds logits.
    There a few loss function we could consider at this point so this class provides an abstraction
    over those options

    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        raise NotImplementedError()

class IndependentBounds(SpanFromBoundsPredictor):
    def __init__(self, aggregate="sum"):
        self.aggregate = aggregate

    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        if len(answer) == 1:
            # answer span is encoding in a sparse int array
            answer_spans = answer[0]
            losses1 = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=masked_start_logits, labels=answer_spans[:, 0])
            losses2 = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=masked_end_logits, labels=answer_spans[:, 1])
            loss = tf.add_n([tf.reduce_mean(losses1), tf.reduce_mean(losses2)], name="loss")
        elif len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
            # all correct start/end bounds are marked in a dense bool array
            # In this case there might be multiple answer spans, so we need an aggregation strategy
            losses = []
            for answer_mask, logits in zip(answer, [masked_start_logits, masked_end_logits]):
                log_norm = tf.reduce_logsumexp(logits, axis=1)
                if self.aggregate == "sum":
                    log_score = tf.reduce_logsumexp(logits +
                                                    VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)),
                elif self.aggregate == "max":
                    log_score = tf.reduce_max(logits +
                                              VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)), axis=1)
                    raise ValueError()
                losses.append(tf.reduce_mean(-(log_score - log_norm)))
            loss = tf.add_n(losses)
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
                                  masked_start_logits, masked_end_logits, mask)

class ForwardSpansOnly(SpanFromBoundsPredictor):
    Explicitly compute the per-span score, the mask out the spans the negative spans, surprisingly I
    found this to hurt performance on SQuAD (similar f1, worse em)

    def __init__(self, aggregate="sum", bound: int=-1):
        self.aggregate = aggregate
        self.bound = bound

    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        l = tf.shape(start_logits)[1]
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        # Explicit score for each span
        span_scores = tf.expand_dims(start_logits, 2) + tf.expand_dims(end_logits, 1)

        # Mask for in-bound spans, now (batch, start, end) matrix
        mask = tf.sequence_mask(mask, l)
        mask = tf.logical_and(tf.expand_dims(mask, 2), tf.expand_dims(mask, 1))

        # Also mask out spans that are negative/inverse by taking only the upper triangle
        mask = tf.matrix_band_part(mask, 0, self.bound)

        # Apply the mask
        mask = tf.cast(mask, tf.float32)
        span_scores = span_scores * mask + (1 - mask) * VERY_NEGATIVE_NUMBER

        if len(answer) == 1:
            answer = answer[0]
            span_scores = tf.reshape(span_scores, (tf.shape(start_logits)[0], -1))
            answer = answer[:, 0] * l + answer[:, 1]
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=span_scores, labels=answer)
            loss = tf.reduce_mean(losses)
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
                                  masked_start_logits, masked_end_logits, mask)

class IndependentBoundsNoAnswerOption(SpanFromBoundsPredictor):
    Return start_logits and end_logit, and also learn a scalar no-answer option. I have generally used
    `ConfidencePredictor` over this class, although possibly forcing the no-answer option to be scalar
    will help ensure the score for the remaining spans are well calibrated

    def __init__(self, aggregate="sum", non_init=-1.0):
        self.aggregate = aggregate
        self.non_init = non_init

    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)
        batch_dim = tf.shape(start_logits)[0]

        if len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
            none_logit = tf.get_variable("none-logit", initializer=self.non_init, dtype=tf.float32)
            none_logit = tf.tile(tf.expand_dims(none_logit, 0), [batch_dim])

            all_logits = tf.reshape(tf.expand_dims(masked_start_logits, 1) +
                                    tf.expand_dims(masked_end_logits, 2),
                                    (batch_dim, -1))

            # (batch, (l * l) + 1) logits including the none option
            all_logits = tf.concat([all_logits, tf.expand_dims(none_logit, 1)], axis=1)
            log_norms = tf.reduce_logsumexp(all_logits, axis=1)

            # Now build a "correctness" mask in the same format
            correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1), tf.expand_dims(answer[1], 2))
            correct_mask = tf.reshape(correct_mask, (batch_dim, -1))
            correct_mask = tf.concat([correct_mask, tf.logical_not(tf.reduce_any(answer[0], axis=1, keep_dims=True))],

            log_correct = tf.reduce_logsumexp(
                all_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(correct_mask, tf.float32)), axis=1)
            loss = tf.reduce_mean(-(log_correct - log_norms))
            probs = tf.nn.softmax(all_logits)
            tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
            return ConfidencePrediction(probs[:, :-1], masked_start_logits, masked_end_logits,
                                        probs[:, -1], none_logit)
            raise NotImplemented()

class IndependentBoundsGrouped(SpanFromBoundsPredictor):
    """ The shared norm loss, where the normalizer is shared between paragraph with the same group id """

    def __init__(self, aggregate="sum"):
        self.aggregate = aggregate

    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        if len(answer) == 3:
            group_ids = answer[2]
            # Turn the ids into segment ids using tf.unique
            _, group_segments = tf.unique(group_ids, out_idx=tf.int32)

            losses = []
            for answer_mask, logits in zip(answer, [masked_start_logits, masked_end_logits]):
                group_norms = segment_logsumexp(logits, group_segments)
                if self.aggregate == "sum":
                    log_score = segment_logsumexp(logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)),
                    raise ValueError()
                losses.append(tf.reduce_mean(-(log_score - group_norms)))
            loss = tf.add_n(losses)
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
                                  masked_start_logits, masked_end_logits, mask)

class IndependentBoundsSigmoidLoss(SpanFromBoundsPredictor):
    """ Independent sigmoid loss for each start/end span """

    def __init__(self, aggregate="sum"):
        self.aggregate = aggregate

    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        if len(answer) == 1:
            raise NotImplementedError()
        elif len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
            losses = []
            for answer_mask, logits in zip(answer, [masked_start_logits, masked_end_logits]):
                answer_mask = tf.cast(answer_mask, tf.float32)
                loss = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.cast(answer_mask, tf.float32),
            loss = tf.add_n(losses)
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, tf.reduce_mean(loss, name="sigmoid-loss"))
        return BoundaryPrediction(tf.nn.sigmoid(masked_start_logits),
                                  masked_start_logits, masked_end_logits, mask)

class BoundedSpanPredictor(SpanFromBoundsPredictor):
    """ Loss based on only using span that are up to a fixed bound in length """

    def __init__(self, bound: int, f1_weight=0, aggregate:str=None):
        self.bound = bound
        self.f1_weight = f1_weight
        self.aggregate = aggregate

    def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        bound = self.bound
        f1_weight = self.f1_weight
        aggregate = self.aggregate
        masked_logits1 = exp_mask(start_logits, mask)
        masked_logits2 = exp_mask(end_logits, mask)

        span_logits = []
        for i in range(self.bound):
            if i == 0:
                span_logits.append(masked_logits1 + masked_logits2)
                span_logits.append(masked_logits1[:, :-i] + masked_logits2[:, i:])
        span_logits = tf.concat(span_logits, axis=1)
        l = tf.shape(start_logits)[1]

        if len(answer) == 1:
            answer = answer[0]
            if answer.dtype == tf.int32:
                if f1_weight == 0:
                    answer_ix = to_packed_coordinates(answer, l, bound)
                    loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(logits=span_logits, labels=answer_ix))
                    f1_mask = packed_span_f1_mask(answer, l, bound)
                    if f1_weight < 1:
                        f1_mask *= f1_weight
                        f1_mask += (1 - f1_weight) * tf.one_hot(to_packed_coordinates(answer, l, bound), l)
                    # TODO can we stay in log space?  (actually its tricky since f1_mask can have zeros...)
                    probs = tf.nn.softmax(span_logits)
                    loss = -tf.reduce_mean(tf.log(tf.reduce_sum(probs * f1_mask, axis=1)))
                log_norm = tf.reduce_logsumexp(span_logits, axis=1)
                if aggregate == "sum":
                    log_score = tf.reduce_logsumexp(
                        span_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)),
                elif aggregate == "max":
                    log_score = tf.reduce_max(span_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)),
                    raise NotImplementedError()
                loss = tf.reduce_mean(-(log_score - log_norm))
            raise NotImplementedError()

        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return PackedSpanPrediction(span_logits, l, bound)

class SpanFromVectorBound(SequencePredictionLayer):
    RaSoR style prediction, combing a vector at the start/end
    of each span. In practice I have struggled to make this work well on TriviaQA

    def __init__(self,
                 mapper: SequenceBiMapper,
                 pre_process: Optional[SequenceMapper],
                 merge: MergeLayer,
                 post_process: Optional[Mapper],
                 bound: int,
                 init: str="glorot_uniform",
        self.mapper = mapper
        self.pre_process = pre_process
        self.merge = merge
        self.post_process = post_process
        self.init = init
        self.f1_weight = f1_weight
        self.bound = bound
        self.aggregate = aggregate

    def apply(self, is_train, context_embed, answer, context_mask=None):
        init_fn = get_keras_initialization(self.init)
        bool_mask = tf.sequence_mask(context_mask, tf.shape(context_embed)[1])

        with tf.variable_scope("predict"):
            m1, m2 = self.mapper.apply(is_train, context_embed, context_mask)

        if self.pre_process is not None:
            with tf.variable_scope("pre-process1"):
                m1 = self.pre_process.apply(is_train, m1, context_mask)
            with tf.variable_scope("pre-process2"):
                m2 = self.pre_process.apply(is_train, m2, context_mask)

        span_vector_lst = []
        mask_lst = []
        with tf.variable_scope("merge"):
            span_vector_lst.append(self.merge.apply(is_train, m1, m2))
        for i in range(1, self.bound):
            with tf.variable_scope("merge", reuse=True):
                span_vector_lst.append(self.merge.apply(is_train, m1[:, :-i], m2[:, i:]))
            mask_lst.append(bool_mask[:, i:])

        mask = tf.concat(mask_lst, axis=1)
        span_vectors = tf.concat(span_vector_lst, axis=1)  # all logits -> flattened per-span predictions

        if self.post_process is not None:
            with tf.variable_scope("post-process"):
                span_vectors = self.post_process.apply(is_train, span_vectors)

        with tf.variable_scope("compute_logits"):
            logits = fully_connected(span_vectors, 1, activation_fn=None, weights_initializer=init_fn)

        logits = tf.squeeze(logits, squeeze_dims=[2])
        logits = logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(tf.concat(mask, axis=1), tf.float32))

        l = tf.shape(context_embed)[1]

        if len(answer) == 1:
            answer = answer[0]
            if answer.dtype == tf.int32:
                if self.f1_weight == 0:
                    answer_ix = to_packed_coordinates(answer, l, self.bound)
                    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=answer_ix))
                    f1_mask = packed_span_f1_mask(answer, l, self.bound)
                    if self.f1_weight < 1:
                        f1_mask *= self.f1_weight
                        f1_mask += (1-self.f1_weight) * tf.one_hot(to_packed_coordinates(answer, l, self.bound), l)

                    # TODO can we stay in log space?  (actually its tricky since f1_mask can have zeros...)
                    probs = tf.nn.softmax(logits)
                    loss = -tf.reduce_mean(tf.log(tf.reduce_sum(probs * f1_mask, axis=1)))
                log_norm = tf.reduce_logsumexp(logits, axis=1)
                if self.aggregate == "sum":
                    log_score = tf.reduce_logsumexp(
                        logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)),
                elif self.aggregate == "max":
                    log_score = tf.reduce_max(logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)),
                    raise NotImplementedError()
                loss = tf.reduce_mean(-(log_score - log_norm))
            raise NotImplementedError()

        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return PackedSpanPrediction(logits, l, self.bound)

class BoundsPredictor(SequencePredictionLayer):
    """ Standard start/end bound prediction """

    def __init__(self, predictor: SequenceBiMapper, init: str="glorot_uniform",
                 span_predictor: SpanFromBoundsPredictor = IndependentBounds()):
        self.predictor = predictor
        self.init = init
        self.span_predictor = span_predictor

    def apply(self, is_train, context_embed, answer, context_mask=None):
        init_fn = get_keras_initialization(self.init)
        with tf.variable_scope("bounds_encoding"):
            m1, m2 = self.predictor.apply(is_train, context_embed, context_mask)

        with tf.variable_scope("start_pred"):
            logits1 = fully_connected(m1, 1, activation_fn=None,
            logits1 = tf.squeeze(logits1, squeeze_dims=[2])

        with tf.variable_scope("end_pred"):
            logits2 = fully_connected(m2, 1, activation_fn=None, weights_initializer=init_fn)
            logits2 = tf.squeeze(logits2, squeeze_dims=[2])

        with tf.variable_scope("predict_span"):
            return self.span_predictor.predict(answer, logits1, logits2, context_mask)

    def __setstate__(self, state):
        if "state" in state:
            if "aggregate" in state["state"]:
                state["state"]["bound_predictor"] = IndependentBounds(state["state"]["aggregate"])
            elif "bound_predictor" not in state:
                state["state"]["bound_predictor"] = IndependentBounds()

class WithFixedContextPredictionLayer(AttentionPredictionLayer):
    """ Bound prediction integrating a fixed length represention of the question """

    def __init__(self, context_mapper: SequenceMapper, context_encoder: SequenceEncoder,
                 merge: FixedMergeLayer, bounds_predictor: SequenceBiMapper,
                 span_predictor: SpanFromBoundsPredictor = IndependentBounds()):
        self.context_mapper = context_mapper
        self.context_encoder = context_encoder
        self.bounds_predictor = bounds_predictor
        self.merge = merge
        self.init = init
        self.span_predictor = span_predictor

    def apply(self, is_train, x, memories, answer: List[Tensor], x_mask=None, memory_mask=None):
        with tf.variable_scope("map_context"):
            memories = self.context_mapper.apply(is_train, memories, memory_mask)
        with tf.variable_scope("encode_context"):
            encoded = self.context_encoder.apply(is_train, memories, memory_mask)
        with tf.variable_scope("merge"):
            x = self.merge.apply(is_train, x, encoded, x_mask)
        with tf.variable_scope("predict"):
            m1, m2 = self.bounds_predictor.apply(is_train, x, x_mask)

        init = get_keras_initialization(self.init)
        with tf.variable_scope("logits1"):
            l1 = fully_connected(m1, 1, activation_fn=None, weights_initializer=init)
            l1 = tf.squeeze(l1, squeeze_dims=[2])
        with tf.variable_scope("logits2"):
            l2 = fully_connected(m2, 1, activation_fn=None, weights_initializer=init)
            l2 = tf.squeeze(l2, squeeze_dims=[2])

        with tf.variable_scope("predict_span"):
            return self.span_predictor.predict(answer, l1, l2, x_mask)

class ConfidencePredictor(SequencePredictionLayer):
    Bound prediction where we compute a non-answer logit/option using soft attention over
    the start/end logit and a `SequenceEncoder`.

    def __init__(self,
                 predictor: SequenceBiMapper,
                 encoder: Union[SequenceEncoder, SequenceMultiEncoder],
                 confidence_predictor: Mapper,
                 init: str="glorot_uniform",
        self.predictor = predictor
        self.init = init
        self.aggregate = aggregate
        self.confidence_predictor = confidence_predictor
        self.encoder = encoder

    def version(self):
        return 1  # Fix masking

    def apply(self, is_train, context_embed, answer, context_mask=None):
        init_fn = get_keras_initialization(self.init)
        m1, m2 = self.predictor.apply(is_train, context_embed, context_mask)

        if m1.shape.as_list()[-1] != 1:
            with tf.variable_scope("start_pred"):
                start_logits = fully_connected(m1, 1, activation_fn=None,
            start_logits = m1
        start_logits = tf.squeeze(start_logits, squeeze_dims=[2])

        if m1.shape.as_list()[-1] != 1:
            with tf.variable_scope("end_pred"):
                end_logits = fully_connected(m2, 1, activation_fn=None, weights_initializer=init_fn)
            end_logits = m2
        end_logits = tf.squeeze(end_logits, squeeze_dims=[2])

        masked_start_logits = exp_mask(start_logits, context_mask)
        masked_end_logits = exp_mask(end_logits, context_mask)

        start_atten = tf.einsum("ajk,aj->ak", m1, tf.nn.softmax(masked_start_logits))
        end_atten = tf.einsum("ajk,aj->ak", m2, tf.nn.softmax(masked_end_logits))
        with tf.variable_scope("encode_context"):
            enc = self.encoder.apply(is_train, context_embed, context_mask)
        if len(enc.shape) == 3:
            _, encodings, fe = enc.shape.as_list()
            enc = tf.reshape(enc, (-1, encodings*fe))

        with tf.variable_scope("confidence"):
            conf = [start_atten, end_atten, enc]
            none_logit = self.confidence_predictor.apply(is_train, tf.concat(conf, axis=1))
        with tf.variable_scope("confidence_logits"):
            none_logit = fully_connected(none_logit, 1, activation_fn=None,
            none_logit = tf.squeeze(none_logit, axis=1)

        batch_dim = tf.shape(start_logits)[0]

        # (batch, (l * l)) logits for each (start, end) pair
        all_logits = tf.reshape(tf.expand_dims(masked_start_logits, 1) +
                                tf.expand_dims(masked_end_logits, 2),
                                (batch_dim, -1))

        # (batch, (l * l) + 1) logits including the none option
        all_logits = tf.concat([all_logits, tf.expand_dims(none_logit, 1)], axis=1)
        log_norms = tf.reduce_logsumexp(all_logits, axis=1)

        # Now build a "correctness" mask in the same format
        correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1), tf.expand_dims(answer[1], 2))
        correct_mask = tf.reshape(correct_mask, (batch_dim, -1))
        correct_mask = tf.concat([correct_mask, tf.logical_not(tf.reduce_any(answer[0], axis=1, keep_dims=True))], axis=1)

        # Note we are happily allowing the model to place weights on "backwards" spans, and also giving
        # it points for predicting spans that start and end at different answer spans. It would be easy to
        # fix by masking out some of the `all_logit` matrix and specify a more accuracy correct_mask, but I
        # in general left it this way to be consistent with the independent bound models that do the same.
        # Some early tests found properly masking things to not make much difference (or even to hurt), but it
        # still could be an avenue for improvement

        log_correct = tf.reduce_logsumexp(all_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(correct_mask, tf.float32)), axis=1)
        loss = tf.reduce_mean(-(log_correct - log_norms))
        probs = tf.nn.softmax(all_logits)
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return ConfidencePrediction(probs[:, :-1], masked_start_logits, masked_end_logits,
                                    probs[:, -1], none_logit, context_mask)