import inspect
import operator
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
import six
import numpy as np
if hasattr(inspect, "getfullargspec"):
    def getargs(func):
        return tuple(inspect.getfullargspec(func).args[1:])
else:
    def getargs(func):
        try:
            return tuple(inspect.getargspec(func).args[1:])
        except TypeError:
            return ()
class MissingValueException(Exception):
    """Internally used exception."""
    __slots__ = ("error",)
    def __init__(self, error):
        self.error = error
class DescriptorMeta(ABCMeta):
    def __new__(cls, classname, bases, dict):
        __init__ = dict.get("__init__")
        if __init__ is None:
            for base in bases:
                __init__ = getattr(base, "__init__", None)
                if __init__ is not None:
                    break
        dict["parameter_names"] = getargs(__init__)
        return ABCMeta.__new__(cls, classname, bases, dict)
[docs]class Descriptor(six.with_metaclass(DescriptorMeta, object)):
    r"""Abstract base class of descriptors.
    Attributes:
        mol(rdkit.Mol): target molecule
    """
    __slots__ = "_context",
    explicit_hydrogens = True
    kekulize = False
    require_connected = False
    require_3D = False
    def __reduce_ex__(self, version):
        return self.__class__, self.parameters()
    def description(self):
        pass
[docs]    @classmethod
    def preset(cls):
        r"""Generate preset descriptor instances.
        Returns:
            Iterable[Descriptor]: preset descriptors
        """
        return () 
[docs]    @abstractmethod
    def parameters(self):
        """[abstractmethod] get __init__ arguments of this descriptor instance.
        this method used in pickling and identifying descriptor instance.
        Returns:
            tuple: tuple of __init__ arguments
        """
        raise NotImplementedError("not implemented Descriptor.parameters method") 
    def get_parameter_dict(self):
        return dict(zip(self.parameter_names, self.parameters()))
[docs]    def to_json(self):
        """Convert to json serializable dictionary.
        Returns:
            dict: dictionary of descriptor
        """
        d, ps = self._to_json()
        if len(ps) == 0:
            return {"name": d}
        else:
            return {"name": d, "args": ps} 
    def _to_json(self):
        d = self.__class__.__name__
        ps = self.get_parameter_dict()
        return d, {k: getattr(v, "as_argument", v) for k, v in ps.items()}
[docs]    @abstractmethod
    def calculate(self):
        r"""[abstractmethod] calculate descriptor value.
        Returns:
            rtype
        """
        raise NotImplementedError("not implemented Descriptor.calculate method") 
[docs]    def dependencies(self):
        r"""Descriptor dependencies.
        Returns:
            dict[str, Descriptor or None] or None
        """
        pass 
    @property
    def as_argument(self):
        """Argument representation of descriptor.
        Returns:
            any
        """
        return self
    @staticmethod
    def _pretty(v):
        v = getattr(v, "as_argument", v)
        return repr(v)
    def __repr__(self):
        return "{}.{}({})".format(
            self.__class__.__module__,
            self.__class__.__name__,
            ", ".join(self._pretty(a) for a in self.parameters()),
        )
    def __hash__(self):
        return hash((self.__class__, self.parameters()))
    def __compare_by_reduce(meth):
        def compare(self, other):
            l = self.__class__, self.parameters()
            r = other.__class__, other.parameters()
            return getattr(l, meth)(r)
        return compare
    __eq__ = __compare_by_reduce("__eq__")
    __ne__ = __compare_by_reduce("__ne__")
    __lt__ = __compare_by_reduce("__lt__")
    __gt__ = __compare_by_reduce("__gt__")
    __le__ = __compare_by_reduce("__le__")
    __ge__ = __compare_by_reduce("__ge__")
    rtype = None
    @property
    def mol(self):
        """Get molecule.
        Returns:
            rdkit.Mol
        """
        return self._context.get_mol(self)
    @property
    def coord(self):
        """Get 3D coordinate.
        Returns:
            numpy.array[3, N]: coordinate matrix
        """
        if not self.require_3D:
            self.fail(AttributeError("use 3D coordinate in 2D descriptor"))
        return self._context.get_coord(self)
