package com.plugins.junk;

import com.android.build.gradle.BaseExtension;
import com.plugins.junk.junk.facatory.BigDecimalJunkCodeFactory;
import com.plugins.junk.junk.facatory.CalculateJunkCodeFactory;
import com.plugins.junk.junk.facatory.DateJunkCodeFactory;
import com.plugins.junk.junk.facatory.EncoderJunkCodeFactory;
import com.plugins.junk.junk.facatory.JunkCodeFactory;
import com.plugins.junk.junk.facatory.MathJunkCodeFactory;
import com.plugins.junk.junk.facatory.StringBuilderJunkCodeFactory;
import com.plugins.junk.junk.facatory.StringJunkCodeFactory;
import com.plugins.junk.junk.facatory.SwitchJunkCodeFactory;
import com.plugins.junk.junk.facatory.TimeJunkCodeFactory;
import com.plugins.junk.junk.facatory.UUIDJunkCodeFactory;
import com.plugins.junk.junk.insert.InsertCode;
import com.plugins.junk.utils.LogUtil;

import org.apache.commons.io.FileUtils;
import org.gradle.api.Project;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;

import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtMethod;
import javassist.NotFoundException;
import javassist.bytecode.ClassFile;
import javassist.bytecode.CodeAttribute;
import javassist.bytecode.LineNumberAttribute;
import javassist.bytecode.MethodInfo;

public class CodeInjectUtil {
    private static final ClassPool sClassPool = ClassPool.getDefault();
    private static int mMethodSuccessCount = 0;
    private static int mMethodCurIndex = 0;
    private static int mModifyRate = 0;

    public static void injectCode(File baseClassPath, Project project) throws NotFoundException, CannotCompileException {
        try {
            //把类路径添加到classpool
            LogUtil.log("Class build path: " + baseClassPath.getPath());
            sClassPool.insertClassPath(baseClassPath.getPath());
        } catch (NotFoundException e) {
            e.printStackTrace();
        }
        //添加Android相关的类
        BaseExtension android = project.getExtensions().getByType(BaseExtension.class);
        sClassPool.insertClassPath(android.getBootClasspath().get(0).toString());
        LogUtil.log("Android libraries: " + android.getBootClasspath());
        mMethodSuccessCount = 0;
        mMethodCurIndex = 0;
        mModifyRate = 2 + new Random().nextInt(4);
        traverseFile(baseClassPath);
        LogUtil.log("Total method modified: " + mMethodSuccessCount);
    }

    private static void traverseFile(File baseClassFile) {
        File[] files = baseClassFile.listFiles();
        for (File file : files) {
            if (file.isDirectory()) {    //若是目录，则递归
                if (file.getName().contains("META-INF")) {
                    LogUtil.log("文件夹 META-INF 跳过  :" + file.getName());
                    continue;
                }
                traverseFile(file);
            } else if (file.isFile()) {
                if (checkClassFile(file)) {
                    //代码插入
                    inject(file.getPath());
                }
            }
        }
    }

    /**
     * 这里真正实现对代码的注入
     */
    private static void inject(String classFilePath) {
        try {
            FileInputStream is = new FileInputStream(classFilePath);
            ClassFile classFile = new ClassFile(new DataInputStream(new BufferedInputStream(is)));
            CtClass ctClass = sClassPool.get(classFile.getName());
            //解冻
            if (ctClass.isFrozen()) {
                ctClass.defrost();
            }
            injectClassMethod(ctClass);
            //保存class
            byte[] classBytes = ctClass.toBytecode();
            FileUtils.writeByteArrayToFile(new File(classFilePath), classBytes);
            //ctClass.writeFile(baseFilePath);//这个方法有bug,文件大于8K保存文件会损坏
            ctClass.detach();//释放
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private static void injectClassMethod(CtClass ctClass) {
        CtMethod[] ctMethods = ctClass.getDeclaredMethods();
        if (ctMethods != null) {
            for (int i = 0; i < ctMethods.length; i++) {
                CtMethod ctMethod = ctMethods[i];
                if (isSkipMethod(ctMethod)) {
                    continue;
                }

                //这里可以决定插桩的代码数量
                if (mMethodCurIndex % mModifyRate != 0) {
                    boolean success = modifyMethod(ctClass, ctMethod);
                    if (success) {
                        mMethodSuccessCount++;
                    }
                }
                mMethodCurIndex++;
            }
        }
    }

    /**
     * 跳过特殊的方法
     *
     * @param ctMethod
     * @return
     */
    private static boolean isSkipMethod(CtMethod ctMethod) {
        //kotlin协程生成的invokeSuspend方法需要跳过，否则会报错
        if ("invokeSuspend".equals(ctMethod.getName())) {
            return true;
        }
        //kotlin 挂起方法需要跳过，否则会报错（suspend标记的方法会添加Continuation形式参数）
        try {
            for (CtClass parameterType : ctMethod.getParameterTypes()) {
                if ("kotlin.coroutines.Continuation".equals(parameterType.getName())) {
                    return true;
                }
            }
        } catch (NotFoundException e) {
        }
        return false;
    }

    private static boolean modifyMethod(CtClass ctClass, CtMethod ctMethod) {
        JunkCodeFactory junkCodeFactory;
//        switch (0) {
        switch (mRandom.nextInt(10)) {
            case 0:
                junkCodeFactory = new CalculateJunkCodeFactory();
                break;
            case 1:
                junkCodeFactory = new StringJunkCodeFactory();
                break;
            case 2:
                junkCodeFactory = new UUIDJunkCodeFactory();
                break;
            case 3:
                junkCodeFactory = new SwitchJunkCodeFactory();
                break;
            case 4:
                junkCodeFactory = new MathJunkCodeFactory();
                break;
            case 5:
                junkCodeFactory = new EncoderJunkCodeFactory();
                break;
            case 6:
                junkCodeFactory = new DateJunkCodeFactory();
                break;
            case 7:
                junkCodeFactory = new BigDecimalJunkCodeFactory();
                break;
            case 8:
                junkCodeFactory = new StringBuilderJunkCodeFactory();
                break;
            default:
                junkCodeFactory = new TimeJunkCodeFactory();
                break;
        }
        InsertCode insertCode = junkCodeFactory.getInsertCode();
        return insertCode.insert(sClassPool, ctClass, ctMethod);
    }

    private static final Random mRandom = new Random();

    public static Random getRandom() {
        return mRandom;
    }

    /**
     * 获取随机方法行数
     *
     * @param ctMethod
     * @return
     */
    public static int getRandomMethodLine(CtMethod ctMethod) {
        MethodInfo methodInfo = ctMethod.getMethodInfo();
        CodeAttribute codeAttribute = methodInfo.getCodeAttribute();
        int codeLines = methodInfo.getLineNumber(0);
        try {
            LineNumberAttribute lineNumberAttribute = (LineNumberAttribute) codeAttribute.getAttribute(LineNumberAttribute.tag);
            if (lineNumberAttribute != null && lineNumberAttribute.tableLength() > 1) {
                return codeLines + mRandom.nextInt(lineNumberAttribute.tableLength() - 1);
            }
        } catch (Exception e) {
        }
        return codeLines;

    }

    private static String generateBody(CtClass ctClass, CtMethod ctMethod, String newName) throws NotFoundException {
        //方法返回类型
        String returnType = ctMethod.getReturnType().getName();
        System.out.println(returnType);
        //生产的方法返回值
        String methodResult = "${newName}($$);";
        if (!"void".equals(returnType)) {
            //处理返回值
            methodResult = "${returnType} result = " + methodResult;
        }
        System.out.println(methodResult);
        return "{long costStartTime = System.currentTimeMillis();" +
                //调用原方法 xxx$$Impl() $$表示方法接收的所有参数
                methodResult +
                "android.util.Log.e(\"METHOD_COST\", \"${ctClass.name}.${ctMethod.name}() 耗时：\" + (System.currentTimeMillis() - costStartTime) + \"ms\");" +
                //处理一下返回值 void 类型不处理
                ("void".equals(returnType) ? "}" : "return result;}");

    }


    /**
     * 过滤掉一些生成的类
     *
     * @param file
     * @return
     */
    private static boolean checkClassFile(File file) {
        if (file.getName().contains("META-INF")) {
            LogUtil.log("META-INF跳过   :" + file.getName());
            return false;
        }
        if (file.isDirectory()) {
            LogUtil.log("文件夹跳过   :" + file.getName());
            return false;
        }
        String filePath = file.getPath();
        return !filePath.contains("R$") &&
                !filePath.contains("R.class") &&
                !filePath.contains("BuildConfig.class");
    }

}