from __future__ import print_function
import sys
import warnings
from types import ModuleType
from contextlib import contextmanager
from multiprocessing import cpu_count
from distutils.version import StrictVersion
from .._util import Capture, DummyBar
from ..error import Error, Missing, MultipleFragments, DuplicatedDescriptorName
from .result import Result
from .context import Context
from .._version import __version__
from .descriptor import Descriptor, MissingValueException, is_descriptor_class
try:
from tqdm import tqdm
from .._util import NotebookWrapper
except ImportError:
tqdm = NotebookWrapper = DummyBar
[docs]class Calculator(object):
r"""descriptor calculator.
Parameters:
descs: see Calculator.register() method
ignore_3D: see Calculator.register() method
"""
__slots__ = (
"_descriptors",
"_name_dict",
"_explicit_hydrogens",
"_kekulizes",
"_require_3D",
"_cache",
"_debug",
"_progress_bar",
"_config",
)
def __setstate__(self, dict):
ds = self._descriptors = dict.get("_descriptors", [])
self._name_dict = {str(d): d for d in ds}
self._explicit_hydrogens = dict.get("_explicit_hydrogens", {True, False})
self._kekulizes = dict.get("_kekulizes", {True, False})
self._require_3D = dict.get("_require_3D", False)
[docs] @classmethod
def from_json(cls, obj):
"""Create Calculator from json descriptor objects.
Parameters:
obj(list or dict): descriptors to register
Returns:
Calculator: calculator
"""
calc = cls()
calc.register_json(obj)
return calc
[docs] def register_json(self, obj):
"""Register Descriptors from json descriptor objects.
Parameters:
obj(list or dict): descriptors to register
"""
if not isinstance(obj, list):
obj = [obj]
self.register(Descriptor.from_json(j) for j in obj)
[docs] def to_json(self):
"""Convert descriptors to json serializable data.
Returns:
list: descriptors
"""
return [d.to_json() for d in self.descriptors]
def __reduce_ex__(self, version):
return (
self.__class__,
(),
{
"_config": self._config,
"_descriptors": self._descriptors,
"_explicit_hydrogens": self._explicit_hydrogens,
"_kekulizes": self._kekulizes,
"_require_3D": self._require_3D,
},
)
def __getitem__(self, key):
return self._name_dict[key]
def __init__(self, descs=None, version=None, ignore_3D=False, config=None):
if descs is None:
descs = []
if config is None:
config = {}
self._descriptors = []
self._name_dict = {}
self._explicit_hydrogens = set()
self._kekulizes = set()
self._require_3D = False
self._debug = False
self._config = config
self.register(descs, version=version, ignore_3D=ignore_3D)
[docs] def config(self, **configs):
r"""Set global configuration."""
self._config.update(configs)
@property
def descriptors(self):
r"""All descriptors.
you can get/set/delete descriptor.
Returns:
tuple[Descriptor]: registered descriptors
"""
return tuple(self._descriptors)
@descriptors.setter
def descriptors(self, descs):
del self.descriptors
self.register(descs)
@descriptors.deleter
def descriptors(self):
self._descriptors = []
self._name_dict = {}
self._explicit_hydrogens.clear()
self._kekulizes.clear()
self._require_3D = False
def __len__(self):
return len(self._descriptors)
def _register_one(self, desc, check_only=False, ignore_3D=False):
if not isinstance(desc, Descriptor):
raise ValueError("{!r} is not descriptor".format(desc))
if ignore_3D and desc.require_3D:
return
self._explicit_hydrogens.add(bool(desc.explicit_hydrogens))
self._kekulizes.add(bool(desc.kekulize))
self._require_3D |= desc.require_3D
for dep in (desc.dependencies() or {}).values():
if isinstance(dep, Descriptor):
self._register_one(dep, check_only=True)
if not check_only:
sdesc = str(desc)
old = self._name_dict.get(sdesc)
if old is not None:
raise DuplicatedDescriptorName(desc, old)
self._name_dict[sdesc] = desc
self._descriptors.append(desc)
[docs] def register(self, desc, version=None, ignore_3D=False):
r"""Register descriptors.
Descriptor-like:
* Descriptor instance: self
* Descriptor class: use Descriptor.preset() method
* module: use Descriptor-likes in module
* Iterable: use Descriptor-likes in Iterable
Parameters:
desc(Descriptor-like): descriptors to register
version(str): version
ignore_3D(bool): ignore 3D descriptors
"""
if version is None:
version = __version__
version = StrictVersion(version)
return self._register(desc, version, ignore_3D)
def _register(self, desc, version, ignore_3D):
if not hasattr(desc, "__iter__"):
if is_descriptor_class(desc):
if desc.since > version:
return
for d in desc.preset(version=version):
self._register_one(d, ignore_3D=ignore_3D)
elif isinstance(desc, ModuleType):
self._register(
get_descriptors_in_module(desc),
version=version,
ignore_3D=ignore_3D,
)
else:
self._register_one(desc, ignore_3D=ignore_3D)
else:
for d in desc:
self._register(d, version=version, ignore_3D=ignore_3D)
def _calculate_one(self, cxt, desc, reset):
if desc in self._cache:
return self._cache[desc]
if reset:
cxt.reset()
desc._context = cxt
cxt.add_stack(desc)
if desc.require_connected and desc._context.n_frags != 1:
return False, Missing(MultipleFragments(), desc._context.get_stack())
args = {}
for name, dep in (desc.dependencies() or {}).items():
if dep is None:
args[name] = None
else:
ok, r = self._calculate_one(cxt, dep, False)
if ok:
args[name] = r
else:
return False, r
ok = False
try:
r = desc.calculate(**args)
if self._debug:
self._check_rtype(desc, r)
ok = True
except MissingValueException as e:
r = Missing(e.error, desc._context.get_stack())
except Exception as e:
r = Error(e, desc._context.get_stack())
self._cache[desc] = ok, r
return ok, r
def _check_rtype(self, desc, result):
if desc.rtype is None:
return
if isinstance(result, Error):
return
if not isinstance(result, desc.rtype):
raise TypeError("{} not match {}".format(result, desc.rtype))
def _calculate(self, cxt):
self._cache = {}
for desc in self.descriptors:
_, r = self._calculate_one(cxt, desc, True)
yield r
def __call__(self, mol, id=-1):
r"""Calculate descriptors.
:type mol: rdkit.Chem.Mol
:param mol: molecular
:type id: int
:param id: conformer id
:rtype: Result[scalar or Error]
:returns: iterator of descriptor and value
"""
return self._wrap_result(
mol, self._calculate(Context.from_calculator(self, mol, id))
)
def _wrap_result(self, mol, r):
return Result(mol, r, self._descriptors)
def _serial(self, mols, nmols, quiet, ipynb, id):
with self._progress(quiet, nmols, ipynb) as bar:
for m in mols:
with Capture() as capture:
r = self._wrap_result(
m, self._calculate(Context.from_calculator(self, m, id))
)
for e in capture.result:
e = e.rstrip()
if not e:
continue
bar.write(e, file=capture.orig)
yield r
bar.update()
@contextmanager
def _progress(self, quiet, total, ipynb):
args = {"dynamic_ncols": True, "leave": True, "total": total}
if quiet:
Bar = DummyBar
elif ipynb:
Bar = NotebookWrapper
else:
Bar = tqdm
try:
with Bar(**args) as self._progress_bar:
yield self._progress_bar
finally:
if hasattr(self, "_progress_bar"):
del self._progress_bar
[docs] def echo(self, s, file=sys.stdout, end="\n"):
"""Output message.
Parameters:
s(str): message to output
file(file-like): output to
end(str): end mark of message
Return:
None
"""
p = getattr(self, "_progress_bar", None)
if p is not None:
p.write(s, file=file, end="\n")
return
print(s, file=file, end="\n") # noqa: T003
[docs] def map(self, mols, nproc=None, nmols=None, quiet=False, ipynb=False, id=-1):
r"""Calculate descriptors over mols.
Parameters:
mols(Iterable[rdkit.Mol]): moleculars
nproc(int): number of process to use. default: multiprocessing.cpu_count()
nmols(int): number of all mols to use in progress-bar. default: mols.__len__()
quiet(bool): don't show progress bar. default: False
ipynb(bool): use ipython notebook progress bar. default: False
id(int): conformer id to use. default: -1.
Returns:
Iterator[Result[scalar]]
"""
if nproc is None:
nproc = cpu_count()
if hasattr(mols, "__len__"):
nmols = len(mols)
if nproc == 1:
return self._serial(mols, nmols=nmols, quiet=quiet, ipynb=ipynb, id=id)
else:
return self._parallel(
mols, nproc, nmols=nmols, quiet=quiet, ipynb=ipynb, id=id
)
[docs] def pandas(self, mols, nproc=None, nmols=None, quiet=False, ipynb=False, id=-1):
r"""Calculate descriptors over mols.
Returns:
pandas.DataFrame
"""
from .pandas_module import MordredDataFrame, Series
if isinstance(mols, Series):
index = mols.index
else:
index = None
return MordredDataFrame(
(list(r) for r in self.map(mols, nproc, nmols, quiet, ipynb, id)),
columns=[str(d) for d in self.descriptors],
index=index,
)
[docs]def get_descriptors_from_module(mdl, submodule=False):
r"""[DEPRECATED] Get descriptors from module.
Parameters:
mdl(module): module to search
Returns:
[Descriptor]
"""
warnings.warn("use get_descriptors_in_module", DeprecationWarning)
__all__ = getattr(mdl, "__all__", None)
if __all__ is None:
__all__ = dir(mdl)
all_functions = (getattr(mdl, name) for name in __all__ if name[:1] != "_")
if submodule:
descs = [
d
for fn in all_functions
if is_descriptor_class(fn) or isinstance(fn, ModuleType)
for d in (
[fn]
if is_descriptor_class(fn)
else get_descriptors_from_module(fn, submodule=True)
)
]
else:
descs = [fn for fn in all_functions if is_descriptor_class(fn)]
return descs
[docs]def get_descriptors_in_module(mdl, submodule=True):
r"""Get descriptors in module.
Parameters:
mdl(module): module to search
submodule(bool): search recursively
Returns:
Iterator[Descriptor]
"""
__all__ = getattr(mdl, "__all__", None)
if __all__ is None:
__all__ = dir(mdl)
all_values = (getattr(mdl, name) for name in __all__ if name[:1] != "_")
if submodule:
for v in all_values:
if is_descriptor_class(v):
yield v
if isinstance(v, ModuleType):
for v in get_descriptors_in_module(v, submodule=True):
yield v
else:
for v in all_values:
if is_descriptor_class(v):
yield v