Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose SNR in TDVPSchmitt #1441

Open
jwnys opened this issue Mar 1, 2023 · 7 comments
Open

Expose SNR in TDVPSchmitt #1441

jwnys opened this issue Mar 1, 2023 · 7 comments

Comments

@jwnys
Copy link
Collaborator

jwnys commented Mar 1, 2023

In order for TDVPSchmitt to work, one needs to analyze the SNR values and determine the cutoff by hand. To my knowledge, SNR is not exposed, so there is currently no way to determine the SNR. Since this would be typically determined for a fixed operators, it would need a general signature independent of the driver, like this:

def compute_snr(state, op, diag_shift=0, diag_scale=0, holomorphic=False):

    E_loc = state.local_estimators(op)

    S = nk.optimizer.qgt.QGTJacobianDense(
        state,
        diag_shift=diag_shift,
        diag_scale=diag_scale,
        holomorphic=holomorphic,
    )

    snr = _compute_snr( # internal jitted function taken from _imp
        state.n_samples,
        E_loc,
        S
    )

    return snr
@jwnys
Copy link
Collaborator Author

jwnys commented Mar 1, 2023

Actually, this internal part would be simply

@partial(jax.jit, static_argnames=("n_samples"))
def _compute_snr(n_samples, E_loc, S):
    E = stats.statistics(E_loc)
    ΔE_loc = E_loc.T.reshape(-1, 1) - E.mean

    stack_jacobian = S.mode == "complex"

    O = S.O / jnp.sqrt(n_samples)  # already divided by jnp.sqrt(n_s)
    if stack_jacobian:
        O = O.reshape(-1, 2, S.O.shape[-1])
        O = O[:, 0, :] + 1j * O[:, 1, :]

    Sd = S.to_dense()
    ev, V = jnp.linalg.eigh(Sd)

    OEdata = O.conj() * ΔE_loc
    F = stats.sum(OEdata, axis=0)

    # Note: this implementation differs from Eq. 20 in Markus's paper, which I would
    # implement as `rho = mpi.mean(QEdata, axis=0)`. However, this is different from
    # changing the basis AFTER averaging over the samples, and leads to the wrong
    # normalisation of RHo.
    Q = jnp.tensordot(V.conj().T, O.T, axes=1).T
    QEdata = Q.conj() * ΔE_loc
    rho = V.conj().T @ F

    # Compute the SNR according to Eq. 21
    snr = jnp.abs(rho) * jnp.sqrt(n_samples) / \
        jnp.sqrt(stats.var(QEdata, axis=0))
    return snr

@jwnys
Copy link
Collaborator Author

jwnys commented Mar 1, 2023

Of course, if we are computing them throughout the time evolution anyway, they should be available at any step.

@PhilipVinc
Copy link
Member

can you make a PR and add a quick test that checks that it is updated after a step?

@PhilipVinc
Copy link
Member

in fact, isn't it already exposed, just undocumented?

@PhilipVinc
Copy link
Member

PhilipVinc commented Mar 1, 2023

IN fact, if you redefine this function for TDVP Schmitt to also log the snr it will end up in the data we log.
It would be a nice addition

    def _log_additional_data(self, log_dict, step):
        log_dict["t"] = self.t
        log_dict["snr"] = self.snr

@jwnys
Copy link
Collaborator Author

jwnys commented Mar 2, 2023

In principle, you just need to determine the SNR once before any time evolution is carried out. The monitoring should just be done from time to time to verify. It's kept constant throughout the time evolution (see supp. mat. of the paper). All the code is there, it just needs to be separated a bit and be made available in an understandable function that is available without any time evolution driver.

@PhilipVinc
Copy link
Member

The SNR is computed anyway as part of the regularisation procedure, so logging it makes sense to me (I think it's why I'm even returning it from the jetted function, I was trying it out).

If you expose the understandable function somewhere where it makes sense and can be more easily accessed I think it's a good idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants