-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose SNR in TDVPSchmitt #1441
Comments
Actually, this internal part would be simply @partial(jax.jit, static_argnames=("n_samples"))
def _compute_snr(n_samples, E_loc, S):
E = stats.statistics(E_loc)
ΔE_loc = E_loc.T.reshape(-1, 1) - E.mean
stack_jacobian = S.mode == "complex"
O = S.O / jnp.sqrt(n_samples) # already divided by jnp.sqrt(n_s)
if stack_jacobian:
O = O.reshape(-1, 2, S.O.shape[-1])
O = O[:, 0, :] + 1j * O[:, 1, :]
Sd = S.to_dense()
ev, V = jnp.linalg.eigh(Sd)
OEdata = O.conj() * ΔE_loc
F = stats.sum(OEdata, axis=0)
# Note: this implementation differs from Eq. 20 in Markus's paper, which I would
# implement as `rho = mpi.mean(QEdata, axis=0)`. However, this is different from
# changing the basis AFTER averaging over the samples, and leads to the wrong
# normalisation of RHo.
Q = jnp.tensordot(V.conj().T, O.T, axes=1).T
QEdata = Q.conj() * ΔE_loc
rho = V.conj().T @ F
# Compute the SNR according to Eq. 21
snr = jnp.abs(rho) * jnp.sqrt(n_samples) / \
jnp.sqrt(stats.var(QEdata, axis=0))
return snr |
Of course, if we are computing them throughout the time evolution anyway, they should be available at any step. |
can you make a PR and add a quick test that checks that it is updated after a step? |
in fact, isn't it already exposed, just undocumented? |
IN fact, if you redefine this function for TDVP Schmitt to also log the snr it will end up in the data we log. def _log_additional_data(self, log_dict, step):
log_dict["t"] = self.t
log_dict["snr"] = self.snr |
In principle, you just need to determine the SNR once before any time evolution is carried out. The monitoring should just be done from time to time to verify. It's kept constant throughout the time evolution (see supp. mat. of the paper). All the code is there, it just needs to be separated a bit and be made available in an understandable function that is available without any time evolution driver. |
The SNR is computed anyway as part of the regularisation procedure, so logging it makes sense to me (I think it's why I'm even returning it from the jetted function, I was trying it out). If you expose the understandable function somewhere where it makes sense and can be more easily accessed I think it's a good idea. |
In order for TDVPSchmitt to work, one needs to analyze the SNR values and determine the cutoff by hand. To my knowledge, SNR is not exposed, so there is currently no way to determine the SNR. Since this would be typically determined for a fixed operators, it would need a general signature independent of the driver, like this:
The text was updated successfully, but these errors were encountered: