import inspect
import operator
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from distutils.version import StrictVersion
import six
import numpy as np
if hasattr(inspect, "getfullargspec"):
def getargs(func):
return tuple(inspect.getfullargspec(func).args[1:])
else:
def getargs(func):
try:
return tuple(inspect.getargspec(func).args[1:])
except TypeError:
return ()
class MissingValueException(Exception):
"""Internally used exception."""
__slots__ = ("error",)
def __init__(self, error):
self.error = error
class DescriptorMeta(ABCMeta):
def __new__(cls, classname, bases, dict):
__init__ = dict.get("__init__")
if __init__ is None:
for base in bases:
__init__ = getattr(base, "__init__", None)
if __init__ is not None:
break
dict["parameter_names"] = getargs(__init__)
if "since" in dict:
dict["since"] = StrictVersion(dict["since"])
return ABCMeta.__new__(cls, classname, bases, dict)
[docs]class Descriptor(six.with_metaclass(DescriptorMeta, object)):
r"""Abstract base class of descriptors.
Attributes:
mol(rdkit.Mol): target molecule
"""
__slots__ = ("_context",)
explicit_hydrogens = True
kekulize = False
require_connected = False
require_3D = False
def __reduce_ex__(self, version):
return self.__class__, self.parameters()
def description(self):
pass
[docs] @classmethod
def preset(cls, version):
r"""Generate preset descriptor instances.
Returns:
Iterable[Descriptor]: preset descriptors
"""
return ()
[docs] @abstractmethod
def parameters(self):
"""[abstractmethod] get __init__ arguments of this descriptor instance.
this method used in pickling and identifying descriptor instance.
Returns:
tuple: tuple of __init__ arguments
"""
raise NotImplementedError("not implemented Descriptor.parameters method")
def get_parameter_dict(self):
return dict(zip(self.parameter_names, self.parameters()))
[docs] def to_json(self):
"""Convert to json serializable dictionary.
Returns:
dict: dictionary of descriptor
"""
d, ps = self._to_json()
if len(ps) == 0:
return {"name": d}
else:
return {"name": d, "args": ps}
def _to_json(self):
d = self.__class__.__name__
ps = self.get_parameter_dict()
return d, {k: getattr(v, "as_argument", v) for k, v in ps.items()}
[docs] @abstractmethod
def calculate(self):
r"""[abstractmethod] calculate descriptor value.
Returns:
rtype
"""
raise NotImplementedError("not implemented Descriptor.calculate method")
[docs] def dependencies(self):
r"""Descriptor dependencies.
Returns:
dict[str, Descriptor or None] or None
"""
pass
@property
def as_argument(self):
"""Argument representation of descriptor.
Returns:
any
"""
return self
@staticmethod
def _pretty(v):
v = getattr(v, "as_argument", v)
return repr(v)
def __repr__(self):
return "{}.{}({})".format(
self.__class__.__module__,
self.__class__.__name__,
", ".join(self._pretty(a) for a in self.parameters()),
)
def __hash__(self):
return hash((self.__class__, self.parameters()))
def __compare_by_reduce(meth):
def compare(self, other):
L = self.__class__, self.parameters()
r = other.__class__, other.parameters()
return getattr(L, meth)(r)
return compare
__eq__ = __compare_by_reduce("__eq__")
__ne__ = __compare_by_reduce("__ne__")
__lt__ = __compare_by_reduce("__lt__")
__gt__ = __compare_by_reduce("__gt__")
__le__ = __compare_by_reduce("__le__")
__ge__ = __compare_by_reduce("__ge__")
rtype = None
@property
def mol(self):
"""Get molecule.
Returns:
rdkit.Mol
"""
return self._context.get_mol(self)
@property
def coord(self):
"""Get 3D coordinate.
Returns:
numpy.array[3, N]: coordinate matrix
"""
if not self.require_3D:
self.fail(AttributeError("use 3D coordinate in 2D descriptor"))
return self._context.get_coord(self)
[docs] def fail(self, exception):
"""Raise known exception and return missing value.
Raises:
MissingValueException
"""
raise MissingValueException(exception)
[docs] @contextmanager
def rethrow_zerodiv(self):
"""[contextmanager] treat zero div as known exception."""
with np.errstate(divide="raise", invalid="raise"):
try:
yield
except (FloatingPointError, ZeroDivisionError) as e:
self.fail(ZeroDivisionError(*e.args))
[docs] @contextmanager
def rethrow_na(self, exception):
"""[contextmanager] treat any exceptions as known exception."""
try:
yield
except exception as e:
self.fail(e)
def _unary_common(name, operator):
def unary(self):
return UnaryOperatingDescriptor(name.format(self), operator, self)
return unary
def _binary_common(name, operator):
def binary(self, other):
if not isinstance(other, Descriptor):
other = ConstDescriptor(other)
return BinaryOperatingDescriptor(name.format(self, other), operator, self, other)
return binary
__add__ = _binary_common("({}+{})", "+")
__sub__ = _binary_common("({}-{})", "-")
__mul__ = _binary_common("({}*{})", "*")
__truediv__ = _binary_common("({}/{})", "/")
__floordiv__ = _binary_common("({}//{})", "//")
__mod__ = _binary_common("({}%{})", "%")
__pow__ = _binary_common("({}**{})", "**")
__neg__ = _unary_common("-{}", "-")
__pos__ = _unary_common("+{}", "+")
__abs__ = _unary_common("|{}|", "abs")
__trunc__ = _unary_common("trunc({})", "trunc")
if six.PY3:
__ceil__ = _unary_common("ceil({})", "ceil")
__floor__ = _unary_common("floor({})", "floor")
def is_descriptor_class(desc, include_abstract=False):
r"""Check calculatable descriptor class or not.
Returns:
bool
"""
return (
isinstance(desc, type) and
issubclass(desc, Descriptor) and
(True if include_abstract else not inspect.isabstract(desc))
)
class UnaryOperatingDescriptor(Descriptor):
@classmethod
def preset(cls, version):
return cls()
operators = {
"+": operator.pos,
"-": operator.neg,
"abs": operator.abs,
"trunc": np.trunc,
"ceil": np.ceil, # noqa: S001
"floor": np.floor,
}
def parameters(self):
return self._name, self._operator, self._value
def __init__(self, name, operator, value):
self._name = name
self._operator = operator
self._fn = self.operators[operator]
self._value = value
def _to_json(self):
return self.__class__.__name__, {
"name": self._name,
"operator": self._operator,
"value": self._value.to_json(),
}
def __str__(self):
return self._name
def dependencies(self):
return {
"value": self._value,
}
def calculate(self, value):
return self._fn(value)
class ConstDescriptor(Descriptor):
@classmethod
def preset(cls, version):
return cls()
def parameters(self):
return (self._value,)
def __init__(self, value):
self._value = value
def __str__(self):
return str(self._value)
def calculate(self):
return self._value
class BinaryOperatingDescriptor(Descriptor):
@classmethod
def preset(cls, version):
return cls()
operators = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul, # noqa: S001
"/": operator.truediv,
"//": operator.floordiv,
"%": operator.mod, # noqa: S001
"**": operator.pow,
}
def _to_json(self):
return self.__class__.__name__, {
"name": self._name,
"operator": self._operator,
"left": self._left.to_json(), # noqa: S001
"right": self._right.to_json(),
}
def parameters(self):
return self._name, self._operator, self._left, self._right
def __init__(self, name, operator, left, right):
self._name = name
self._operator = operator
self._fn = self.operators[operator]
self._left = left
self._right = right
def __str__(self):
return self._name
def dependencies(self):
return {
"left": self._left,
"right": self._right,
}
def calculate(self, left, right):
return self._fn(left, right)