[docs]    def fail(self, exception):
        """Raise known exception and return missing value.
        Raises:
            MissingValueException
        """
        raise MissingValueException(exception) 
[docs]    @contextmanager
    def rethrow_zerodiv(self):
        """[contextmanager] treat zero div as known exception."""
        with np.errstate(divide="raise", invalid="raise"):
            try:
                yield
            except (FloatingPointError, ZeroDivisionError) as e:
                self.fail(ZeroDivisionError(*e.args)) 
[docs]    @contextmanager
    def rethrow_na(self, exception):
        """[contextmanager] treat any exceptions as known exception."""
        try:
            yield
        except exception as e:
            self.fail(e) 
    def _unary_common(name, operator):
        def unary(self):
            return UnaryOperatingDescriptor(name.format(self), operator, self)
        return unary
    def _binary_common(name, operator):
        def binary(self, other):
            if not isinstance(other, Descriptor):
                other = ConstDescriptor(other)
            return BinaryOperatingDescriptor(name.format(self, other), operator, self, other)
        return binary
    __add__ = _binary_common("({}+{})", "+")
    __sub__ = _binary_common("({}-{})", "-")
    __mul__ = _binary_common("({}*{})", "*")
    __truediv__ = _binary_common("({}/{})", "/")
    __floordiv__ = _binary_common("({}//{})", "//")
    __mod__ = _binary_common("({}%{})", "%")
    __pow__ = _binary_common("({}**{})", "**")
    __neg__ = _unary_common("-{}", "-")
    __pos__ = _unary_common("+{}", "+")
    __abs__ = _unary_common("|{}|", "abs")
    __trunc__ = _unary_common("trunc({})", "trunc")
    if six.PY3:
        __ceil__ = _unary_common("ceil({})", "ceil")
        __floor__ = _unary_common("floor({})", "floor") 
def is_descriptor_class(desc):
    r"""Check calculatable descriptor class or not.
    Returns:
        bool
    """
    return (
        isinstance(desc, type) and
        issubclass(desc, Descriptor) and
        not inspect.isabstract(desc)
    )
class UnaryOperatingDescriptor(Descriptor):
    @classmethod
    def preset(cls):
        return cls()
    operators = {
        "+": operator.pos,
        "-": operator.neg,
        "abs": operator.abs,
        "trunc": np.trunc,
        "ceil": np.ceil,  # noqa: S001
        "floor": np.floor,
    }
    def parameters(self):
        return self._name, self._operator, self._value
    def __init__(self, name, operator, value):
        self._name = name
        self._operator = operator
        self._fn = self.operators[operator]
        self._value = value
    def _to_json(self):
        return self.__class__.__name__, {
            "name": self._name,
            "operator": self._operator,
            "value": self._value.to_json(),
        }
    def __str__(self):
        return self._name
    def dependencies(self):
        return {
            "value": self._value,
        }
    def calculate(self, value):
        return self._fn(value)
class ConstDescriptor(Descriptor):
    @classmethod
    def preset(cls):
        return cls()
    def parameters(self):
        return self._value,
    def __init__(self, value):
        self._value = value
    def __str__(self):
        return str(self._value)
    def calculate(self):
        return self._value
class BinaryOperatingDescriptor(Descriptor):
    @classmethod
    def preset(cls):
        return cls()
    operators = {
        "+": operator.add,
        "-": operator.sub,
        "*": operator.mul,  # noqa: S001
        "/": operator.truediv,
        "//": operator.floordiv,
        "%": operator.mod,  # noqa: S001
        "**": operator.pow,
    }
    def _to_json(self):
        return self.__class__.__name__, {
            "name": self._name,
            "operator": self._operator,
            "left": self._left.to_json(),  # noqa: S001
            "right": self._right.to_json(),
        }
    def parameters(self):
        return self._name, self._operator, self._left, self._right
    def __init__(self, name, operator, left, right):
        self._name = name
        self._operator = operator
        self._fn = self.operators[operator]
        self._left = left
        self._right = right
    def __str__(self):
        return self._name
    def dependencies(self):
        return {
            "left": self._left,
            "right": self._right,
        }
    def calculate(self, left, right):
        return self._fn(left, right)