# -*- coding: utf-8 -*-
"""This module provides functionality to compile PyDOLFIN compatible
extension modules."""

# Copyright (C) 2009-2016 Johan Hake
#
# This file is part of DOLFIN.
#
# DOLFIN is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# DOLFIN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Johannes Ring, 2011
# Modified by Mikael Mortensen, 2014
# Modified by Martin Sandve Alnæs, 2016

import sys
import re
import hashlib

import instant
from ufl.utils.sorting import canonicalize_metadata
import ffc
from dolfin_utils.cppparser import *

# Import PyDOLFIN
import dolfin
import dolfin.cpp as cpp
from dolfin.cpp import parameters

from dolfin.compilemodules.jit import mpi_jit_decorator
from dolfin.compilemodules.swigimportinfo import *

__all__ = ["compile_extension_module",
           "expression_to_code_fragments",
           "math_header"]

# Bump the interface version if anything changes that invalidates
# cached modules (not required for change in generated code, swig
# version or dolfin version)
_interface_version = 8

# Add max_signature_length to parameter system (don't need this in C++)
cpp.parameters.add("max_signature_length", 0)

# A list of supported math builtins
_cpp_math_builtins = [
    # <cmath> functions: from http://www.cplusplus.com/reference/cmath/
    "cos", "sin", "tan", "acos", "asin", "atan", "atan2",
    "cosh", "sinh", "tanh", "exp", "frexp", "ldexp", "log", "log10", "modf",
    "pow", "sqrt", "ceil", "fabs", "floor", "fmod",
    "max", "min"]

_math_builtins = [
    # math.h functions: http://en.wikibooks.org/wiki/C_Programming/C_Reference/math.h
    "acos", "asin", "atan", "atan2", "ceil", "cos", "cosh", "exp",
    "fabs", "floor", "fmod", "frexp", "ldexp", "log", "log10", "modf",
    "pow", "sin", "sinh", "sqrt", "tan", "tanh", "acosh", "asinh", "atanh",
    "cbrt", "copysign", "erf", "erfc", "exp2", "expm1", "fdim", "fma", "fmax",
    "fmin", "hypot", "ilogb", "lgamma", "llrint", "lrint", "llround", "lround",
    "log1p", "log2", "logb", "nan", "nearbyint", "nextafter", "nexttoward",
    "remainder", "remquo", "rint", "round", "scalbln", "scalbn", "tgamma", "trunc"]

_math_dolfin = [
    # functions from dolfin::math:
    "sqr", "ipow", "rand", "near", "DOLFIN_EPS", "DOLFIN_PI", "pi"]

_all_math = (list(set(_math_builtins).difference(_cpp_math_builtins))
            + _math_dolfin + _cpp_math_builtins)

math_header = """
// cmath functions
%s

const double pi = DOLFIN_PI;
""" % "\n".join("using std::%s;" % mf for mf in _cpp_math_builtins)


# Copied from http://en.cppreference.com/w/cpp/keyword 2017/02/02
_allowed_cpp_keywords = """and
and_eq
bitand
bitor
bool
char
char16_t
char32_t
double
false
float
int
long
not
not_eq
or
or_eq
short
signed
sizeof
true
unsigned
wchar_t
xor
xor_eq
""".splitlines()

_disallowed_cpp_keywords = """alignas
alignof
asm
atomic_cancel
atomic_commit
atomic_noexcept
auto
break
case
catch
class
compl
concept
const
constexpr
const_cast
continue
decltype
default
delete
do
dynamic_cast
else
enum
explicit
export
extern
for
friend
goto
if
import
inline
module
mutable
namespace
new
noexcept
nullptr
operator
private
protected
public
register
reinterpret_cast
requires
return
static
static_assert
static_cast
struct
switch
synchronized
template
this
thread_local
throw
try
typedef
typeid
typename
union
using
virtual
void
volatile
while""".splitlines()


# User defined attributes with same names as attributes
# of any compiled expression base class cause confusion
# and segfaults with swig
_fenics_keywords = [
    # Attributes of ufl.Expr and ufl.Coefficient
    "T",
    "dx",
    "evaluate",
    "geometric_dimension",
    "count",
    # Attributes of dolfin Variable, GenericFunction, and Function
    "id",
    "eval",
    "value_rank",
    "value_dimension",
    "value_size",
    "restrict",
    "update",
    "vector",
    "in",
    "interpolate",
    "extrapolate",
    ]


_additional_declarations = r"""
%%init%%{
import_array();
%%}

// Include global SWIG interface files:
// Typemaps, shared_ptr declarations, exceptions, version
%%include <std_shared_ptr.i>

// Global typemaps and forward declarations
%%include "dolfin/swig/typemaps/includes.i"
%%include "dolfin/swig/forwarddeclarations.i"

// Global exceptions
%%include <exception.i>
%%include "dolfin/swig/exceptions.i"

// Do not expand default arguments in C++ by generating two an extra
// function in the SWIG layer. This reduces code bloat.
%%feature("compactdefaultargs");

// STL SWIG string class
%%include <std_string.i>

// Manually import ufc:
%%{
#include <ufc.h>
%%}
%%shared_ptr(ufc::function)
%%shared_ptr(ufc::dofmap)
%%shared_ptr(ufc::finite_element)
%%shared_ptr(ufc::coordinate_mapping)
%%shared_ptr(ufc::function)
%%shared_ptr(ufc::form)
%%ignore ufc::integral;
%%ignore ufc::cell_integral;
%%ignore ufc::exterior_facet_integral;
%%ignore ufc::interior_facet_integral;
%%ignore ufc::vertex_integral;
%%ignore ufc::custom_integral;
%%ignore ufc::cutcell_integral;
%%ignore ufc::interface_integral;
%%ignore ufc::overlap_integral;
%%include <ufc.h>

// Local shared_ptr declarations
%(shared_ptr_declarations)s

%(additional_declarations)s

// Import statements
%(dolfin_import_statement)s

%%feature("autodoc", "1");

%%inline %%{
int get_swigversion() { return  SWIGVERSION; }
%%}

%%pythoncode %%{
tmp = hex(get_swigversion())
swigversion = "%%d.%%d.%%d"%%(tuple(map(int, [tmp[-5], tmp[-3], tmp[-2:]])))
del tmp, get_swigversion
%%}

"""

re_symbol = re.compile(r"(?:[a-zA-Z_]*::)*[\w\.]+")
re_number = re.compile("[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")


def _check_valid_symbol(sym):
    if sym in _disallowed_cpp_keywords:
        cpp.dolfin_error("compilemodule.py",
                         "parse expression string",
                         "Detected C++ keyword (%s) in C++ expression" % sym)
    if sym in _fenics_keywords:
        cpp.dolfin_error("compilemodule.py",
                         "parse expression string",
                         "Detected name reserved by FEniCS (%s) in C++ expression" % sym)
    if sym.startswith("_") or sym.endswith("_"):
        cpp.dolfin_error("compilemodule.py",
                         "parse expression string",
                         "Detected invalid name with trailing underscores (%s) in C++ expression" % sym)
    if sym.startswith("ufl_"):
        cpp.dolfin_error("compilemodule.py",
                         "parse expression string",
                         "Detected invalid name with initial ufl_ (%s) in C++ expression" % sym)


def expression_to_code_fragments(expr, arguments,
        generic_function_members=(), mesh_function_members=()):
    "A help function which extract a dict with code snippets from an expression"
    expr = list(expr)

    # Autodetect variables from function strings
    variables = set()

    for i, c in enumerate(expr):
        # Find groups of connected alphanumeric letters
        symbols = re_symbol.findall(c)
        assert isinstance(symbols, list)
        variables.update(symbols)
        for sym in symbols:
            _check_valid_symbol(sym)
        # FIXME: Remove outcommented code?
        # NOTE: A hack to get around an ambiguous overloading of
        #       dolfin::pow(double,int)

        #if "pow" in symbols:
        #    c = c.replace("pow", "std::pow")
        #expr[i] = c

    # Remove any variables defined in the arguments list
    variables.difference_update(arguments)

    # Remove the builtin math functions from the variables
    variables.difference_update(_all_math)

    # Remove allowed C++ keywords
    variables.difference_update(_allowed_cpp_keywords)

    # Remove the numerals from the variables
    numerals = [v for v in variables if re_number.match(v)]
    variables.difference_update(numerals)

    # Remove namespace-specified variables
    namespacevars = [v for v in variables if "::" in v]
    variables.difference_update(namespacevars)

    # Variables are either GenericFunction, MeshFunction, or scalar
    mesh_function_member_names = [item[0] for item in mesh_function_members]
    scalar_members = sorted(variables - set(generic_function_members) - set(mesh_function_member_names))

    # Generate code for member variables
    members_code = ["  double %s;" % name for name in scalar_members]
    members_code.extend("  std::shared_ptr<dolfin::MeshFunction<%s>> shared_%s;" % (typename, name)
                        for name, typename in mesh_function_members)
    members_code.extend("  std::shared_ptr<dolfin::GenericFunction> shared_%s;" % name
                        for name in generic_function_members)
    members_code = "\n".join(members_code)

    # Generate constructor code for initialization of scalar member variables
    constructor_code = "\n".join("    %s = 0;" % name for name in scalar_members)

    # Connect the code fragments using the function template code
    fragments = {}
    fragments["members"]        = members_code
    fragments["constructor"]    = constructor_code

    # Return the code fragments
    return fragments, variables


