package jadx.core.clsp;

import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.GenericInfo;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.DecodeException;
import jadx.core.utils.exceptions.JadxRuntimeException;
import jadx.core.utils.files.FileUtils;
import jadx.core.utils.files.ZipSecurity;

/**
 * Classes list for import into classpath graph
 */
public class ClsSet {
	private static final Logger LOG = LoggerFactory.getLogger(ClsSet.class);

	private static final String CLST_EXTENSION = ".jcst";
	private static final String CLST_FILENAME = "core" + CLST_EXTENSION;
	private static final String CLST_PKG_PATH = ClsSet.class.getPackage().getName().replace('.', '/');

	private static final String JADX_CLS_SET_HEADER = "jadx-cst";
	private static final int VERSION = 2;

	private static final String STRING_CHARSET = "US-ASCII";

	private static final NClass[] EMPTY_NCLASS_ARRAY = new NClass[0];

	private enum TypeEnum {
		WILDCARD, GENERIC, GENERIC_TYPE, OBJECT, ARRAY, PRIMITIVE
	}

	private NClass[] classes;

	public void loadFromClstFile() throws IOException, DecodeException {
		try (InputStream input = getClass().getResourceAsStream(CLST_FILENAME)) {
			if (input == null) {
				throw new JadxRuntimeException("Can't load classpath file: " + CLST_FILENAME);
			}
			load(input);
		}
	}

	public void loadFrom(RootNode root) {
		List<ClassNode> list = root.getClasses(true);
		Map<String, NClass> names = new HashMap<>(list.size());
		int k = 0;
		for (ClassNode cls : list) {
			String clsRawName = cls.getRawName();
			if (cls.getAccessFlags().isPublic()) {
				cls.load();
				NClass nClass = new NClass(clsRawName, k);
				if (names.put(clsRawName, nClass) != null) {
					throw new JadxRuntimeException("Duplicate class: " + clsRawName);
				}
				k++;
				nClass.setGenerics(cls.getGenerics());
				nClass.setMethods(getMethodsDetails(cls));
			} else {
				names.put(clsRawName, null);
			}
		}
		classes = new NClass[k];
		k = 0;
		for (ClassNode cls : list) {
			if (cls.getAccessFlags().isPublic()) {
				NClass nClass = getCls(cls.getRawName(), names);
				if (nClass == null) {
					throw new JadxRuntimeException("Missing class: " + cls);
				}
				nClass.setParents(makeParentsArray(cls, names));
				classes[k] = nClass;
				k++;
			}
		}
	}

	private List<NMethod> getMethodsDetails(ClassNode cls) {
		List<NMethod> methods = new ArrayList<>();
		for (MethodNode m : cls.getMethods()) {
			AccessInfo accessFlags = m.getAccessFlags();
			if (accessFlags.isPublic() || accessFlags.isProtected()) {
				processMethodDetails(methods, m, accessFlags);
			}
		}
		return methods;
	}

	private void processMethodDetails(List<NMethod> methods, MethodNode mth, AccessInfo accessFlags) {
		List<ArgType> args = mth.getArgTypes();
		boolean genericArg = false;
		ArgType[] genericArgs;
		if (args.isEmpty()) {
			genericArgs = null;
		} else {
			int argsCount = args.size();
			genericArgs = new ArgType[argsCount];
			for (int i = 0; i < argsCount; i++) {
				ArgType argType = args.get(i);
				if (argType.isGeneric() || argType.isGenericType()) {
					genericArgs[i] = argType;
					genericArg = true;
				}
			}
		}
		ArgType retType = mth.getReturnType();
		if (!retType.isGeneric() && !retType.isGenericType()) {
			retType = null;
		}
		boolean varArgs = accessFlags.isVarArgs();
		if (genericArg || retType != null || varArgs) {
			methods.add(new NMethod(mth.getMethodInfo().getShortId(), genericArgs, retType, varArgs));
		}
	}

	public static NClass[] makeParentsArray(ClassNode cls, Map<String, NClass> names) {
		List<NClass> parents = new ArrayList<>(1 + cls.getInterfaces().size());
		ArgType superClass = cls.getSuperClass();
		if (superClass != null) {
			NClass c = getCls(superClass.getObject(), names);
			if (c != null) {
				parents.add(c);
			}
		}
		for (ArgType iface : cls.getInterfaces()) {
			NClass c = getCls(iface.getObject(), names);
			if (c != null) {
				parents.add(c);
			}
		}
		int size = parents.size();
		if (size == 0) {
			return EMPTY_NCLASS_ARRAY;
		}
		return parents.toArray(new NClass[size]);
	}

	private static NClass getCls(String fullName, Map<String, NClass> names) {
		NClass cls = names.get(fullName);
		if (cls == null) {
			LOG.debug("Class not found: {}", fullName);
		}
		return cls;
	}

	void save(Path path) throws IOException {
		FileUtils.makeDirsForFile(path);
		String outputName = path.getFileName().toString();
		if (outputName.endsWith(CLST_EXTENSION)) {
			try (BufferedOutputStream outputStream = new BufferedOutputStream(Files.newOutputStream(path))) {
				save(outputStream);
			}
		} else if (outputName.endsWith(".jar")) {
			Path temp = FileUtils.createTempFile(".zip");
			Files.copy(path, temp, StandardCopyOption.REPLACE_EXISTING);

			try (ZipOutputStream out = new ZipOutputStream(Files.newOutputStream(path));
					ZipInputStream in = new ZipInputStream(Files.newInputStream(temp))) {
				String clst = CLST_PKG_PATH + '/' + CLST_FILENAME;
				out.putNextEntry(new ZipEntry(clst));
				save(out);
				ZipEntry entry = in.getNextEntry();
				while (entry != null) {
					if (!entry.getName().equals(clst)) {
						out.putNextEntry(new ZipEntry(entry.getName()));
						FileUtils.copyStream(in, out);
					}
					entry = in.getNextEntry();
				}
			}
		} else {
			throw new JadxRuntimeException("Unknown file format: " + outputName);
		}
	}

	public void save(OutputStream output) throws IOException {
		DataOutputStream out = new DataOutputStream(output);
		out.writeBytes(JADX_CLS_SET_HEADER);
		out.writeByte(VERSION);

		LOG.info("Classes count: {}", classes.length);
		Map<String, NClass> names = new HashMap<>(classes.length);
		out.writeInt(classes.length);
		for (NClass cls : classes) {
			writeString(out, cls.getName());
			names.put(cls.getName(), cls);
		}
		for (NClass cls : classes) {
			NClass[] parents = cls.getParents();
			out.writeByte(parents.length);
			for (NClass parent : parents) {
				out.writeInt(parent.getId());
			}
			writeGenerics(out, cls, names);
			List<NMethod> methods = cls.getMethodsList();
			out.writeByte(methods.size());
			for (NMethod method : methods) {
				writeMethod(out, method, names);
			}
		}
	}

	private static void writeGenerics(DataOutputStream out, NClass cls, Map<String, NClass> names) throws IOException {
		List<GenericInfo> genericsList = cls.getGenerics();
		out.writeByte(genericsList.size());
		for (GenericInfo genericInfo : genericsList) {
			writeArgType(out, genericInfo.getGenericType(), names);
			List<ArgType> extendsList = genericInfo.getExtendsList();
			out.writeByte(extendsList.size());
			for (ArgType type : extendsList) {
				writeArgType(out, type, names);
			}

		}
	}

	private static void writeMethod(DataOutputStream out, NMethod method, Map<String, NClass> names) throws IOException {
		writeLongString(out, method.getShortId());

		ArgType[] argTypes = method.getGenericArgs();
		if (argTypes == null) {
			out.writeByte(0);
		} else {
			int argCount = 0;
			for (ArgType arg : argTypes) {
				if (arg != null) {
					argCount++;
				}
			}
			out.writeByte(argCount);
			// last argument first
			for (int i = argTypes.length - 1; i >= 0; i--) {
				ArgType argType = argTypes[i];
				if (argType != null) {
					out.writeByte(i);
					writeArgType(out, argType, names);
				}
			}
		}
		if (method.getReturnType() == null) {
			out.writeBoolean(false);
		} else {
			out.writeBoolean(true);
			writeArgType(out, method.getReturnType(), names);
		}
		out.writeBoolean(method.isVarArgs());
	}

