/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.graphx.util

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.collection.mutable.HashSet
import scala.language.existentials

import org.apache.spark.util.Utils

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._


/**
 * Includes an utility function to test whether a function accesses a specific attribute
 * of an object.
 */
private[graphx] object BytecodeUtils {

  /**
   * Test whether the given closure invokes the specified method in the specified class.
   */
  def invokedMethod(closure: AnyRef, targetClass: Class[_], targetMethod: String): Boolean = {
    if (_invokedMethod(closure.getClass, "apply", targetClass, targetMethod)) {
      true
    } else {
      // look at closures enclosed in this closure
      for (f <- closure.getClass.getDeclaredFields
           if f.getType.getName.startsWith("scala.Function")) {
        f.setAccessible(true)
        if (invokedMethod(f.get(closure), targetClass, targetMethod)) {
          return true
        }
      }
      return false
    }
  }

  private def _invokedMethod(cls: Class[_], method: String,
      targetClass: Class[_], targetMethod: String): Boolean = {

    val seen = new HashSet[(Class[_], String)]
    var stack = List[(Class[_], String)]((cls, method))

    while (stack.nonEmpty) {
      val (c, m) = stack.head
      stack = stack.tail
      seen.add((c, m))
      val finder = new MethodInvocationFinder(c.getName, m)
      getClassReader(c).accept(finder, 0)
      for (classMethod <- finder.methodsInvoked) {
        // println(classMethod)
        if (classMethod._1 == targetClass && classMethod._2 == targetMethod) {
          return true
        } else if (!seen.contains(classMethod)) {
          stack = classMethod :: stack
        }
      }
    }
    return false
  }

  /**
   * Get an ASM class reader for a given class from the JAR that loaded it.
   */
  private def getClassReader(cls: Class[_]): ClassReader = {
    // Copy data over, before delegating to ClassReader - else we can run out of open file handles.
    val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
    val resourceStream = cls.getResourceAsStream(className)
    // todo: Fixme - continuing with earlier behavior ...
    if (resourceStream == null) return new ClassReader(resourceStream)

    val baos = new ByteArrayOutputStream(128)
    Utils.copyStream(resourceStream, baos, true)
    new ClassReader(new ByteArrayInputStream(baos.toByteArray))
  }

  /**
   * Given the class name, return whether we should look into the class or not. This is used to
   * skip examing a large quantity of Java or Scala classes that we know for sure wouldn't access
   * the closures. Note that the class name is expected in ASM style (i.e. use "/" instead of ".").
   */
  private def skipClass(className: String): Boolean = {
    val c = className
    c.startsWith("java/") || c.startsWith("scala/") || c.startsWith("javax/")
  }

  /**
   * Find the set of methods invoked by the specified method in the specified class.
   * For example, after running the visitor,
   *   MethodInvocationFinder("spark/graph/Foo", "test")
   * its methodsInvoked variable will contain the set of methods invoked directly by
   * Foo.test(). Interface invocations are not returned as part of the result set because we cannot
   * determine the actual metod invoked by inspecting the bytecode.
   */
  private class MethodInvocationFinder(className: String, methodName: String)
    extends ClassVisitor(ASM4) {

    val methodsInvoked = new HashSet[(Class[_], String)]

    override def visitMethod(access: Int, name: String, desc: String,
                             sig: String, exceptions: Array[String]): MethodVisitor = {
      if (name == methodName) {
        new MethodVisitor(ASM4) {
          override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
            if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
              if (!skipClass(owner)) {
                methodsInvoked.add((Class.forName(owner.replace("/", ".")), name))
              }
            }
          }
        }
      } else {
        null
      }
    }
  }
}