def expect_list_of(argtype, arg, argname):
    if arg is None:
        return []

    if isinstance(arg, (list, tuple)):
        if all(isinstance(s, argtype) for s in arg):
            return arg

    cpp.dolfin_error("compilemodule.py",
                     "ensure correct argument for compile_extension_module",
                     "Provide a 'tuple' or 'list' with '%s', for the "\
                     "'%s' argument" % (argtype.__name__, argname))


def expect_arg(argtype, arg, argname):
    # Check the type of the argument
    if isinstance(arg, argtype):
        return

    cpp.dolfin_error("compilemodule.py",
                     "ensure correct argument for compile_extension_module",
                     "Provide a '%s', for the '%s' argument" % \
                     (argtype.__name__, argname))


def check_swig_version(compiled_module):

    # Check swig version of compiled module
    if compiled_module and compiled_module.swigversion != cpp.__swigversion__:
        cpp.dolfin_error("compilemodule.py",
                         "compiling extension module",
                         "Incompatible swig versions detected. DOLFIN swig "\
                         "version is not the same as extension module swig "\
                         "version: '%s' != '%s' " % \
                         (cpp.__swigversion__, compiled_module.swigversion))


@mpi_jit_decorator
def compile_extension_module(code, module_name="",
                             additional_declarations="",
                             additional_system_headers=None,
                             mpi_comm=None,
                             **instant_kwargs):
    """
    Just In Time compile DOLFIN C++ code into a Python module.

    *Arguments*
        code
            C++ code which implements any function or C++ class. Any function
            or class available in the C++ DOLFIN namespace can be used and/or
            subclassed. All typemaps from the original Python interface are
            available, making it possible to interface with for example NumPy
            for Array<double/int> arguments. Source which is not wrapped in
            a dolfin namespace will be automatically wrapped.

        module_name
            Force a name of the module. If not set a name based on the hex
            representation of the code will be used.

        additional_declarations
            Additional SWIG declarations can be passed using this argument.

        additional_system_headers :
            System headers needed to compile the generated can be included
            using this argument. The headers are passed using a list of 'str'

        mpi_comm
            An mpi communicator used to determine what processors
            should compile the module. By default mpi_world_comm is
            used, but could also accept mpi_group communicators.

    *Returns*

        The JIT compiled extension module

    *Examples of usage*

        The following toy example shows how one can use compiled extension
        modules to access low level PETSc routines:

        .. code-block:: python

            from numpy import arange
            code = '''
            namespace dolfin {

              void PETSc_exp(std::shared_ptr<dolfin::PETScVector> vec)
              {
                Vec x = vec->vec();
                assert(x);
                VecExp(x);
              }
            }
            '''
            ext_module = compile_extension_module(code,
                         additional_system_headers=["petscvec.h"])
            comm = mpi_comm_world()
            vec = PETScVector(comm, 10)
            vec[:] = arange(10)
            print vec[-1]
            ext_module.PETSc_exp(vec)
            print vec[-1]

    """
    # Check the provided arguments
    expect_arg(str, code, "first")
    expect_arg(str, module_name, "module_name")
    expect_arg(str, additional_declarations, "additional_declarations")
    additional_system_headers = \
                expect_list_of(str, additional_system_headers, "additional_system_headers")

    # Check that the code does not use 'using namespace dolfin'
    if re.search("using\s+namespace\s+dolfin",code):
        cpp.dolfin_error("compilemodule.py",
                         "ensure correct argument to compile_extension_module",
                         "Do not use 'using namespace dolfin'. "\
                         "Include the code in namespace dolfin {...} instead")

    # Check if the code does not use namespace dolfin {...}
    if not re.search("namespace\s+dolfin\s*\{[\s\S]+\}", code):

        # Wrap and indet code in namespace dolfin
        codelines = ["namespace dolfin","{"]
        codelines += ["  " + line for line in code.split("\n")]
        codelines += ["}"]
        code = "\n".join(codelines)

    # Create unique module name for this application run
    if not module_name:
        module_signature = hashlib.sha1((repr(code) +
                                   dolfin.__version__ +
                                   str(_interface_version) +
                                   ffc.ufc_signature() +
                                   sys.version +
                                   additional_declarations +
                                   str(additional_system_headers) +
                                   repr(canonicalize_metadata(instant_kwargs))
                                   ).encode("utf-8")
                                ).hexdigest()
        max_signature_length = parameters["max_signature_length"]
        if max_signature_length:
            module_signature = module_signature[:max_signature_length]
        module_name = "dolfin_%s" % module_signature

    # Extract dolfin dependencies and class names
    used_types, declared_types = parse_and_extract_type_info(code)

    # Add any bases of the declared types to used_types
    for declared_type, bases in list(declared_types.items()):
        used_types.update(bases)

    # Filter out dolfin types and add derived and bases for each type
    used_dolfin_types = []
    for dolfin_type in dolfin_type_def:
        for used_type in used_types:
            if dolfin_type in used_type:

                # Add bases and derived types
                used_dolfin_types.extend(\
                    dolfin_type_def[dolfin_type]["bases"])

                # Add dolfin type
                used_dolfin_types.append(dolfin_type)

                break

    # Generate dependency info
    dependencies = {}
    for dolfin_type in used_dolfin_types:
        if dolfin_type_def[dolfin_type]["submodule"] not in dependencies:
            dependencies[dolfin_type_def[dolfin_type]["submodule"]] = []

        dependencies[dolfin_type_def[dolfin_type]["submodule"]].append(\
            dolfin_type_def[dolfin_type]["header"])

    # Need special treatment for template definitions in function/pre.i
    if "function" in dependencies:
        for dolfin_type in ["FunctionSpace", "Function"]:
            dependencies["function"].append(dolfin_type_def[dolfin_type]["header"])

    # Add dependency for constants and types
    if "common" not in dependencies:
        dependencies["common"] = ["dolfin/common/constants.h", "dolfin/common/types.h"]

    else:
        for f in ["dolfin/common/constants.h", "dolfin/common/types.h"]:
            if f not in dependencies["common"]:
                dependencies["common"].append(f)

    # Sort the dependencies
    dependencies = sort_submodule_dependencies(dependencies, submodule_info)

    import_lines, headers_includes, file_dependencies = \
                  build_swig_import_info(dependencies, submodule_info, "dolfin.cpp.")

    # Extract header info
    dolfin_system_headers = [header for header in file_dependencies
                             if not "pre.i" in header]

    # Check the handed import files
    interface_import_files = []

    # Check cache
    compiled_module = instant.import_module(module_name)

    if not compiled_module:
        # Configure instant before building
        sys.stdout.flush()
        dolfin.info("Calling DOLFIN just-in-time (JIT) compiler, this may take some time.")

        # Add various system headers
        instant_kwargs["system_headers"] = ["cmath", "iostream","complex",
                                            "stdexcept","numpy/arrayobject.h",
                                            "memory",
                                            "dolfin/common/types.h",
                                            "dolfin/math/basic.h",
                                        ] + instant_kwargs.get("system_headers", [])
        instant_kwargs["system_headers"] += dolfin_system_headers
        instant_kwargs["system_headers"] += additional_system_headers

        # Add cmake packages
        instant_kwargs["cmake_packages"]  = ["DOLFIN"] + \
                                            instant_kwargs.get("cmake_packages", [])
        instant_kwargs["signature"] = module_name

        # Add additional swig declarations
        dolfin_import_statement = "\n".join(import_lines)
        shared_ptr_declarations = extract_shared_ptr_declaration(
            declared_types, used_dolfin_types, shared_ptr_classes)
        additional_declarations = _additional_declarations % {
            "additional_declarations": additional_declarations,
            "shared_ptr_declarations": shared_ptr_declarations,
            "dolfin_import_statement": dolfin_import_statement,
            }

        # Compile extension module with instant
        compiled_module = instant.build_module(
            code=code,
            additional_declarations=additional_declarations,
            **instant_kwargs)

    sys.stdout.flush()

    # Check that the swig version of the compiled module is the same as
    # dolfin was compiled with
    check_swig_version(compiled_module)

    return compiled_module


def extract_shared_ptr_declaration(declared_types, used_dolfin_types, shared_ptr_classes):
    " Extract any declaration for shared_ptr"
    # Check if there are any classes that is derived from any of the
    # shared_ptr classes in PyDOLFIN and declare if any
    shared_ptr_format = "%%shared_ptr(dolfin::%s)"

    used_shared_ptr_types = []

    # Collect used types which should be shared_ptr declared
    for dolfin_type in used_dolfin_types:
        if dolfin_type in shared_ptr_classes:
            used_shared_ptr_types.append(dolfin_type)

    used_shared_ptr_types.extend(derived for derived, bases in list(declared_types.items())
                                 if any(base in shared_ptr_classes for base in bases))

    shared_ptr_declarations = "\n".join([shared_ptr_format % derived
                                         for derived in used_shared_ptr_types])
    return shared_ptr_declarations
