"""
Visualisation of the copying process and ancestor generation using PIL
"""
import os
import sys

import numpy as np
import PIL.Image as Image
import PIL.ImageDraw as ImageDraw
import PIL.ImageColor as ImageColor
import PIL.ImageFont as ImageFont
import svgwrite

import tsinfer
import msprime


class AncestorBuilderViz(object):
    """
    Visualisation for the process of building ancestors.
    """

    def __init__(self, sample_data, ancestor_data, width=800, height=400):
        self.ancestor_data = ancestor_data
        self.sample_data = sample_data
        self.width = width
        self.height = height
        self.x_pad = 20
        self.y_pad = 20
        self.x_unit = (width - 2 * self.x_pad) / sample_data.num_sites
        self.y_unit = (height - 2 * self.y_pad) / (sample_data.num_samples + 2)

    def x_trans(self, v):
        return self.x_pad + v * self.x_unit

    def y_trans(self, v):
        return self.height - (self.y_pad + v * self.y_unit)

    def draw_matrix(self, dwg, focal_sites, ancestor, current_site=None):
        A = self.sample_data.sites_genotypes[:].T
        n, m = A.shape

        for site in focal_sites:
            dwg.add(
                dwg.rect(
                    (self.x_trans(site), self.y_trans(n)),
                    (self.x_unit, n * self.y_unit),
                    fill="grey",
                )
            )

        labels = dwg.add(dwg.g(font_size=14, text_anchor="middle"))
        lines = dwg.add(dwg.g(id="lines", stroke="black", stroke_width=3))
        for x in range(m + 1):
            a = self.x_trans(x), self.y_trans(0)
            b = self.x_trans(x), self.y_trans(n)
            lines.add(dwg.line(a, b))

        for y in range(n + 1):
            a = self.x_trans(0), self.y_trans(y)
            b = self.x_trans(m), self.y_trans(y)
            lines.add(dwg.line(a, b))

        for x in range(m):
            for y in range(n):
                labels.add(
                    dwg.text(
                        str(A[y, x]), (self.x_trans(x + 0.5), self.y_trans(y + 0.5))
                    )
                )
        y = n + 1
        for x in range(m):
            labels.add(
                dwg.text(
                    str(ancestor[x]), (self.x_trans(x + 0.5), self.y_trans(y + 0.5))
                )
            )

    def draw(self, ancestor_id, filename_pattern):
        start = self.ancestor_data.ancestors_start[ancestor_id]
        end = self.ancestor_data.ancestors_end[ancestor_id]
        focal_sites = self.ancestor_data.ancestors_focal_sites[ancestor_id]
        a = np.zeros(self.sample_data.num_sites, dtype=int)
        a[:] = -1
        a[start:end] = self.ancestor_data.ancestors_haplotype[ancestor_id]
        print(start, end, focal_sites, a)

        dwg = svgwrite.Drawing(size=(self.width, self.height), debug=True)
        self.draw_matrix(dwg, focal_sites, a)
        with open(filename_pattern.format(0), "w") as f:
            f.write(dwg.tostring())