	private static void writeArgType(DataOutputStream out, ArgType argType, Map<String, NClass> names) throws IOException {
		if (argType.getWildcardType() != null) {
			out.writeByte(TypeEnum.WILDCARD.ordinal());
			ArgType.WildcardBound bound = argType.getWildcardBound();
			out.writeByte(bound.getNum());
			if (bound != ArgType.WildcardBound.UNBOUND) {
				writeArgType(out, argType.getWildcardType(), names);
			}
		} else if (argType.isGeneric()) {
			out.writeByte(TypeEnum.GENERIC.ordinal());
			out.writeInt(names.get(argType.getObject()).getId());
			ArgType[] types = argType.getGenericTypes();
			if (types == null) {
				out.writeByte(0);
			} else {
				out.writeByte(types.length);
				for (ArgType type : types) {
					writeArgType(out, type, names);
				}
			}
		} else if (argType.isGenericType()) {
			out.writeByte(TypeEnum.GENERIC_TYPE.ordinal());
			writeString(out, argType.getObject());
		} else if (argType.isObject()) {
			out.writeByte(TypeEnum.OBJECT.ordinal());
			out.writeInt(names.get(argType.getObject()).getId());
		} else if (argType.isArray()) {
			out.writeByte(TypeEnum.ARRAY.ordinal());
			writeArgType(out, argType.getArrayElement(), names);
		} else if (argType.isPrimitive()) {
			out.writeByte(TypeEnum.PRIMITIVE.ordinal());
			out.writeByte(argType.getPrimitiveType().getShortName().charAt(0));
		} else {
			throw new JadxRuntimeException("Cannot save type: " + argType);
		}
	}

	private void load(File input) throws IOException, DecodeException {
		String name = input.getName();
		try (InputStream inputStream = new FileInputStream(input)) {
			if (name.endsWith(CLST_EXTENSION)) {
				load(inputStream);
			} else if (name.endsWith(".jar")) {
				try (ZipInputStream in = new ZipInputStream(inputStream)) {
					ZipEntry entry = in.getNextEntry();
					while (entry != null) {
						if (entry.getName().endsWith(CLST_EXTENSION) && ZipSecurity.isValidZipEntry(entry)) {
							load(in);
						}
						entry = in.getNextEntry();
					}
				}
			} else {
				throw new JadxRuntimeException("Unknown file format: " + name);
			}
		}
	}

	private void load(InputStream input) throws IOException, DecodeException {
		try (DataInputStream in = new DataInputStream(input)) {
			byte[] header = new byte[JADX_CLS_SET_HEADER.length()];
			int readHeaderLength = in.read(header);
			int version = in.readByte();
			if (readHeaderLength != JADX_CLS_SET_HEADER.length()
					|| !JADX_CLS_SET_HEADER.equals(new String(header, STRING_CHARSET))
					|| version != VERSION) {
				throw new DecodeException("Wrong jadx class set header");
			}
			int count = in.readInt();
			classes = new NClass[count];
			for (int i = 0; i < count; i++) {
				String name = readString(in);
				classes[i] = new NClass(name, i);
			}
			for (int i = 0; i < count; i++) {
				int pCount = in.readByte();
				NClass[] parents = new NClass[pCount];
				for (int j = 0; j < pCount; j++) {
					parents[j] = classes[in.readInt()];
				}
				NClass nClass = classes[i];
				nClass.setParents(parents);
				nClass.setGenerics(readGenerics(in));
				nClass.setMethods(readClsMethods(in));
			}
		}
	}

