import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.pyplot
from itertools import cycle

# this is so I can plot graphics on a headless server

matplotlib.pyplot.ioff()

HORIZONTAL_LEVELS = [1,2,3,4]


class _Plot(object):
    def __init__(self, filename='', height=0, width=0, dpi=0, fontsize=12,
                 scale=0):

        self.filename = filename
        self.scale = scale
        self.width = width
        self.height = height
        self.dpi = dpi
        self.fontsize = fontsize

        self.fig = plt.figure(
            figsize=(self.width, self.height), dpi=self.dpi, frameon=False
        )
        self.ax = self.fig.add_subplot(111)
        self.rr = self.fig.canvas.get_renderer()

    def save(self):

        self.fig.savefig(
            self.filename,
            dpi=self.dpi,
            bbox_inches='tight'
        )

        plt.close(self.fig)
        plt.clf()

    def _scale(self, seq_length):
        """
        scale the sequence (protein or DNA)
        """

        if self.scale is None or self.scale < seq_length:
            self.normalize = seq_length
        else:
            self.normalize = self.scale

        self.offset = 0.05 + (1.0 - float(seq_length)/self.normalize)*0.45

        assert self.normalize >= seq_length, "length normalization should be >= protein length"


class _PlotExons(_Plot):
    def __init__(self, *args, **kwargs):
        super(_PlotExons, self).__init__(*args, **kwargs)
        self.vertical_offset = 0.15

    def _draw_length_markers(self, basepair_length):
        # plot protein length markers

        self.basepair_length = basepair_length/float(self.normalize)*0.9

        self.line_end = basepair_length/float(self.normalize)*0.9 + self.offset

        self.ax.text(
            0.5,
            0.1,
            "Base pair position (kbp)",
            horizontalalignment='center',
            fontsize=self.fontsize
        )

        self.ax.add_line(plt.Line2D(
            (
                self.offset,
                self.offset+self.basepair_length
            ),
            (0.2+self.vertical_offset, 0.2+self.vertical_offset),
            color='black'
            )
        )

        # left marker

        self.left_marker_line = self.ax.add_line(plt.Line2D(
            (
                self.offset,
                self.offset
            ),
            (0.15+self.vertical_offset, 0.2+self.vertical_offset),
            color='black'
            )
        )
        self.left_marker_text = self.ax.text(
            self.offset,
            0.05+self.vertical_offset,
            "0",
            horizontalalignment='center',
            fontsize=self.fontsize
        )

        # draw markers for increments of 1000 base pairs

        for i in range(1, basepair_length+1):
            if (i % 10000) == 0:
                self.left_marker_line = self.ax.add_line(plt.Line2D(
                    (
                        self.offset+(i/float(self.normalize)*0.9),
                        self.offset+(i/float(self.normalize)*0.9)
                    ),
                    (0.175+self.vertical_offset, 0.2+self.vertical_offset),
                    color='black'
                    )
                )

        # right marker

        self.right_marker_line = self.ax.add_line(plt.Line2D(
            (
                self.offset+self.basepair_length,
                self.offset+self.basepair_length
            ),
            (0.15+self.vertical_offset, 0.2+self.vertical_offset),
            color='black'
            )
        )
        self.right_marker_text = self.ax.text(
            self.offset+self.basepair_length,
            0.05+self.vertical_offset,
            str(basepair_length/1000),
            horizontalalignment='center',
            fontsize=self.fontsize
        )


