/*
 * Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 3 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 3 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 3 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package com.oracle.truffle.r.ffi.codegen;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.oracle.truffle.r.ffi.impl.upcalls.StdUpCallsRFFI;
import com.oracle.truffle.r.ffi.processor.RFFICpointer;
import com.oracle.truffle.r.ffi.processor.RFFICstring;
import com.oracle.truffle.r.ffi.processor.RFFIInject;

/**
 * Generates 1) C wrapper that calls each RFFI function and converts between SEXP and primitive
 * types, 2) sequence of calls to CALLDEF to register all those functions as ".Call" targets.
 *
 * This creates R interface to all applicable RFFI functions. RFFI functions working with raw
 * pointers are excluded.
 *
 * The generated code is to be used in testrffi package located in
 * "com.oracle.truffle.r.test.native/packages/testrffi/testrffi".
 */
public final class FFITestsCodeGen extends CodeGenBase {
    private static final String FUN_PREFIX = "api_";
    private static final HashSet<String> IGNORE_FUNS = new HashSet<>(
                    Arrays.asList("Rf_cospi", "Rf_sinpi", "Rf_tanpi", "R_forceAndCall", "Rf_duplicate", "R_ToplevelExec", "R_CleanUp", "R_ParseVector", "octsize", "R_NewHashedEnv", "Rf_ScalarComplex",
                                    "Rf_ScalarRaw", "Rf_allocList", "DispatchPRIMFUN", "COMPLEX_ELT", "match5" /*
                                                                                                                * match5
                                                                                                                * is
                                                                                                                * not
                                                                                                                * public
                                                                                                                */));

    public static void main(String[] args) {
        new FFITestsCodeGen().run(args);
    }

    private void run(String[] args) {
        this.initOutput(args);
        if (Arrays.stream(args).anyMatch(x -> "-init".equals(x))) {
            generateCInit();
        } else if (Arrays.stream(args).anyMatch(x -> "-h".equals(x))) {
            generateH();
        } else if (Arrays.stream(args).anyMatch(x -> "-r".equals(x))) {
            generateR();
        } else {
            generateC();
        }
    }

    private void generateR() {
        out.print(COPYRIGHT_HASH);
        out.println();
        out.println("#############");
        out.printf("# Code generated by %s class run with option '-r'\n", FFITestsCodeGen.class.getName());
        printMxHelp("#");
        out.println("# R wrappers for all the generated RFFI C wrappers\n");
        getFFIMethods().forEach(method -> {
            out.printf("api.%s <- function(...) .Call(C_api_%s, ...)\n", getName(method), getName(method));
        });
    }

    private void generateCInit() {
        out.println(COPYRIGHT);
        out.printf("// Code generated by %s class run with option '-init'\n", FFITestsCodeGen.class.getName());
        printMxHelp("//");
        out.println("// The following code registers all C functions that wrap RFFI functions and convert SEXP <-> primitive types.");
        out.println("// The definitions of the C functions could be generated by the same Java class (but run without any option)");
        out.println("// RFFI functions that take/return C pointers are ignored");
        out.println("// This code is '#included' into init.c ");
        getFFIMethods().forEach(method -> {
            out.printf("CALLDEF(%s%s, %d),\n", FUN_PREFIX, getName(method), getNonInjectedParameterCount(method));
        });
        out.println("// ---- end of generated code");
    }

    private void generateH() {
        out.println(COPYRIGHT);
        out.printf("// Code generated by %s class run with option '-h'\n", FFITestsCodeGen.class.getName());
        printMxHelp("//");
        out.println("// See the corresponding C file for more details");
        printIncludes();
        getFFIMethods().forEach(method -> {
            out.println(getDeclaration(method) + ";\n");
        });
    }

    private void generateC() {
        out.println(COPYRIGHT);
        out.printf("// Code generated by %s class\n", FFITestsCodeGen.class.getName());
        printMxHelp("//");
        out.println("// The following code defines a 'SEXP' variant of every RFFI function implemented in FastR");
        out.println("// Run the same Java class with '-init' option to get sequence of CALLDEF statements that register those functions for use from R");
        out.println("// RFFI functions that take/return C pointers are ignored");
        printIncludes();
        out.println("#include \"rffiwrappers.h\"\n");
        out.println("#pragma GCC diagnostic push");
        out.println("#pragma GCC diagnostic ignored \"-Wint-conversion\"\n");
        out.println("#pragma GCC diagnostic ignored \"-Wincompatible-pointer-types\"\n");
        getFFIMethods().forEach(method -> {
            out.println(getDeclaration(method) + " {");
            String stmt = String.format("%s(%s)", getName(method),
                            Arrays.stream(method.getParameters()).filter(FFITestsCodeGen::isNotInjected).map(FFITestsCodeGen::toCValue).collect(Collectors.joining(", ")));
            out.println("    " + getReturnStmt(method.getReturnType(), stmt) + ';');
            if (method.getReturnType() == void.class) {
                out.println("    return R_NilValue;");
            }
            out.println("}\n");
        });
        out.println("#pragma GCC diagnostic pop");
        out.println("#pragma GCC diagnostic pop");
    }

