Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linalg.lstsq to use svd to compute the result for rank deficient matrices. #126652

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ZelboK
Copy link
Contributor

@ZelboK ZelboK commented May 19, 2024

Fixes #117122

This PR adds the logic so that in the case of rank deficient matrices, it can fallback to an SVD backend for batched mode.

I apologize for the previous PR... I messed up a rebase and it ended up showing a million changes.

cc @lezcano

Copy link

pytorch-bot bot commented May 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126652

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fd07d5e with merge base 853081a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: linalg_frontend release notes category label May 19, 2024
solution.set_(solution.storage(), solution_view.storage_offset(),
solution_view.sizes(), solution_view.strides());
} else {
solution = at::zeros({solution.size(-1), n}, solution.options());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is going on here??

Copy link
Contributor Author

@ZelboK ZelboK May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're referring to just everything inside the else clause correct?

I found that with a tensor A that has rows > cols

A = torch.tensor([[1.0, 2.0],
                  [3.0, 4.0],
                  [5.0, 6.0],
                  [7.0, 8.0]], device='cuda')

# Create tensor B with shape (4, 1)
B = torch.tensor([[1.0],
                  [2.0],
                  [3.0],
                  [4.0]], device='cuda')

X_lstsq = torch.linalg.lstsq(A, B, driver='gelss').solution

would lead to
RuntimeError: start (2) + length (2) exceeds dimension size (2).

Is this incorrect? I'm refreshing my linear algebra here and I might not have the correct understanding.

def svd_lstsq(AA, BB, tol=1e-5):
    U, S, Vh = torch.linalg.svd(AA, full_matrices=False)
    Spinv = torch.zeros_like(S)
    Spinv[S>tol] = 1/S[S>tol]
    UhBB = U.adjoint() @ BB
    if Spinv.ndim!=UhBB.ndim:
      Spinv = Spinv.unsqueeze(-1)
    SpinvUhBB = Spinv * UhBB
    return Vh.adjoint() @ SpinvUhBB

X_svd= svd_lstsq(A, B)

This fo example will not throw an error with the same tensors.

Also I should have clarified this earlier. Sorry.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why should you allocate a new tensor when you already have a solution allocated in the else path? And why a tensor of zeros?

Copy link
Contributor Author

@ZelboK ZelboK May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right, I shouldn't allocate a new tensor.
So wrt the zeros, in hindsight, it makes no sense to zero it out as that's not the correct behavior(this should have a solution, right?). An exception actually tells the user too where this is just silent UB. How do we handle this case though? Does solution need to be reshaped before using set_ or something in the else path?

Since im still new, curious to know if this is out of scope for this PR? This exception occurs in general for when the solution.size(-2) < n. I don't mind doing it in this PR since it is small(better use of github runners too rather than 2 split PRs).

@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
if (input.numel() == 0) {
auto output_shape = input.sizes().vec();
output_shape.back() = other.size(-1);
rank.zero_();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rank is required later on, this solves the problem of the integer overflow later when toInt() is called, because it wasn't set to anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source release notes: linalg_frontend release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve behaviour of torch.linalg.lstsq on CUDA GPU for rank defficient matrices
4 participants