#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script to calculate FLOPs & PARAMs of a tf.keras model.
Compatible with TF 1.x and TF 2.x
"""
import os, sys, argparse
import tensorflow as tf
from tensorflow.keras.models import load_model
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..'))
from common.utils import get_custom_objects

# check tf version to be compatible with TF 2.x
if tf.__version__.startswith('2'):
    import tensorflow.compat.v1 as tf
    tf.disable_eager_execution()

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

def get_flops(model):
    run_meta = tf.RunMetadata()
    graph = tf.get_default_graph()

    # We use the Keras session graph in the call to the profiler.
    opts = tf.profiler.ProfileOptionBuilder.float_operation()
    flops = tf.profiler.profile(graph=graph, run_meta=run_meta, cmd='op', options=opts)

    opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()
    params = tf.profiler.profile(graph=graph, run_meta=run_meta, cmd='op', options=opts)

    print('Total FLOPs: {}m float_ops'.format(flops.total_float_ops/1e6))
    print('Total PARAMs: {}m'.format(params.total_parameters/1e6))


def main():
    parser = argparse.ArgumentParser(description='tf.keras model FLOPs & PARAMs checking tool')
    parser.add_argument('--model_path', help='model file to evaluate', type=str, required=True)
    args = parser.parse_args()

    custom_object_dict = get_custom_objects()
    model = load_model(args.model_path, compile=False, custom_objects=custom_object_dict)

    get_flops(model)


if __name__ == '__main__':
    main()