# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time

from flask import (
    Flask, abort, request, redirect, url_for, render_template, g,
    send_from_directory)
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.sql.expression import func
from PIL import Image, ImageDraw, ImageFont

from configuration import (
    get_args, get_db_uri, get_templates_list,
    BASE_DIR, MEME_DIR, FONT_PATH)

app = Flask(__name__)
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['SQLALCHEMY_DATABASE_URI'] = get_db_uri()
db = SQLAlchemy(app)


# Model for representing created Memes
class Meme(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    template = db.Column(db.String(80), nullable=False)
    top_text = db.Column(db.String(80), nullable=False)
    bot_text = db.Column(db.String(80), nullable=False)

    def __repr__(self):
        return '<Meme %r>' % self.id


@app.before_first_request
def setup_db():
    # Create folder for memes if it doesn't exist
    if not os.path.exists(MEME_DIR):
        os.makedirs(MEME_DIR)
    # Create tables for models if they don't exist
    db.create_all()


@app.before_request
def setup_request_time():
    start_time = time.time()
    g.request_time = lambda: "%d ms" % ((time.time() - start_time) * 1000)


@app.route('/')
def index():
    return redirect(url_for("get_create_menu"))


@app.route('/recent', methods=['GET'])
def view_recent():
    memes = Meme.query.order_by(Meme.id.desc()).limit(20).all()
    return render_template('recent.html', memes=memes)


@app.route('/random', methods=['GET'])
def view_random():
    meme = Meme.query.order_by(func.random()).first()
    return redirect(url_for('view_meme', meme_id=meme.id))


@app.route('/template', methods=['GET'])
def get_create_menu():
    templates = get_templates_list()
    return render_template('view.html', templates=templates)


@app.route('/template/<string:template>', methods=['GET'])
def get_create(template):
    if template not in get_templates_list():
        abort(400, "Template does not exist.")
    return render_template('create_meme.html', template=template)


@app.route('/meme/<int:meme_id>', methods=['GET'])
def view_meme(meme_id):
    meme_file = os.path.join(MEME_DIR, '%d.png' % meme_id)
    if not os.path.exists(meme_file):
        generate_meme(meme_file, meme_id)
    print(meme_file)
    return send_from_directory(MEME_DIR, '%d.png' % meme_id)


@app.route('/meme', methods=['POST'])
def create_meme():
    try:
        meme = Meme(
            template=request.form['template'],
            top_text=request.form['top'],
            bot_text=request.form['bottom']
        )
        db.session.add(meme)
        db.session.commit()

        return redirect(url_for('view_meme', meme_id=meme.id))
    except KeyError:
        abort(400, "Incorrect parameters.")


def generate_meme(file, meme_id):
    # Query for meme
    meme = Meme.query.filter(Meme.id == meme_id).first()
    if meme is None:
        abort(400, 'Meme does not exist.')
    # Load template
    template_file = os.path.join(
        BASE_DIR, 'static', 'templates', meme.template)
    if not os.path.exists(template_file):
        abort(400, 'Template does not exist')
    template = Image.open(template_file)
    # Get Font Details
    font, top_loc, bot_loc = calc_font_details(
        meme.top_text, meme.bot_text, template.size)
    draw = ImageDraw.Draw(template)
    draw_text(draw, top_loc[0], top_loc[1], meme.top_text, font)
    draw_text(draw, bot_loc[0], bot_loc[1], meme.bot_text, font)

    template.save(file)


# Calculate font size and location
def calc_font_details(top, bot, img_size):
    font_size = 50
    font = ImageFont.truetype(FONT_PATH, font_size)
    max_width = img_size[0] - 20
    # Get ideal font size
    while font.getsize(top)[0] > max_width or font.getsize(bot)[0] > max_width:
        font_size = font_size - 1
        font = ImageFont.truetype(FONT_PATH, font_size)
    # Get font locations
    top_loc = ((img_size[0] - font.getsize(top)[0])/2, -5)
    bot_size = font.getsize(bot)
    bot_loc = ((img_size[0] - bot_size[0])/2, img_size[1] - bot_size[1] - 5)
    return font, top_loc, bot_loc


# Draws the given text with a border
def draw_text(draw, x, y, text, font):
    # Draw border
    draw.text((x-1, y-1), text, font=font, fill="black")
    draw.text((x+1, y-1), text, font=font, fill="black")
    draw.text((x-1, y+1), text, font=font, fill="black")
    draw.text((x+1, y+1), text, font=font, fill="black")
    # Draw text
    draw.text((x, y), text, font=font, fill="white")


if __name__ == '__main__':
    # Run dev server (for debugging only)
    args = get_args()
    app.run(host=args.host, port=args.port, debug=True)