def draw_edges(ts, width=800, height=600):
    """
    Returns an SVG depiction of the edges in the specified tree sequence.
    """
    dwg = svgwrite.Drawing(size=(width, height), debug=True)
    x_pad = 20
    y_pad = 20
    x_unit = (width - 2 * x_pad) / ts.sequence_length
    y_unit = (height - 2 * y_pad) / (ts.num_nodes + 1)

    def x_trans(v):
        return x_pad + v * x_unit

    def y_trans(v):
        return height - (y_pad + v * y_unit)

    lines = dwg.add(dwg.g(id="lines", stroke="black", stroke_width=3))
    left_labels = dwg.add(dwg.g(font_size=14, text_anchor="start"))
    mid_labels = dwg.add(dwg.g(font_size=14, text_anchor="middle"))
    for u in range(ts.num_nodes):
        left_labels.add(dwg.text(str(u), (0, y_trans(u))))
    for x in ts.breakpoints():
        dwg.add(
            dwg.line(
                (x_trans(x), 2 * y_pad),
                (x_trans(x), height),
                stroke="grey",
                stroke_width=1,
            )
        )
        dwg.add(dwg.text(str(x), (x_trans(x), y_pad), writing_mode="tb"))

    for edge in ts.edges():
        a = x_trans(edge.left), y_trans(edge.child)
        b = x_trans(edge.right), y_trans(edge.child)
        c = x_trans(edge.left + (edge.right - edge.left) / 2), y_trans(edge.child) - 5
        mid_labels.add(dwg.text(str(edge.parent), c))
        dwg.add(dwg.circle(center=a, r=3, fill="black"))
        dwg.add(dwg.circle(center=b, r=3, fill="black"))
        lines.add(dwg.line(a, b))

    for site in ts.sites():
        assert len(site.mutations) >= 1
        mutation = site.mutations[0]
        a = x_trans(site.position), y_trans(mutation.node)
        dwg.add(dwg.circle(center=a, r=1, fill="red"))
        for mutation in site.mutations[1:]:
            a = x_trans(site.position), y_trans(mutation.node)
            dwg.add(dwg.circle(center=a, r=1, fill="blue"))

    return dwg.tostring()


def draw_ancestors(ts, width=800, height=600):
    """
    Returns an SVG depiction of the ancestors in the specified tree sequence.
    """
    dwg = svgwrite.Drawing(size=(width, height), debug=True)
    x_pad = 20
    y_pad = 20
    x_unit = (width - 2 * x_pad) / ts.sequence_length
    y_unit = (height - 2 * y_pad) / (ts.num_nodes + 1)

    def x_trans(v):
        return x_pad + v * x_unit

    def y_trans(v):
        return height - (y_pad + v * y_unit)

    lines = dwg.add(dwg.g(id="lines", stroke="black", stroke_width=3))
    left_labels = dwg.add(dwg.g(font_size=14, text_anchor="start"))
    mid_labels = dwg.add(dwg.g(font_size=14, text_anchor="middle"))
    for u in range(ts.num_nodes):
        left_labels.add(dwg.text(str(u), (0, y_trans(u))))
    for x in ts.breakpoints():
        dwg.add(
            dwg.line(
                (x_trans(x), 2 * y_pad),
                (x_trans(x), height),
                stroke="grey",
                stroke_width=1,
            )
        )
        dwg.add(dwg.text("{}".format(x), (x_trans(x), y_pad), writing_mode="tb"))

    for e in ts.edgesets():
        a = x_trans(e.left), y_trans(e.parent)
        b = x_trans(e.right), y_trans(e.parent)
        c = x_trans(e.left + (e.right - e.left) / 2), y_trans(e.parent) - 5
        mid_labels.add(dwg.text(str(e.children), c))
        dwg.add(dwg.circle(center=a, r=3, fill="black"))
        dwg.add(dwg.circle(center=b, r=3, fill="black"))
        lines.add(dwg.line(a, b))

    for site in ts.sites():
        mutation = site.mutations[0]
        a = x_trans(site.position), y_trans(mutation.node)
        dwg.add(dwg.circle(center=a, r=1, fill="red"))
        for mutation in site.mutations[1:]:
            a = x_trans(site.position), y_trans(mutation.node)
            dwg.add(dwg.circle(center=a, r=1, fill="blue"))
    return dwg.tostring()


class Visualiser(object):
    def __init__(
        self, original_ts, sample_data, ancestor_data, inferred_ts, box_size=8
    ):
        # Make sure the singletons have been removed.
        for v in original_ts.variants():
            if np.sum(v.genotypes) < 2:
                raise ValueError("Only non singletons will be considered")
        self.box_size = box_size
        self.sample_data = sample_data
        self.original_ts = original_ts
        self.inferred_ts = inferred_ts
        self.ancestor_data = ancestor_data
        self.samples = original_ts.genotype_matrix().T
        self.num_samples = self.original_ts.num_samples
        self.num_sites = self.ancestor_data.num_sites
        node_time = inferred_ts.tables.nodes.time
        self.num_ancestors = np.where(node_time > 0)[0].shape[0]
        self.ancestors = np.zeros(
            (self.num_ancestors, original_ts.num_sites), dtype=np.uint8
        )
        for j, a in enumerate(ancestor_data.ancestors()):
            self.ancestors[j, a.start : a.end] = a.haplotype
            self.ancestors[j, : a.start] = tsinfer.UNKNOWN_ALLELE
            self.ancestors[j, a.end :] = tsinfer.UNKNOWN_ALLELE

        # TODO This only partially works for extra ancestors created by path
        # compression. We'll get -1 lines for extra ancestors created from
        # ancestors. However, extra ancestors created from matching samples
        # will break this code. We really need to just match node IDs to
        # y coordinates. Breaking up into samples and ancestors is awkward.

        # Find the site indexes for the true breakpoints
        breakpoints = list(original_ts.breakpoints())
        self.true_breakpoints = breakpoints[1:-1]

        self.top_padding = box_size
        self.left_padding = box_size
        self.bottom_padding = box_size
        self.mid_padding = 2 * box_size
        self.right_padding = box_size
        self.background_colour = ImageColor.getrgb("white")
        self.copying_outline_colour = ImageColor.getrgb("white")
        self.colours = {
            255: ImageColor.getrgb("pink"),
            0: ImageColor.getrgb("blue"),
            1: ImageColor.getrgb("red"),
        }
        self.copy_colours = {
            255: ImageColor.getrgb("white"),
            0: ImageColor.getrgb("black"),
            1: ImageColor.getrgb("green"),
        }
        self.error_colours = {
            0: ImageColor.getrgb("purple"),
            1: ImageColor.getrgb("orange"),
        }

        # Make the haplotype box
        num_haplotype_rows = 1
        self.row_map = {0: 0}

        # print(inferred_ts.tables.nodes)
        print("Ancestors = ", self.ancestors.shape, self.num_ancestors)

        num_haplotype_rows += 1
        for j in range(self.num_ancestors):
            self.row_map[j] = num_haplotype_rows
            num_haplotype_rows += 1
        num_haplotype_rows += 1
        for j in range(self.num_samples):
            self.row_map[self.num_ancestors + j] = num_haplotype_rows
            num_haplotype_rows += 1

        self.width = box_size * self.num_sites + self.left_padding + self.right_padding
        self.height = (
            self.top_padding
            + self.bottom_padding
            + self.mid_padding
            + num_haplotype_rows * box_size
        )
        self.ts_origin = (self.left_padding, self.top_padding)
        self.haplotype_origin = (self.left_padding, self.top_padding + self.mid_padding)
        self.base_image = Image.new(
            "RGB", (self.width, self.height), color=self.background_colour
        )

        b = self.box_size
        origin = self.haplotype_origin
        self.x_coordinate_map = {
            site.position: origin[0] + site.id * b for site in original_ts.sites()
        }
        self.draw_base()

    def draw_base(self):
        draw = ImageDraw.Draw(self.base_image)
        self.draw_base_haplotypes(draw)
        self.draw_true_breakpoints(draw)
        self.draw_errors(draw)

    def draw_errors(self, draw):
        b = self.box_size
        origin = self.haplotype_origin
        for site in self.original_ts.sites():
            for mut in site.mutations[1:]:
                row = self.row_map[self.num_ancestors + mut.node]
                y = row * b + origin[1]
                x = site.id * b + origin[0]
                fill = self.error_colours[int(mut.derived_state)]
                print("error at", site.id, mut.node, mut.derived_state)
                draw.rectangle([(x, y), (x + b, y + b)], fill=fill)

    def draw_true_breakpoints(self, draw):
        b = self.box_size
        origin = self.haplotype_origin
        coordinates = sorted(self.x_coordinate_map.keys())
        for bp in self.true_breakpoints:
            # Find the smallest coordinate > position
            for position in coordinates:
                if position >= bp:
                    break
            x = self.x_coordinate_map[position]
            y1 = origin[0] + self.row_map[0] * b
            y2 = origin[1] + (self.row_map[len(self.row_map) - 1] + 1) * b
            draw.line([(x, y1), (x, y2)], fill="purple", width=3)

    def draw_base_haplotypes(self, draw):
        b = self.box_size
        origin = self.haplotype_origin
        for node in self.row_map.keys():
            y = self.row_map[node] * b + origin[1] + b / 2
            x = origin[0]
            draw.text((x - b, y), str(node), fill="black")
            x = self.width - self.right_padding
            mapped = (node - len(self.row_map) + 1) * -1
            if mapped < self.num_samples:
                mapped = (mapped - self.num_samples + 1) * -1
            draw.text((x + b / 4, y), str(mapped), fill="black")

        # Draw the ancestors
        for j in range(self.ancestors.shape[0]):
            a = self.ancestors[j]
            row = self.row_map[j]
            y = row * b + origin[1]
            for k in range(self.num_sites):
                x = k * b + origin[0]
                if a[k] != -1:
                    draw.rectangle([(x, y), (x + b, y + b)], fill=self.colours[a[k]])
        # Draw the samples
        for j in range(self.samples.shape[0]):
            a = self.samples[j]
            row = self.row_map[self.num_ancestors + j]
            y = row * b + origin[1]
            for k in range(self.num_sites):
                x = k * b + origin[0]
                draw.rectangle([(x, y), (x + b, y + b)], fill=self.colours[a[k]])

    def draw_haplotypes(self, filename):
        self.base_image.save(filename)

    def draw_copying_path(self, filename, child_row, parents, breakpoints):
        origin = self.haplotype_origin
        b = self.box_size
        m = self.num_sites
        image = self.base_image.copy()
        draw = ImageDraw.Draw(image)
        y = self.row_map[child_row] * b + origin[1]
        x = origin[0]
        draw.rectangle(
            [(x, y), (x + m * b, y + b)], outline=self.copying_outline_colour
        )
        for k in range(m):
            if parents[k] != -1:
                row = self.row_map[parents[k]]
                y = row * b + origin[1]
                x = k * b + origin[0]
                a = self.ancestors[parents[k], k]
                draw.rectangle([(x, y), (x + b, y + b)], fill=self.copy_colours[a])

        for position in breakpoints:
            x = self.x_coordinate_map[position]
            y1 = origin[0] + self.row_map[0] * b
            y2 = origin[1] + (self.row_map[len(self.row_map) - 1] + 1) * b
            draw.line([(x, y1), (x, y2)], fill="black")

        # Draw the positions of the sites.
        font = ImageFont.load_default()
        for site in self.original_ts.sites():
            label = "{} {:.6f}".format(site.id, site.position)
            img_txt = Image.new("L", font.getsize(label), color="white")
            draw_txt = ImageDraw.Draw(img_txt)
            draw_txt.text((0, 0), label, font=font)
            t = img_txt.rotate(90, expand=1)
            x = origin[0] + site.id * b
            y = origin[1] - b
            image.paste(t, (x, y))
        # print("Saving", filename)
        image.save(filename)

    def draw_copying_paths(self, pattern):
        N = self.num_ancestors + self.samples.shape[0]
        P = np.zeros((N, self.num_sites), dtype=int) - 1
        ts = self.inferred_ts
        site_index = {}
        sites = list(ts.sites())
        for site in ts.sites():
            site_index[site.position] = site.id
        site_index[ts.sequence_length] = ts.num_sites
        site_index[0] = 0
        for e in ts.edges():
            left = site_index[e.left]
            right = site_index[e.right]
            assert left < right
            P[e.child, left:right] = e.parent
        n = self.samples.shape[0]
        breakpoints = []
        for j in range(1, self.num_ancestors + n):
            for k in np.where(P[j][1:] != P[j][:-1])[0]:
                breakpoints.append(sites[k + 1].position)
            self.draw_copying_path(pattern.format(j - 1), j, P[j], breakpoints)


def visualise(
    ts,
    recombination_rate,
    error_rate,
    engine="C",
    box_size=8,
    perfect_ancestors=False,
    path_compression=False,
    time_chunking=False,
):

    sample_data = tsinfer.SampleData.from_tree_sequence(ts)

    if perfect_ancestors:
        ancestor_data = tsinfer.AncestorData(sample_data)
        tsinfer.build_simulated_ancestors(
            sample_data, ancestor_data, ts, time_chunking=time_chunking
        )
        ancestor_data.finalise()
    else:
        ancestor_data = tsinfer.generate_ancestors(sample_data, engine=engine)

    ancestors_ts = tsinfer.match_ancestors(
        sample_data,
        ancestor_data,
        engine=engine,
        path_compression=path_compression,
        extended_checks=True,
    )
    inferred_ts = tsinfer.match_samples(
        sample_data,
        ancestors_ts,
        engine=engine,
        simplify=False,
        path_compression=path_compression,
        extended_checks=True,
    )

    prefix = "tmp__NOBACKUP__/"
    visualiser = Visualiser(
        ts, sample_data, ancestor_data, inferred_ts, box_size=box_size
    )
    visualiser.draw_copying_paths(os.path.join(prefix, "copying_{}.png"))

    # tsinfer.print_tree_pairs(ts, inferred_ts, compute_distances=False)
    inferred_ts = tsinfer.match_samples(
        sample_data,
        ancestors_ts,
        engine=engine,
        simplify=True,
        path_compression=False,
        stabilise_node_ordering=True,
    )

    tsinfer.print_tree_pairs(ts, inferred_ts, compute_distances=True)
    sys.stdout.flush()
    print(
        "num_sites = ",
        inferred_ts.num_sites,
        "num_mutations= ",
        inferred_ts.num_mutations,
    )

    for site in inferred_ts.sites():
        if len(site.mutations) > 1:
            print(
                "Multiple mutations at ",
                site.id,
                "over",
                [mut.node for mut in site.mutations],
            )


