Source code for vayesta.mpi.scf

import functools
import logging
import pyscf
import pyscf.df
import pyscf.pbc
import pyscf.pbc.df
from vayesta.core.util import log_time


[docs]def scf_with_mpi(mpi, mf, mpi_rank=0, log=None): """Use to run SCF only on the master node and broadcast result afterwards.""" if not mpi: return mf kernel_orig = mf.kernel log = log or mpi.log or logging.getLogger(__name__) def mpi_kernel(self, *args, **kwargs): df = getattr(self, "with_df", None) if mpi.rank == mpi_rank: log.info("MPI rank= %3d is running SCF", mpi.rank) with log_time(log.timing, "Time for SCF: %s"): res = kernel_orig(*args, **kwargs) log.info("MPI rank= %3d finished SCF", mpi.rank) else: res = None # Generate auxiliary cell, compensation basis etc,..., but not 3c integrals: if df is not None: # Molecules if getattr(df, "auxmol", False) is None: df.auxmol = pyscf.df.addons.make_auxmol(df.mol, df.auxbasis) # Solids elif getattr(df, "auxcell", False) is None: df.build(with_j3c=False) log.info("MPI rank= %3d is waiting for SCF results", mpi.rank) mpi.world.barrier() # Broadcast results bcast = functools.partial(mpi.bcast, root=mpi_rank) with log_time(log.timing, "Time for MPI broadcast of SCF results: %s"): res = bcast(res) if df is not None: df._cderi = bcast(df._cderi) self.converged = bcast(self.converged) self.e_tot = bcast(self.e_tot) self.mo_energy = bcast(self.mo_energy) self.mo_occ = bcast(self.mo_occ) self.mo_coeff = bcast(self.mo_coeff) return res mf.kernel = mpi_kernel.__get__(mf) mf.with_mpi = True return mf
[docs]def gdf_with_mpi(mpi, df, mpi_rank=0, log=None): log = log or mpi.log or logging.getLogger(__name__) if not isinstance(df._cderi_to_save, str): raise NotImplementedError build_orig = df.build pbc = isinstance(df, pyscf.pbc.df.GDF) cderi_file = getattr(df._cderi_to_save, "name", df._cderi_to_save) df._cderi_to_save = mpi.world.bcast(cderi_file, root=mpi_rank) log.debug("df._cderi_to_save= %s", df._cderi_to_save) def mpi_build(self, *args, **kwargs): if mpi.rank == mpi_rank: res = build_orig(*args, **kwargs) else: res = build_orig(*args, with_j3c=False, **kwargs) if pbc else None df._cderi = df._cderi_to_save mpi.world.barrier() return res df.build = mpi_build.__get__(df) df.with_mpi = True return df