Source code for vayesta.core.spinalg

"""Some utility to perform operations for RHF and UHF using the
same functions"""

import numpy as np
from vayesta.core import util

__all__ = ["add_numbers", "hstack_matrices"]

[docs]def add_numbers(*args): # RHF if np.all([np.isscalar(arg) for arg in args]): return sum(args) # UHF if not np.any([np.isscalar(arg) for arg in args]): return (sum([arg[0] for arg in args]), sum([arg[1] for arg in args])) raise ValueError
[docs]def hstack_matrices(*args, ignore_none=True): if ignore_none: args = [x for x in args if x is not None] ndims = np.asarray([(arg[0].ndim + 1) for arg in args]) # RHF if np.all(ndims == 2): return util.hstack(*args) # UHF if np.all(ndims == 3): return (util.hstack(*[arg[0] for arg in args]), util.hstack(*[arg[1] for arg in args])) raise ValueError("ndims= %r" % ndims)
def dot(*args, out=None): """Generalizes dot with or without spin channel: ij,jk->ik or Sij,Sjk->Sik Additional non spin-dependent matrices can be present, eg. Sij,jk,Skl->Skl. Note that unlike, this does not support vectors.""" maxdim = np.max([np.ndim(x[0]) for x in args]) + 1 # No spin-dependent arguments present if maxdim == 2: return*args, out=out) # Spin-dependent arguments present assert maxdim == 3 if out is None: out = 2 * [None] args_a = [(x if np.ndim(x[0]) < 2 else x[0]) for x in args] args_b = [(x if np.ndim(x[1]) < 2 else x[1]) for x in args] return (*args_a, out=out[0]),*args_b, out=out[1])) def eigh(a, b=None, *args, **kwargs): ndim = np.ndim(a[0]) + 1 # RHF if ndim == 2: return scipy.linalg.eigh(a, b=b, *args, **kwargs) # UHF if b is None or np.ndim(b[0]) == 1: b = (b, b) results = (scipy.linalg.eigh(a[0], b=b[0], *args, **kwargs), scipy.linalg.eigh(a[1], b=b[1], *args, **kwargs)) return tuple(zip(*results)) def transpose(a, axes=None): if np.ndim(a[0]) == 1: return np.transpose(a, axes=axes) return (transpose(a[0], axes=axes), transpose(a[1], axes=axes)) T = transpose def _guess_spinsym(a): if isinstance(a, (tuple, list)): return "unrestricted" return "restricted" def _make_func(func, nargs=1): def _func(a, *args, spinsym=None, **kwargs): spinsym = spinsym or _guess_spinsym(a) if spinsym == "restricted": return func(a, *args, **kwargs) if nargs == 1: return tuple(func(a[i], *args, **kwargs) for i in range(len(a))) if nargs == 2: assert len(args) >= 1 b, args = args[0], args[1:] return tuple(func(a[i], b[i], *args, **kwargs) for i in range(len(a))) if nargs == 3: assert len(args) >= 2 b, c, args = args[0], args[1], args[2:] return tuple(func(a[i], b[i], c[i], *args, **kwargs) for i in range(len(a))) raise NotImplementedError return _func zeros_like = _make_func(np.zeros_like) add = _make_func(np.add, nargs=2) subtract = _make_func(np.subtract, nargs=2) multiply = _make_func(np.multiply, nargs=2) copy = _make_func(np.copy, nargs=1) norm = _make_func(np.linalg.norm, nargs=1)