From 3a982a965351dcd6dd1bc43b2ac8d71b3205839c Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Thu, 9 Jan 2025 13:07:30 +0000 Subject: [PATCH] fix: unify the torch.Tensor.cuda frontends (#28827) --- ivy/functional/frontends/torch/tensor.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 897c5a909ffa7..cf6425bf4d3ef 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -126,11 +126,6 @@ def itemsize(self): # Setters # # --------# - @device.setter - def cuda(self, device=None): - self.device = device - return self - @ivy_array.setter def ivy_array(self, array): self._ivy_array = array if isinstance(array, ivy.Array) else ivy.array(array) @@ -728,8 +723,8 @@ def detach_(self): def cpu(self): return ivy.to_device(self.ivy_array, "cpu") - def cuda(self): - return ivy.to_device(self.ivy_array, "gpu:0") + def cuda(self, device=None, non_blocking=False, memory_format=None): + return self.to("cuda" if device is None else device) @with_unsupported_dtypes({"2.2 and below": ("uint16",)}, "torch") @numpy_to_torch_style_args