    private static int getNonInjectedParameterCount(Method m) {
        Annotation[][] parameterAnnotations = m.getParameterAnnotations();
        int injectedArgCounter = 0;
        for (int i = 0; i < parameterAnnotations.length; i++) {
            for (int j = 0; j < parameterAnnotations[i].length; j++) {
                if (parameterAnnotations[i][j] instanceof RFFIInject) {
                    injectedArgCounter++;
                }
            }
        }
        return m.getParameterCount() - injectedArgCounter;
    }

    private static String getDeclaration(Method method) {
        return String.format("SEXP %s%s(", FUN_PREFIX, getName(method)) +
                        Arrays.stream(method.getParameters()).filter(FFITestsCodeGen::isNotInjected).map(p -> "SEXP " + p.getName()).collect(Collectors.joining(", ")) + ')';
    }

    private static String getName(Method m) {
        return m.getName().replace("_FASTR", "").replace("FASTR_", "");
    }

    private void printIncludes() {
        out.print("#define NO_FASTR_REDEFINE\n" +
                        "#include <R.h>\n" +
                        "#include <Rdefines.h>\n" +
                        "#include <Rinterface.h>\n" +
                        "#include <Rinternals.h>\n" +
                        "#include <Rinterface.h>\n" +
                        "#include <R_ext/Parse.h>\n" +
                        "#include <R_ext/Connections.h>\n" +
                        "#include <Rmath.h>\n\n");
    }

    private static Stream<Method> getFFIMethods() {
        return Arrays.stream(StdUpCallsRFFI.class.getMethods()).filter(m -> !ignoreMethod(m));
    }

    // ignore methods with C pointers, we only support SEXP, strings and primitives
    private static boolean ignoreMethod(Method method) {
        return IGNORE_FUNS.contains(method.getName()) || method.getAnnotation(RFFICpointer.class) != null ||
                        Arrays.stream(method.getParameterAnnotations()).anyMatch(FFITestsCodeGen::anyCPointer);
    }

    private static boolean isNotInjected(Parameter param) {
        return param.getAnnotation(RFFIInject.class) == null;
    }

    private static String toCValue(Parameter param) {
        if (param.getAnnotation(RFFICstring.class) != null || param.getType() == String.class) {
            return "R_CHAR(STRING_ELT(" + param.getName() + ", 0))";
        } else if (param.getType() == int.class || param.getType() == long.class || param.getType() == boolean.class) {
            return "INTEGER_VALUE(" + param.getName() + ")";
        } else if (param.getType() == double.class) {
            return "NUMERIC_VALUE(" + param.getName() + ")";
        } else {
            return param.getName();
        }
    }

    private static String getReturnStmt(Class<?> returnType, String value) {
        return returnType == void.class ? value : ("return " + fromCValueToSEXP(returnType, value));
    }

    private static String fromCValueToSEXP(Class<?> fromType, String value) {
        if (fromType == int.class || fromType == long.class) {
            return "ScalarInteger(" + value + ")";
        } else if (fromType == double.class) {
            return "ScalarReal(" + value + ")";
        } else if (fromType == boolean.class) {
            return "ScalarLogical(" + value + ")";
        } else if (fromType == String.class) {
            return "ScalarString(Rf_mkString(" + value + "))";
        } else if (fromType == Object.class) {
            return value;
        } else {
            throw new RuntimeException("Unsupported return type of RFFI function: " + fromType.getSimpleName());
        }
    }

    private void printMxHelp(String prefix) {
        out.println(prefix + " All the generated files in testrffi can be regenerated by running 'mx testrfficodegen'");
    }

    private static boolean anyCPointer(Annotation[] items) {
        return Arrays.stream(items).anyMatch(a -> a.annotationType() == RFFICpointer.class);
    }
}
