Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop using implicit style differentiating #221

Closed
1 task
ablaom opened this issue Apr 17, 2023 · 18 comments · Fixed by #251
Closed
1 task

Stop using implicit style differentiating #221

ablaom opened this issue Apr 17, 2023 · 18 comments · Fixed by #251
Assignees

Comments

@ablaom
Copy link
Collaborator

ablaom commented Apr 17, 2023

It seems the style used here is being deprecated and won't work with Flux 0.14:

gs = Flux.gradient(parameters) do


edit After discussion below, I suggest we wait on

and refactor to use a optimiser-based solution to weight regularisation, which will avoid current limitations of explicit differentiation outlined in the discussion. Note, this will likely mean the reported training_loss must change, as it will no longer include the weight penalty. So this will be breaking.

@mcabbott
Copy link
Member

Relatedly, it would be nice if the MLJFlux models listed here https://github.com/FluxML/model-zoo#examples-elsewhere could be updated to use latest Flux, and avoid implicit gradients.

Examples of similar upgrades: https://github.com/FluxML/model-zoo/issues?q=is%3Aclosed+label%3Aupdate+explicit

In the end, Flux 0.14 did not drop support for implicit gradients, but 0.15 should.

@ablaom
Copy link
Collaborator Author

ablaom commented Jul 30, 2023

@pat-alt Would you have any time and interest in addressing this issue?

@pat-alt
Copy link
Collaborator

pat-alt commented Jul 31, 2023

That actually syncs well with some of my other outstanding issues and I think I'll have to address this very same thing in CounterfactualExplanations.jl soon. So yes, please feel free to assign to this one to me and I'll look at it in the coming weeks 👍

@pat-alt pat-alt self-assigned this Aug 1, 2023
@pat-alt pat-alt linked a pull request Aug 1, 2023 that will close this issue
2 tasks
@pat-alt
Copy link
Collaborator

pat-alt commented Aug 1, 2023

I have added a draft for this with very minor changes here #230:

function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
    opt_state = Flux.setup(optimiser, chain)
    loss = model.loss
    n_batches = length(y)
    training_loss = zero(Float32)
    parameters = Flux.params(chain)
    for i in 1:n_batches
        batch_loss, gs = Flux.withgradient(chain) do m
            yhat = m(X[i])
            pen = penalty(parameters) / n_batches
            loss(yhat, y[i]) + pen
        end
        training_loss += batch_loss
        Flux.update!(opt_state, chain, gs[1])
    end
    return training_loss / n_batches
end

Currently, the following test fails:

[ Info: regularization has an effect:
[ Info: acceleration = CPU1{Nothing}(nothing)
regularization has an effect (typename(CPU1)): Test Failed at /Users/patrickaltmeyer/code/MLJFlux.jl/test/integration.jl:25
  Expression: !(loss2  loss3)
   Evaluated: !(0.8354643267207931  0.8354643267207931)

I'm not quite sure what's happening. @mcabbott can you spot anything obviously wrong this?

@ToucheSir
Copy link
Member

ToucheSir commented Aug 1, 2023

That's because the regularization term is still using implicit params. Something like FluxML/Flux.jl#2040 (comment) will be needed for explicit params.

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2023

parameters = Flux.params(chain) outside the gradient context will only work in the implicit style -- changing the explicit local m will not change pen. (Edit -- as ToucheSir says, while I was typing!)

What is penalty? For L2 it will be better to use WeightDecay like this: http://fluxml.ai/Flux.jl/stable/training/training/#Regularisation

@pat-alt
Copy link
Collaborator

pat-alt commented Aug 1, 2023

Thanks both!

What is penalty? For L2 it will be better to use WeightDecay like this: http://fluxml.ai/Flux.jl/stable/training/training/#Regularisation

Currently, penalty functions are explicitly defined callable objects in MLJFlux (see here). I saw the note on WeightDecay in the Flux docs and was wondering if it's worth changing that.

In any case, I can't really get either of the approaches you suggest to work in this particular case, so we may indeed want to rethink the implementation of the penalty functions, for example by using WeightDecay instead. Will require a little extra work, but should be doable. @ablaom what do you think?

@ToucheSir
Copy link
Member

I can't really get either of the approaches you suggest to work in this particular case

Can you elaborate? I'm not sure I understand why/how they wouldn't work.

@pat-alt
Copy link
Collaborator

pat-alt commented Aug 1, 2023

Sure!

Moving the params call inside as follows

function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
    opt_state = Flux.setup(optimiser, chain)
    loss = model.loss
    n_batches = length(y)
    training_loss = zero(Float32)
    for i in 1:n_batches
        batch_loss, gs = Flux.withgradient(chain) do m
            yhat = m(X[i])
            pen = penalty(Flux.params(m)) / n_batches
            loss(yhat, y[i]) + pen
        end
        training_loss += batch_loss
        Flux.update!(opt_state, chain, gs[1])
    end
    return training_loss / n_batches
end

the tests just seem to get stuck at some point. I may try and commit this now, but at least locally on my machine things get stuck.

Alternatively, using the approach in FluxML/Flux.jl#2040 (comment) as follows

function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
    opt_state = Flux.setup(optimiser, chain)
    loss = model.loss
    n_batches = length(y)
    training_loss = zero(Float32)
    for i in 1:n_batches
        batch_loss, gs = Flux.withgradient(chain) do m
            yhat = m(X[i])
            l = loss(yhat, y[i])
            reg = Functors.fmap(penalty, m; exclude=Flux.trainable)
            l + reg / n_batches
        end
        training_loss += batch_loss
        Flux.update!(opt_state, chain, gs[1])
    end
    return training_loss / n_batches
end

I get the following error:

[ Info: acceleration = CPU1{Nothing}(nothing)
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(5 => 15)      # 90 parameterssummary(x) = "5×20 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/n3cOc/src/layers/stateless.jl:60
fit! and dropout (typename(CPU1)): Error During Test at /Users/patrickaltmeyer/code/MLJFlux.jl/test/test_utils.jl:38
  Got exception outside of a @test
  TypeError: non-boolean (NamedTuple{(:layers,), Tuple{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dropout{Float64, Colon, Random.TaskLocalRNG}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}) used in boolean context

Perhaps it has to do with the fact that the penalizers aren't Functors?

@ToucheSir
Copy link
Member

Yeah I wouldn't try the first version you have there, was referring to the second one or @mcabbott's suggestion about moving things to the optimization step.

I get the following error: ...

Pretty sure that's due to a typo in the original example code snippet. See FluxML/Flux.jl#2040 (comment)

@pat-alt
Copy link
Collaborator

pat-alt commented Aug 11, 2023

hmm in that case I get the following error: MethodError: no method matching Dense(::Float32, ::Float32, ::typeof(identity)). Any ideas?

@ablaom
Copy link
Collaborator Author

ablaom commented Aug 13, 2023

Thanks @pat-alt for this work!

In any case, I can't really get either of the approaches you suggest to work in this particular case, so we may indeed want to rethink the implementation of the penalty functions, for example by using WeightDecay instead. Will require a little extra work, but should be doable. @ablaom what do you think?

WeightDecay only provides a mechanism for L2 regularisation, but the current implementation provides for a combination of both L1 regularisation (good for feature selection) and L2 regularisation. It seems a pity to drop support of a feature to accomodate the new explicit syntax.

I don't know what the source of your current issue is.

@ablaom
Copy link
Collaborator Author

ablaom commented Sep 5, 2023

@pat-alt I don't think your use of Functors.fmap is valid here. The penalty function takes a tuple of matrices, as returned by Flux.params(chain), and returns a single aggregate number.

Your first suggestion (with params) actually works but is 3600 times slower than the implicit style code on the dev branch, when tested on a small model / dataset.

@ToucheSir To implement mixed L1/L2 penalties (not just L2 ones) I don't really see how to avoid the params in the withgradient block. (And this is after all a suggestion in the Flux documentation - second code block here). Am I to conclude that explicit-Zygote style AD is just no good on this problem?

@ToucheSir
Copy link
Member

To implement mixed L1/L2 penalties (not just L2 ones) I don't really see how to avoid the params in the withgradient block. (And this is after all a suggestion in the Flux documentation - second code block here). Am I to conclude that explicit-Zygote style AD is just no good on this problem?

It's arguably better, but it requires some helper functionality that isn't currently nicely packaged up in a library. FluxML/Optimisers.jl#57 is one example of how to do this and how we're thinking about packaging it up going forwards, but the problem with general solutions is that they take time. For this work, you may be better served by implementing a similar but more constrained version on top of Functors.jl and Optimisers.jl which only includes as much as MLJFlux needs for regularization. If you do, feel free to ping me for input.

@ablaom
Copy link
Collaborator Author

ablaom commented Sep 5, 2023

@ToucheSir Thanks for the prompt response and offer of help.

So, with the apparatus you describe (Functors.jl, etc ) what code replaces the following to avoid the params call, working for a generic Flux model, chain, and so that differentiating chain -> penalty is free of issues?

# function to return penalty on an array:
f(A) = 0.01*sum(abs2, A) + 0.02*sum(abs, A)

f(ones(2,3))
# 0.6000000000000001

chain = Chain(Dense(3=>5), Dense(5=>1, relu))
penalty = sum(f.(Flux.params(chain)))

@ablaom
Copy link
Collaborator Author

ablaom commented Sep 5, 2023

Or if you prefer, how should the regularisation example in the Flux documentation be re-written (without the weight-decay trick , which does not work for L1 penalty)?

@ToucheSir
Copy link
Member

f(A) = ...
penalty = mytotal(f, chain)

Where mytotal is a simplified form or direct copy of Optimisers.total as I mentioned earlier.

...(without the weight-decay trick , which does not work for L1 penalty)?

Side note, but I remembered looking into this a few months back and coming across https://stackoverflow.com/questions/42704283/l1-l2-regularization-in-pytorch/66630301#66630301, which suggests that L1 could be implemented using a similar trick. Whether that would be compatible with MLJFlux's API I'm not sure, but we could consider adding it to Optimisers.jl.

@ablaom
Copy link
Collaborator Author

ablaom commented Sep 7, 2023

Thanks for the help @ToucheSir . Unfortunately, Optimisers.total is not working for me. I've tried some variations on that approach but without any luck.

I suggest we wait on the WeightDecay extension referenced above and switch that approach, which is likely more performant anyhow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants