Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eror when calcualte gradients with zygote #75

Open
radudiaconu0 opened this issue Apr 26, 2024 · 5 comments
Open

Eror when calcualte gradients with zygote #75

radudiaconu0 opened this issue Apr 26, 2024 · 5 comments

Comments

@radudiaconu0
Copy link

i get the folowing error when i try to calcualte gradients:

ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations do not execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use allowscalar or @allowscalar
to enable scalar iteration globally or for the operations in question.
Stacktrace:

using TensorCast
using Flux, AMDGPU, OMEinsum
using OMEinsum

struct LayerNorm{T<:AbstractArray}
    g::T
end

function LayerNorm(dim::Int)
    g = ones(Float32, (1, 1, dim, 1))
    LayerNorm(g)
end

Flux.@layer :expand LayerNorm

function (ln::LayerNorm)(x)
    dims = 3
    eps::Float32 = 1e-5
    μ = mean(x; dims=dims)
    σ² = var(x; dims=dims)
    x_normalized = (x .- μ) .* sqrt.((eps .+ σ²))
    return x_normalized .* ln.g
end


struct LinearAttention{T<:Chain,S<:Conv}
    scale::Float32
    heads::Int
    to_qkv::S
    to_out::T
end


function LinearAttention(dim::Int; heads::Int=4, dim_head::Int=32)
    scale::Float32 = dim_head^-0.5
    hidden_dim = dim_head * heads
    to_qkv = Conv((1, 1), dim => hidden_dim * 3, bias=false)
    to_out = Chain(
        Conv((1, 1), hidden_dim => dim),
        LayerNorm(dim)  # Assuming a LayerNorm implementation as earlier
    )
    return LinearAttention(scale, heads, to_qkv, to_out)
end

Flux.@layer :expand LinearAttention
using Statistics: mean, var
function (la::LinearAttention)(x::ROCArray)
    h, w, c, b = size(x)
    qkv = Flux.chunk(la.to_qkv(x), 3; dims=3)
    q, k, v = qkv
    q, k, v = q[:, :, :, :], k[:, :, :, :], v[:, :, :, :]
    @cast q[c, (x, y), h, b] |= q[x, y, (c, h), b] h in 1:la.heads
    @cast k[c, (x, y), h, b] |= k[x, y, (c, h), b] h in 1:la.heads
    @cast v[c, (x, y), h, b] |= v[x, y, (c, h), b] h in 1:la.heads
    println(typeof(q))

    q = softmax(q, dims=1)
    k = softmax(k, dims=2)

    q *= la.scale

    v /= (h * w)
    println("typeof of k", typeof(k))
    println("typeof v", typeof(v))


    context = ein"dnhb,enhb->dehb"(k, v)
    println(typeof(context))

    # println("context: ", size(context))
    out = ein"dehb,dnhb->enhb"(context, q)
    println(typeof(out))
    # println("out: ", size(out))

    @cast out[x, y, (c, h), b] |= out[c, (x, y), h, b] (h in 1:la.heads, x in 1:h, y in 1:w)
    println(typeof(out))
    return la.to_out(out)
end


x = AMDGPU.randn(256, 256, 64, 8)
loss, grad = Flux.withgradient(layer) do l
    a = layer(x)
    sum(a)
end //this doen't work
layer(x) //this works

type of q is RocArry on inference
but it is Base.ReshapedArray{Float32, 4, TransmuteDims.TransmutedDimsArray{Float32, 5, (3, 1, 2, 4, 5), (2, 3, 1, 4, 5), ROCArray{Float32, 5, AMDGPU.Runtime.Mem.HIPBuffer}}, NTuple{4, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}

on calculating gradients

@mcabbott
Copy link
Owner

Does OMEinsum support AMD at all?

TensorCast should in principle support everything. It is by default as lazy as possible, and hence returns things like Base.ReshapedArray, and some GPU things have trouble working through these wrappers. Replacing := with |= is how you ask this package to be less lazy.

@radudiaconu0
Copy link
Author

yes. i personally added suport to it

@radudiaconu0
Copy link
Author

ywah. i did that (as i was saying on salck) to directly get arrays an not .ReshapedArray

@mcabbott
Copy link
Owner

Sorry the code does have |=.

Can you figure out what operation is causing the problem? There are 4 @cast and two @ein, what's the minimal example which causes this? Can you make gradient return a surprising type?

(I cannot run this locally)

@radudiaconu0
Copy link
Author

Actually yes
This

v /= (h * w)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants