from collections import namedtuple
import logging
import functools
from timeit import default_timer
import numpy as np
import vayesta
from vayesta.core.util import log_time, memory_string
from vayesta.mpi.rma import RMA_Dict
from vayesta.mpi.scf import scf_with_mpi
from vayesta.mpi.scf import gdf_with_mpi
NdArrayMetadata = namedtuple("NdArrayMetadata", ["shape", "dtype"])
[docs]class MPI_Interface:
def __init__(self, mpi, required=False, log=None):
self.log = log or logging.getLogger(__name__)
if mpi == "mpi4py":
mpi = self._import_mpi4py(required=required)
if mpi:
self.MPI = mpi
self.world = mpi.COMM_WORLD
self.rank = self.world.Get_rank()
self.size = self.world.Get_size()
self.timer = mpi.Wtime
else:
self.MPI = None
self.world = None
self.rank = 0
self.size = 1
self.timer = default_timer
self._tag = -1
def _import_mpi4py(self, required=True):
try:
import mpi4py
mpi4py.rc.threads = False
from mpi4py import MPI as mpi
return mpi
except (ModuleNotFoundError, ImportError) as e:
if required:
self.log.critical("mpi4py not found.")
raise e
self.log.debug("mpi4py not found.")
return None
def __len__(self):
return self.size
def __bool__(self):
return self.enabled
@property
def enabled(self):
return self.size > 1
@property
def disabled(self):
return not self.enabled
@property
def is_master(self):
return self.rank == 0
[docs] def get_new_tag(self):
self._tag += 1
return self._tag
[docs] def nreduce(self, *args, target=None, logfunc=None, **kwargs):
"""(All)reduce multiple arguments.
TODO:
* Use Allreduce/Reduce for NumPy types
* combine multiple *args of same dtype into a single array,
to reduce communication overhead.
"""
if logfunc is None:
logfunc = vayesta.log.timingv
if target is None:
with log_time(logfunc, "Time for MPI allreduce: %s"):
res = [self.world.allreduce(x, **kwargs) for x in args]
else:
with log_time(logfunc, "Time for MPI reduce: %s"):
res = [self.world.reduce(x, root=target, **kwargs) for x in args]
if len(res) == 1:
return res[0]
return tuple(res)
[docs] def bcast(self, obj, root=0):
"""Common function to broadcast NumPy arrays or general objects.
Parameters
----------
obj: ndarray or Any
Array or object to be broadcasted.
root: int
Root MPI process.
Returns
-------
obj: ndarray or Any
Broadcasted array or object.
"""
# --- First bcast using the pickle interface
# if obj is a NumPy array, only broadcast the shape and dtype information,
# within a NdArrayMetadata object. Otherwise, broadcast the object itself.
if self.rank == root:
if isinstance(obj, np.ndarray):
data = NdArrayMetadata(obj.shape, obj.dtype)
else:
data = obj
else:
data = None
data = self.world.bcast(data, root=root)
# If the object itself was broadcasted, we can return here:
if not isinstance(data, NdArrayMetadata):
return data
# --- Second Bcast using buffer interface
if self.rank == root:
obj = np.ascontiguousarray(obj)
else:
obj = np.empty(data.shape, dtype=data.dtype)
self.log.debug("Broadcasting array: size= %d memory= %s", obj.size, memory_string(obj.nbytes))
self.world.Bcast(obj, root=root)
return obj
# --- Function wrapper at embedding level
# ---------------------------------------
[docs] def with_reduce(self, **mpi_kwargs):
def decorator(func):
# No MPI:
if self.disabled:
return func
@functools.wraps(func)
def wrapper(*args, **kwargs):
res = func(*args, **kwargs)
res = self.world.reduce(res, **mpi_kwargs)
return res
return wrapper
return decorator
[docs] def with_allreduce(self, **mpi_kwargs):
def decorator(func):
# No MPI:
if self.disabled:
return func
@functools.wraps(func)
def wrapper(*args, **kwargs):
res = func(*args, **kwargs)
res = self.world.allreduce(res, **mpi_kwargs)
return res
return wrapper
return decorator
[docs] def only_master(self):
def decorator(func):
# No MPI:
if self.disabled:
return func
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not self.is_master:
return None
return func(*args, **kwargs)
return wrapper
return decorator
# --- Function wrapper at fragment level
# --------------------------------------
[docs] def with_send(self, source, dest=0, tag=None, **mpi_kwargs):
def decorator(func):
# No MPI:
if self.disabled:
return func
# With MPI:
tag2 = self.get_new_tag() if tag is None else tag
@functools.wraps(func)
def wrapper(*args, **kwargs):
if callable(source):
src = source(*args)
else:
src = source
if self.rank == src:
res = func(*args, **kwargs)
if self.rank != dest:
self.log.debugv("MPI[%d]<send>: func=%s dest=%d", self.rank, func.__name__, dest)
self.world.send(res, dest=dest, tag=tag2, **mpi_kwargs)
self.log.debugv("MPI[%d]<send>: done", self.rank)
return res
elif self.rank == dest:
self.log.debugv("MPI[%d] <recv>: func=%s source=%d", self.rank, func.__name__, src)
res = self.world.recv(source=src, tag=tag2, **mpi_kwargs)
self.log.debugv("MPI[%d] <recv>: type= %r done!", self.rank, type(res))
return res
else:
self.log.debugv("MPI[%d] <do nothing> func=%s source=%d ", self.rank, func.__name__, src)
return None
return wrapper
return decorator
[docs] def create_rma_dict(self, dictionary):
return RMA_Dict.from_dict(self, dictionary)
# --- PySCF decorators
# --------------------
[docs] def scf(self, mf, mpi_rank=0, log=None):
return scf_with_mpi(self, mf, mpi_rank=mpi_rank, log=log)
[docs] def gdf(self, df, mpi_rank=0, log=None):
return gdf_with_mpi(self, df, mpi_rank=mpi_rank, log=log)