Replies: 2 comments
-
This might be of interest, Compositional Linear Algebra (CoLA) CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA supports both PyTorch and JAX. |
Beta Was this translation helpful? Give feedback.
0 replies
-
Pinging the author of #1314 @hawkinsp :-) happy to help if there are some instructions (#1099). |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello Jax Developers and Community,
I've been delving into Jax's capabilities in Singular Value Decompositions (SVDs). While Jax provides support for
gesvdjBatched
(#1314), I've observed that this implementation comes with a matrix size restriction (m, n <= 32
). On looking further into cuSOLVER's offerings, I stumbled upongesvdaStridedBatched
which doesn't seem to be a part of Jax currently.Here are some considerations related to
gesvdaStridedBatched
:General Matrix Support & No Size Restriction: Unlike
gesvdjBatched
,gesvdaStridedBatched
offers support for general matrices without the strict size limitations, potentially offering more flexibility.Usage in CuPy: I noticed that CuPy has integrated
gesvdaStridedBatched
into their framework. Here's their implementation. It's intriguing to see how they've leveraged it, and it made me wonder about its possible advantages if introduced in Jax.Performance Advantages: An NVIDIA presentation (Page 14) touches upon the speed-ups achievable with Batched GESVDA. It could be of interest to evaluate if these performance metrics align with what Jax aims for.
On a related note, it seems that AMD's equivalent, ROCm, does not have such matrix size restrictions, as seen here. It does make me curious about the technical challenges or considerations behind these size restrictions.
I understand there are always multiple factors at play when deciding on integrations. It would be enlightening to hear the community's and developers' perspectives on this topic and whether
gesvdaStridedBatched
might find a place in Jax's roadmap.Thank you for continually pushing the boundaries with Jax. Looking forward to learning more.
Edit 20230918: note that gesvdaStridedBatched can be called in the pytorch interface with the commit pytorch/pytorch@d136852 of the pull request pytorch/pytorch#74521.
Best wishes,
Ray
Beta Was this translation helpful? Give feedback.
All reactions