def run_viz(
    n,
    L,
    rate,
    seed,
    mutation_rate=0,
    engine="C",
    perfect_ancestors=True,
    perfect_mutations=True,
    path_compression=False,
    time_chunking=True,
    error_rate=0,
):
    recomb_map = msprime.RecombinationMap.uniform_map(length=L, rate=rate, num_loci=L)
    ts = msprime.simulate(
        n,
        recombination_map=recomb_map,
        random_seed=seed,
        model="smc_prime",
        mutation_rate=mutation_rate,
    )
    if perfect_mutations:
        ts = tsinfer.insert_perfect_mutations(ts, delta=1 / 512)
    else:
        ts = tsinfer.strip_singletons(tsinfer.insert_errors(ts, error_rate, seed))
    print("num_sites = ", ts.num_sites)

    with open("tmp__NOBACKUP__/edges.svg", "w") as f:
        f.write(draw_edges(ts))
    with open("tmp__NOBACKUP__/ancestors.svg", "w") as f:
        f.write(draw_ancestors(ts))
    visualise(
        ts,
        rate,
        0,
        engine=engine,
        box_size=26,
        perfect_ancestors=perfect_ancestors,
        path_compression=path_compression,
        time_chunking=time_chunking,
    )


def visualise_ancestors():
    ts = msprime.simulate(10, mutation_rate=2, recombination_rate=2, random_seed=3)
    ts = tsinfer.strip_singletons(ts)
    sample_data = tsinfer.SampleData.from_tree_sequence(ts)
    ancestor_data = tsinfer.generate_ancestors(sample_data)
    viz = AncestorBuilderViz(sample_data, ancestor_data)

    viz.draw(6, "ancestors_{}.svg")


def main():

    # visualise_ancestors()

    # run_viz(
    #     15, 1000, 0.0020, 11, mutation_rate=0.02, perfect_ancestors=True,
    #     perfect_mutations=True, time_chunking=True, engine="C", path_compression=False,
    #     error_rate=0.00)

    run_viz(15, 1000, 0.002, 2, engine=tsinfer.PY_ENGINE, perfect_ancestors=False)


if __name__ == "__main__":
    main()