ksddescent.ksdd_lbfgs

ksddescent.ksdd_lbfgs(x0, score, kernel='gaussian', bw=1.0, max_iter=10000, tol=1e-12, beta=0.5, store=False, verbose=False)

Kernel Stein Discrepancy descent with L-BFGS

Perform Kernel Stein Discrepancy descent with L-BFGS. L-BFGS is a fast and robust algorithm, that has no critical hyper-parameter.

Parameters
x0torch.tensor, size n_samples x n_features

initial positions

scorecallable

function that computes the score

kernl‘gaussian’ or ‘imq’

which kernel to choose

max_iterint

max numer of iters

bwfloat

bandwidth of the stein kernel

tolfloat

stopping criterion for L-BFGS

storebool

whether to stores the iterates

verbose: bool

wether to print the current loss

Returns
x: torch.tensor, size n_samples x n_features

The final positions

References

A.Korba, P-C. Aubin-Frankowski, S.Majewski, P.Ablin. Kernel Stein Discrepancy Descent International Conference on Machine Learning, 2021.