Source code for vayesta.solver.callback

import dataclasses
from typing import Callable
import numpy as np

from vayesta.core.types import CISD_WaveFunction, CCSD_WaveFunction, FCI_WaveFunction, RDM_WaveFunction
from vayesta.solver.solver import ClusterSolver

[docs]class CallbackSolver(ClusterSolver):
[docs] @dataclasses.dataclass class Options(ClusterSolver.Options): # Need to specify a type for this to work callback: int = None
[docs] def kernel(self, *args, **kwargs): mf_clus, frozen = self.hamil.to_pyscf_mf(allow_dummy_orbs=True, allow_df=True) results = self.opts.callback(mf_clus) # Build appropriate wavefunction object if 'civec' in results: self.log.info("FCI WaveFunction found in callback results.") wf = FCI_WaveFunction(self.hamil.mo, results['civec']) elif 't1' in results and 't2' in results: self.log.info("CCSD WaveFunction found in callback results.") t1, t2 = results['t1'], results['t2'] if 'l1' in results and 'l2' in results: l1, l2 = results['l1'], results['l2'] else: l1, l2 = None, None wf = CCSD_WaveFunction(self.hamil.mo, t1, t2, l1=l1, l2=l2) elif 'c0' in results and 'c1' in results and 'c2' in results: self.log.info("CISD WaveFunction found in callback results.") c0, c1, c2 = results['c0'], results['c1'], results['c2'] wf = CISD_WaveFunction(self.hamil.mo, c0, c1, c2) elif 'dm1' in results and 'dm2' in results: self.log.info("RDM WaveFunction found in callback results.") dm1, dm2 = results['dm1'], results['dm2'] wf = RDM_WaveFunction(self.hamil.mo, dm1, dm2) else: self.log.warn("No wavefunction results returned by callback!") if 'hole_moments' in results: self.log.info("Hole moments found in callback results.") self.hole_moments = results['hole_moments'] if 'particle_moments' in results: self.log.info("Particle moments found in callback results.") self.particle_moments = results['particle_moments'] results['wf'] = wf self.wf = wf self.converged = results['converged'] if 'converged' in results else False self.callback_results = results