class PlotWTExons(_PlotExons):
    def __init__(self, ensembl_transcript, *args, **kwargs):
        super(PlotWTExons, self).__init__(*args, **kwargs)
        self.ensembl_transcript = ensembl_transcript

    def _draw_main_body(self, name_symbols, name_isoform):
        """
        main protein frame
        """

        length = (self.ensembl_transcript.end-self.ensembl_transcript.start)/float(self.normalize)*0.9

        self.ax.add_line(plt.Line2D(
            (
                self.offset,
                self.offset+length
            ),
            (0.5, 0.5),
            color='black'
            )
        )

        self.ax.text(
            0.5,
            0.9,
            name_symbols,
            horizontalalignment='center',
            fontsize=self.fontsize
        )
        self.ax.text(
            0.5,
            0.83,
            name_isoform,
            horizontalalignment='center',
            fontsize=self.fontsize-3
        )

    def _draw_exons(self):
        for exon in self.ensembl_transcript.exon_intervals:

            if self.ensembl_transcript.strand == '+':
                start = exon[0] - self.ensembl_transcript.start
                end = exon[1] - self.ensembl_transcript.start
            else:

                # this is so the transcription direction is not plotted
                # in reverse for genes on minus strand

                start = -(exon[1] - self.ensembl_transcript.end)
                end = -(exon[0] - self.ensembl_transcript.end)

            exon_start = (int(start)/float(self.normalize))*0.9 + self.offset
            exon_end = (int(end)/float(self.normalize))*0.9 + self.offset
            exon_center = (exon_end-exon_start)/2. + exon_start

            self.ax.add_patch(
                patches.Rectangle(
                    (
                        exon_start,
                        0.45,
                    ),
                    exon_end-exon_start,
                    0.1,
                    color="black"
                )
            )


    def draw(self):
        self._scale(self.ensembl_transcript.end-self.ensembl_transcript.start)
        self._draw_exons()
        self._draw_length_markers(self.ensembl_transcript.end-self.ensembl_transcript.start)
        self._draw_main_body(
            self.ensembl_transcript.gene.gene_name,
            self.ensembl_transcript.id
        )
        self.ax.axis('off')
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 1)


class PlotFusionExons(_PlotExons):
    def __init__(self, transcript, *args, **kwargs):
        super(PlotFusionExons, self).__init__(*args, **kwargs)
        self.transcript = transcript

    def _draw_fusion_junction(self, junction_location):

        junction_location_norm = junction_location/float(self.normalize)*0.9

        self.ax.add_line(plt.Line2D(
            (
                self.offset+junction_location_norm,
                self.offset+junction_location_norm
            ),
            (0.15+self.vertical_offset, 0.2+self.vertical_offset),
            color='black'
            )
        )
        self.right_marker_text = self.ax.text(
            self.offset+junction_location_norm,
            0.05+self.vertical_offset,
            str(junction_location/1000),
            horizontalalignment='center',
            fontsize=self.fontsize
        )

    def _draw_exons(self):
        for exon in self.transcript.gene5prime_exon_intervals:

            if self.transcript.transcript1.strand == '+':
                start = exon[0] - self.transcript.transcript1.start
                end = exon[1] - self.transcript.transcript1.start
            else:

                # this is so the transcription direction is not plotted
                # in reverse for genes on minus strand

                start = -(exon[1] - self.transcript.transcript1.end)
                end = -(exon[0] - self.transcript.transcript1.end)

            exon_start = (int(start)/float(self.normalize))*0.9 + self.offset
            exon_end = (int(end)/float(self.normalize))*0.9 + self.offset
            exon_center = (exon_end-exon_start)/2. + exon_start

            self.ax.add_patch(
                patches.Rectangle(
                    (
                        exon_start,
                        0.45,
                    ),
                    exon_end-exon_start,
                    0.1,
                    color="black"
                )
            )

        if self.transcript.transcript1.strand == '+':
            distance_to_add = self.transcript.gene5prime.junction - \
                self.transcript.transcript1.start
        else:
            distance_to_add = self.transcript.transcript1.end - \
                self.transcript.gene5prime.junction

        for exon in self.transcript.gene3prime_exon_intervals:

            if self.transcript.transcript2.strand == '+':
                start = exon[0] - self.transcript.gene3prime.junction + \
                    distance_to_add
                end = exon[1] - self.transcript.gene3prime.junction + \
                    distance_to_add
            else:

                # this is so the transcription direction is not plotted
                # in reverse for genes on minus strand

                start = (self.transcript.gene3prime.junction - exon[1]) + \
                    distance_to_add
                end = (self.transcript.gene3prime.junction - exon[0]) + \
                    distance_to_add

            exon_start = (int(start)/float(self.normalize))*0.9 + self.offset
            exon_end = (int(end)/float(self.normalize))*0.9 + self.offset
            exon_center = (exon_end-exon_start)/2. + exon_start

            self.ax.add_patch(
                patches.Rectangle(
                    (
                        exon_start,
                        0.45,
                    ),
                    exon_end-exon_start,
                    0.1,
                    color="red"
                )
            )

    def _draw_main_body(self, name_symbols, name_isoform, length):
        """
        main protein frame
        """

        gene5prime_length = 0
        gene3prime_length = 0

        if self.transcript.transcript1.strand == '+':
            gene5prime_length = (self.transcript.gene5prime.junction -
                                 self.transcript.transcript1.start) \
                                 / float(self.normalize)*0.9
        else:
            gene5prime_length = (self.transcript.transcript1.end -
                                 self.transcript.gene5prime.junction) \
                                 / float(self.normalize)*0.9

        if self.transcript.transcript2.strand == '+':
            gene3prime_length = (self.transcript.transcript2.end -
                                 self.transcript.gene3prime.junction) \
                                 / float(self.normalize)*0.9
        else:
            gene3prime_length = (self.transcript.gene3prime.junction -
                                 self.transcript.transcript2.start) \
                                 / float(self.normalize)*0.9

        self.ax.add_line(plt.Line2D(
            (
                self.offset,
                self.offset + gene5prime_length
            ),
            (0.5, 0.5),
            color='black'
            )
        )
        self.ax.add_line(plt.Line2D(
            (
                self.offset + gene5prime_length,
                self.offset + gene5prime_length + gene3prime_length
            ),
            (0.5, 0.5),
            color='red'
            )
        )

        self.ax.text(
            0.5,
            0.9,
            name_symbols,
            horizontalalignment='center',
            fontsize=self.fontsize
        )
        self.ax.text(
            0.5,
            0.83,
            name_isoform,
            horizontalalignment='center',
            fontsize=self.fontsize-3
        )

    def draw(self):

        if self.transcript.transcript1.strand == '+':
            gene5prime_length = self.transcript.gene5prime.junction - \
                self.transcript.transcript1.start
        else:
            gene5prime_length = self.transcript.transcript1.end - \
                self.transcript.gene5prime.junction

        if self.transcript.transcript2.strand == '+':
            gene3prime_length = self.transcript.transcript2.end - \
                self.transcript.gene3prime.junction
        else:
            gene3prime_length = self.transcript.gene3prime.junction - \
                self.transcript.transcript2.start

        self._scale(gene5prime_length+gene3prime_length)
        self._draw_exons()
        self._draw_length_markers(gene5prime_length+gene3prime_length)
        self._draw_fusion_junction(gene5prime_length)
        self._draw_main_body(
            self.transcript.transcript1.gene.gene_name + '-' +
            self.transcript.transcript2.gene.gene_name,
            self.transcript.transcript1.id + '-' +
            self.transcript.transcript2.id,
            gene5prime_length+gene3prime_length
        )
        self.ax.axis('off')
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 1)


class _PlotProtein(_Plot):

    def __init__(self, transcript=None, colors=None,
                 rename=None, no_domain_labels=False,
                 exclude = [], *args, **kwargs):
        super(_PlotProtein, self).__init__(*args, **kwargs)
        self.transcript = transcript
        self.colors = colors
        self.rename = rename
        self.no_domain_labels = no_domain_labels
        self.exclude = exclude
        self.vertical_offset = 0.55

    def _draw_domains(self, domains):
        # plot domains

        domain_labels = {i:[] for i in HORIZONTAL_LEVELS}
        domain_labels_levels = {}
        domain_label_boxes = {i:[] for i in HORIZONTAL_LEVELS}
        lowest_level_plotted = HORIZONTAL_LEVELS[0]

        domains.sort(key=lambda x: x[3])

        domain_count = 0

        for domain in domains:

            # use domain name if available, otherwise use its ID

            if domain[1] is None:
                domain_name = str(domain[0])
            else:
                domain_name = str(domain[1])

            if domain_name in self.exclude:
                continue

            if domain_name in self.rename:
                domain_name = self.rename[domain_name]

            domain_start = (int(domain[3])/float(self.normalize))*0.9 + self.offset
            domain_end = (int(domain[4])/float(self.normalize))*0.9 + self.offset
            domain_center = (domain_end-domain_start)/2. + domain_start

            if not self.no_domain_labels:

                # for each newly plotted domain, loop through all previous
                # plotted domains and calculated the extent of overlap
                # then horizontally adjust the domain label to be plotted
                # closest to the protein domain structure or be
                # on the same level as the label it overlaps with the least

                #domain_stack_level = cycle()
                overlaps = {i:0.0 for i in HORIZONTAL_LEVELS}
                overlaps_all_levels = True
                min_overlap = [float("inf"),HORIZONTAL_LEVELS[0]]
                # plot domain at 1st level

                for level in HORIZONTAL_LEVELS:

                    level_pos = self.vertical_offset - 0.15 - (level-1.0)*0.1

                    tmp_domain_label = self.ax.text(
                        domain_center,
                        level_pos,
                        domain_name,
                        horizontalalignment='center',
                        verticalalignment='center',
                        fontsize=self.fontsize
                    )
                    tmp_domain_label_box = tmp_domain_label.get_window_extent(renderer=self.rr)

                    #check to see if it overlaps with anything

                    if len(domain_label_boxes[level])>0:
                        max_overlap = max([i.x1 - tmp_domain_label_box.x0 for i in domain_label_boxes[level]])

                        if max_overlap > 0.0:
                            overlaps[level] = max_overlap
                            tmp_domain_label.remove()

                            if max_overlap <= min_overlap[0]:
                                min_overlap = [max_overlap, level]
                        else:
                            domain_labels[level].append(tmp_domain_label)
                            domain_label_boxes[level].append(tmp_domain_label_box)
                            overlaps_all_levels = False

                            domain_labels_levels[domain_count] = level

                            if level > lowest_level_plotted:
                                lowest_level_plotted = level

                            break
                    else:
                        domain_labels[level].append(tmp_domain_label)
                        domain_label_boxes[level].append(tmp_domain_label_box)
                        overlaps_all_levels = False

                        domain_labels_levels[domain_count] = level

                        if level > lowest_level_plotted:
                            lowest_level_plotted = level

                        break

                # if the domain label overlaps with something on all levels
                # then plot it on the level where is overlaps the least

                if overlaps_all_levels:

                    level_pos = self.vertical_offset - 0.15 - (min_overlap[1]-1.0)*0.1

                    tmp_domain_label = self.ax.text(
                        domain_center,
                        level_pos,
                        domain_name,
                        horizontalalignment='center',
                        verticalalignment='center',
                        fontsize=self.fontsize
                    )
                    tmp_domain_label_box = tmp_domain_label.get_window_extent(renderer=self.rr)

                    domain_labels[min_overlap[1]].append(tmp_domain_label)
                    domain_label_boxes[min_overlap[1]].append(tmp_domain_label_box)

                    domain_labels_levels[domain_count] = level

                    if min_overlap[1] > lowest_level_plotted:
                        lowest_level_plotted = min_overlap[1]
            domain_count += 1

        # now we know how many levels of domains labels are needed, so
        # remove all levels, make the correction to self.vertical_offset
        # and replot all labels.

        for level, label in list(domain_labels.items()):
            for ll in label:
                ll.remove()

        self.levels_plotted = HORIZONTAL_LEVELS.index(lowest_level_plotted)
        self.vertical_offset += (0.05 * self.levels_plotted)

        domain_count = 0

        for domain in domains:

            if domain[1] is None:
                domain_name = str(domain[0])
            else:
                domain_name = str(domain[1])

            if domain_name in self.exclude:
                continue

            if domain_name in self.rename:
                domain_name = self.rename[domain_name]

            color = '#3385ff'
            if domain_name in self.colors:
                color = self.colors[domain_name]

            domain_start = (int(domain[3])/float(self.normalize))*0.9 + self.offset
            domain_end = (int(domain[4])/float(self.normalize))*0.9 + self.offset
            domain_center = (domain_end-domain_start)/2. + domain_start

            self.ax.add_patch(
                patches.Rectangle(
                    (
                        domain_start,
                        self.vertical_offset,
                    ),
                    domain_end - domain_start,
                    0.1,
                    color=color
                )
            )

            # fetch the level the domain label was determined it was to be
            # plotted on

            level = domain_labels_levels[domain_count]

            level_pos = self.vertical_offset - 0.15 - (level-1.0)*0.1

            tmp_domain_label = self.ax.text(
                domain_center,
                level_pos,
                domain_name,
                horizontalalignment='center',
                verticalalignment='center',
                fontsize=self.fontsize
            )

            domain_count += 1


    def _draw_protein_length_markers(self, protein_length):
        # plot protein length markers

        self.line_end = protein_length/float(self.normalize)*0.9 + self.offset

        self.ax.text(
            0.5,
            self.vertical_offset - (0.5 + self.levels_plotted * 0.1),
            "Amino acid position",
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=self.fontsize
        )

        self.ax.add_line(plt.Line2D(
            (
                self.offset,
                self.offset+self.protein_frame_length
            ),
            (
                self.vertical_offset - (0.35 + self.levels_plotted * 0.05),
                self.vertical_offset - (0.35 + self.levels_plotted * 0.05)
            ),
            color='black'
            )
        )

        # left marker

        self.left_marker_line = self.ax.add_line(plt.Line2D(
            (
                self.offset,
                self.offset
            ),
            (
                self.vertical_offset - (0.38 + self.levels_plotted * 0.05),
                self.vertical_offset - (0.35 + self.levels_plotted * 0.05)
            ),
            color='black'
            )
        )
        self.left_marker_text = self.ax.text(
            self.offset,
            self.vertical_offset - (0.43 + self.levels_plotted * 0.05),
            "0",
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=self.fontsize
        )

        # draw markers for increments of 100 amino acids

        for i in range(1, protein_length+1):
            if (i % 100) == 0:
                self.left_marker_line = self.ax.add_line(plt.Line2D(
                    (
                        self.offset+(i/float(self.normalize)*0.9),
                        self.offset+(i/float(self.normalize)*0.9)
                    ),
                    (
                        self.vertical_offset - (0.38 + self.levels_plotted * 0.05),
                        self.vertical_offset - (0.35 + self.levels_plotted * 0.05)
                    ),
                    color='black'
                    )
                )

        # right marker

        self.right_marker_line = self.ax.add_line(plt.Line2D(
            (
                self.offset+self.protein_frame_length,
                self.offset+self.protein_frame_length
            ),
            (
                self.vertical_offset - (0.38 + self.levels_plotted * 0.05),
                self.vertical_offset - (0.35 + self.levels_plotted * 0.05)
            ),
            color='black'
            )
        )

        self.right_marker_text = self.ax.text(
            self.offset+self.protein_frame_length,
            self.vertical_offset - (0.43 + self.levels_plotted * 0.05),
            str(protein_length),
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=self.fontsize
        )

    def _draw_main_body(self, name_symbols, name_isoform):
        """
        main protein frame
        """

        self.ax.add_patch(
            patches.Rectangle(
                (self.offset, self.vertical_offset),
                self.protein_frame_length,
                0.1,
                fill=False
            )
        )

        self.ax.text(
            0.5,
            0.95,
            name_symbols,
            horizontalalignment='center',
            fontsize=self.fontsize
        )

        self.ax.text(
            0.5,
            0.88,
            name_isoform,
            horizontalalignment='center',
            fontsize=self.fontsize-3
        )


class PlotFusionProtein(_PlotProtein):
    def __init__(self, *args, **kwargs):
        super(PlotFusionProtein, self).__init__(*args, **kwargs)

    def _draw_junction(self):
        # add the junction

        self.ax.add_line(plt.Line2D(
            (
                (self.transcript.transcript_protein_junction_5prime/float(self.normalize))*0.9 + self.offset,
                (self.transcript.transcript_protein_junction_5prime/float(self.normalize))*0.9 + self.offset
            ),
            (
                self.vertical_offset - 0.05,
                self.vertical_offset + 0.15
            ),
            color='black'
            )
        )

        # middle marker, loop until it does not overlap with right marker

        overlaps = True
        line_offset = (self.transcript.transcript_protein_junction_5prime/float(self.normalize))*0.9 + self.offset
        text_offset = (self.transcript.transcript_protein_junction_5prime/float(self.normalize))*0.9 + self.offset
        junction_label_vertical_offset = 0.0

        right_marker_text_box = self.right_marker_text.get_window_extent(renderer=self.rr)
        left_marker_text_box = self.left_marker_text.get_window_extent(renderer=self.rr)

        while overlaps:

            # middle_marker_line_1/2/3 are to draw angled line

            middle_marker_line_1 = self.ax.add_line(plt.Line2D(
                (
                    (self.transcript.transcript_protein_junction_5prime/float(self.normalize))*0.9 + self.offset,
                    (self.transcript.transcript_protein_junction_5prime/float(self.normalize))*0.9 + self.offset
                ),
                (
                    self.vertical_offset - (0.37 + self.levels_plotted * 0.05) - junction_label_vertical_offset,
                    self.vertical_offset - (0.35 + self.levels_plotted * 0.05)
                ),
                color='black'
                )
            )

            middle_marker_line_2 = self.ax.add_line(plt.Line2D(
                (
                    line_offset,
                    (self.transcript.transcript_protein_junction_5prime/float(self.normalize))*0.9 + self.offset
                ),
                (
                    self.vertical_offset - (0.37 + self.levels_plotted * 0.05) - junction_label_vertical_offset,
                    self.vertical_offset - (0.37 + self.levels_plotted * 0.05) - junction_label_vertical_offset
                ),
                color='black'
                )
            )

            middle_marker_line_3 = self.ax.add_line(plt.Line2D(
                (
                    line_offset,
                    line_offset
                ),
                (
                    self.vertical_offset - (0.4 + self.levels_plotted * 0.05) - junction_label_vertical_offset,
                    self.vertical_offset - (0.37 + self.levels_plotted * 0.05) - junction_label_vertical_offset
                ),
                color='black'
                )
            )

            middle_marker_text = self.ax.text(
                text_offset,
                self.vertical_offset - (0.45 + self.levels_plotted * 0.05) - junction_label_vertical_offset,
                str(self.transcript.transcript_protein_junction_5prime),
                horizontalalignment='center',
                verticalalignment='center',
                fontsize=self.fontsize
            )

            # detect if text overlaps

            middle_marker_text_box = middle_marker_text.get_window_extent(
                renderer=self.rr
            )

            # if overlaps then offset the junction text to the left

            if (right_marker_text_box.fully_overlaps(middle_marker_text_box)) and (left_marker_text_box.fully_overlaps(middle_marker_text_box)):
                junction_label_vertical_offset = junction_label_vertical_offset + 0.01

                middle_marker_line_1.remove()
                middle_marker_line_2.remove()
                middle_marker_line_3.remove()
                middle_marker_text.remove()

            elif right_marker_text_box.fully_overlaps(middle_marker_text_box):
                line_offset = line_offset - 0.01
                text_offset = text_offset - 0.01

                middle_marker_line_1.remove()
                middle_marker_line_2.remove()
                middle_marker_line_3.remove()
                middle_marker_text.remove()
            elif left_marker_text_box.fully_overlaps(middle_marker_text_box):
                line_offset = line_offset + 0.01
                text_offset = text_offset + 0.01

                middle_marker_line_1.remove()
                middle_marker_line_2.remove()
                middle_marker_line_3.remove()
                middle_marker_text.remove()
            else:
                overlaps = False

    def draw(self):
        self._scale(self.transcript.protein_length)
        self.protein_frame_length = self.transcript.protein_length/float(self.normalize)*0.9
        self._draw_domains(self.transcript.domains['fusion'])
        self._draw_protein_length_markers(self.transcript.protein_length)
        self._draw_junction()

        name_symbols = self.transcript.gene5prime.gene.gene_name + ' - ' + \
            self.transcript.gene3prime.gene.gene_name
        name_isoform = self.transcript.transcript1.id + ' - ' + \
            self.transcript.transcript2.id
        self._draw_main_body(name_symbols, name_isoform)

        self.ax.axis('off')
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 1)


class PlotWTProtein(_PlotProtein):
    def __init__(self, ensembl_transcript, *args, **kwargs):
        super(PlotWTProtein, self).__init__(*args, **kwargs)
        self.ensembl_transcript = ensembl_transcript

    def draw(self):
        self._scale(len(self.ensembl_transcript.coding_sequence)/3)
        self.protein_frame_length = len(self.ensembl_transcript.coding_sequence)/3/float(self.normalize)*0.9
        self._draw_domains(self.transcript.domains[self.ensembl_transcript.id])
        self._draw_protein_length_markers(int(len(self.ensembl_transcript.coding_sequence)/3))
        self._draw_main_body(
            self.ensembl_transcript.gene.gene_name,
            self.ensembl_transcript.id
        )

        self.ax.axis('off')
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 1)