	private List<GenericInfo> readGenerics(DataInputStream in) throws IOException {
		int count = in.readByte();
		if (count == 0) {
			return Collections.emptyList();
		}
		List<GenericInfo> list = new ArrayList<>(count);
		for (int i = 0; i < count; i++) {
			ArgType genericType = readArgType(in);
			List<ArgType> extendsList;
			byte extCount = in.readByte();
			if (extCount == 0) {
				extendsList = Collections.emptyList();
			} else {
				extendsList = new ArrayList<>(extCount);
				for (int j = 0; j < extCount; j++) {
					extendsList.add(readArgType(in));
				}
			}
			list.add(new GenericInfo(genericType, extendsList));
		}
		return list;
	}

	private List<NMethod> readClsMethods(DataInputStream in) throws IOException {
		int mCount = in.readByte();
		List<NMethod> methods = new ArrayList<>(mCount);
		for (int j = 0; j < mCount; j++) {
			methods.add(readMethod(in));
		}
		return methods;
	}

	private NMethod readMethod(DataInputStream in) throws IOException {
		String shortId = readLongString(in);
		int argCount = in.readByte();
		ArgType[] argTypes = null;
		for (int i = 0; i < argCount; i++) {
			int index = in.readByte();
			ArgType argType = readArgType(in);
			if (argTypes == null) {
				argTypes = new ArgType[index + 1];
			}
			argTypes[index] = argType;
		}
		ArgType retType = in.readBoolean() ? readArgType(in) : null;
		boolean varArgs = in.readBoolean();
		return new NMethod(shortId, argTypes, retType, varArgs);
	}

	private ArgType readArgType(DataInputStream in) throws IOException {
		int ordinal = in.readByte();
		switch (TypeEnum.values()[ordinal]) {
			case WILDCARD:
				int bounds = in.readByte();
				return bounds == 0
						? ArgType.wildcard()
						: ArgType.wildcard(readArgType(in), ArgType.WildcardBound.getByNum(bounds));

			case GENERIC:
				String obj = classes[in.readInt()].getName();
				int typeLength = in.readByte();
				ArgType[] generics;
				if (typeLength == 0) {
					generics = null;
				} else {
					generics = new ArgType[typeLength];
					for (int i = 0; i < typeLength; i++) {
						generics[i] = readArgType(in);
					}
				}
				return ArgType.generic(obj, generics);

			case GENERIC_TYPE:
				return ArgType.genericType(readString(in));

			case OBJECT:
				return ArgType.object(classes[in.readInt()].getName());

			case ARRAY:
				return ArgType.array(readArgType(in));

			case PRIMITIVE:
				char shortName = (char) in.readByte();
				return ArgType.parse(shortName);

			default:
				throw new JadxRuntimeException("Unsupported Arg Type: " + ordinal);
		}
	}

	private static void writeString(DataOutputStream out, String name) throws IOException {
		byte[] bytes = name.getBytes(STRING_CHARSET);
		out.writeByte(bytes.length);
		out.write(bytes);
	}

	private static void writeLongString(DataOutputStream out, String name) throws IOException {
		byte[] bytes = name.getBytes(STRING_CHARSET);
		out.writeShort(bytes.length);
		out.write(bytes);
	}

	private static String readString(DataInputStream in) throws IOException {
		int len = in.readByte();
		return readString(in, len);
	}

	private static String readLongString(DataInputStream in) throws IOException {
		int len = in.readShort();
		return readString(in, len);
	}

	private static String readString(DataInputStream in, int len) throws IOException {
		byte[] bytes = new byte[len];
		int count = in.read(bytes);
		while (count != len) {
			int res = in.read(bytes, count, len - count);
			if (res == -1) {
				throw new IOException("String read error");
			} else {
				count += res;
			}
		}
		return new String(bytes, STRING_CHARSET);
	}

	public int getClassesCount() {
		return classes.length;
	}

	public void addToMap(Map<String, NClass> nameMap) {
		for (NClass cls : classes) {
			nameMap.put(cls.getName(), cls);
		}
	}
}