diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..f009af3 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -28,12 +28,18 @@ ne, neg, pow, + quantile, relu, rms_norm, + rot90, rotary_position_embedding, rsqrt, scaled_dot_product_attention, + select_copy, + sgn, sigmoid, + sign, + signbit, silu, sin, softmax, @@ -71,12 +77,18 @@ "ne", "neg", "pow", + "quantile", "relu", "rms_norm", + "rot90", "rotary_position_embedding", "rsqrt", "scaled_dot_product_attention", + "select_copy", + "sgn", "sigmoid", + "sign", + "signbit", "silu", "sin", "softmax", diff --git a/src/ntops/kernels/quantile.py b/src/ntops/kernels/quantile.py new file mode 100644 index 0000000..6bd8058 --- /dev/null +++ b/src/ntops/kernels/quantile.py @@ -0,0 +1,122 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, q, dim_size, output, dim, block_size=None): + def _arrange_input_or_output(tensor, dim): + ndim = tensor.ndim + + if dim < 0: + dim += ndim + + non_target_dims = tuple(i for i in range(ndim) if i != dim) + + arranged = tensor.permute(non_target_dims + (dim,)) + + block_shape = tuple(1 for _ in non_target_dims) + (-1,) + non_target_dim_indices = tuple(range(len(non_target_dims))) + + arranged = arranged.tile(block_shape) + arranged.dtype = arranged.dtype.squeeze(non_target_dim_indices) + + return arranged + + input_arranged = _arrange_input_or_output(input, dim) + output_arranged = _arrange_input_or_output(output, 0) + + q_arranged = q.tile((-1,)) + q_arranged = q_arranged.squeeze(0) + + for _ in range(output_arranged.ndim): + q_arranged = q_arranged.unsqueeze(0) + + q_arranged = q_arranged.expand(output_arranged.shape) + + return input_arranged, q_arranged, dim_size, output_arranged + + +def linear_application(input, q, dim_size, output): + pos = ntl.cast(q * (dim_size - 1), ntl.float32) + i = ntl.cast(ntl.floor(pos), ntl.int32) + j = ntl.cast(ntl.ceil(pos), ntl.int32) + frac = pos - i + + sorted = ntl.sort(input) + lower_value = ntl.gather(sorted, i, 0) + higher_value = ntl.gather(sorted, j, 0) + + output = lower_value + frac * (higher_value - lower_value) # noqa: F841 + + +def lower_application(input, q, dim_size, output): + pos = ntl.cast(q * (dim_size - 1), ntl.float32) + i = ntl.cast(ntl.floor(pos), ntl.int32) + + sorted = ntl.sort(input) + lower_value = ntl.gather(sorted, i, 0) + + output = lower_value # noqa: F841 + + +def higher_application(input, q, dim_size, output): + pos = ntl.cast(q * (dim_size - 1), ntl.float32) + j = ntl.cast(ntl.ceil(pos), ntl.int32) + + sorted = ntl.sort(input) + higher_value = ntl.gather(sorted, j, 0) + + output = higher_value # noqa: F841 + + +def nearest_application(input, q, dim_size, output): + pos = ntl.cast(q * (dim_size - 1), ntl.float32) + + # Rounding mode for float to int conversion is always towards zero, + # we have to manually implement `rtne` (round to nearest, ties to even). + i = ntl.cast(ntl.floor(pos), ntl.int32) + frac = ntl.cast(pos - i, ntl.float32) + i = ntl.where(frac > 0.5, ntl.minimum(i + 1, dim_size - 1), i) + i = ntl.where((frac == 0.5) & (i % 2 == 1), ntl.minimum(i + 1, dim_size - 1), i) + + sorted = ntl.sort(input) + output = ntl.gather(sorted, i, 0) # noqa: F841 + + +def midpoint_application(input, q, dim_size, output): + pos = ntl.cast(q * (dim_size - 1), ntl.float32) + i = ntl.cast(ntl.floor(pos), ntl.int32) + j = ntl.cast(ntl.ceil(pos), ntl.int32) + + sorted = ntl.sort(input) + lower_value = ntl.gather(sorted, i, 0) + higher_value = ntl.gather(sorted, j, 0) + + output = (higher_value + lower_value) / 2 # noqa: F841 + + +def premake(in_ndim, out_ndim, dim, interpolation, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor(in_ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), + Tensor(0), + Tensor(out_ndim, dtype=dtype, shape_options={"constexpr": True}), + ) + + if interpolation == "linear": + application = linear_application + elif interpolation == "lower": + application = lower_application + elif interpolation == "higher": + application = higher_application + elif interpolation == "nearest": + application = nearest_application + elif interpolation == "midpoint": + application = midpoint_application + else: + raise ValueError(f"Unsupported interpolation method: {interpolation}") + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/rot90.py b/src/ntops/kernels/rot90.py new file mode 100644 index 0000000..2eb3bea --- /dev/null +++ b/src/ntops/kernels/rot90.py @@ -0,0 +1,88 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, output, k, dims, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + ndim = input.ndim + dims = tuple(dim if dim >= 0 else dim + ndim for dim in dims) + non_target_dims = tuple(i for i in range(ndim) if i not in dims) + + def _arrange_0(tensor): + arranged = tensor.flatten() + arranged = arranged.tile((block_size,)) + + return arranged + + def _arrange_1_or_3(tensor, dims): + arranged = tensor.permute(non_target_dims + dims) + arranged = arranged.flatten(end_dim=-1) + arranged = arranged.tile((1, -1)) + arranged.dtype = arranged.dtype.squeeze(0) + + return arranged + + def _arrange_2(tensor, dims): + arranged = tensor.permute(non_target_dims + dims) + + if ndim == 2: + arranged = arranged.unsqueeze(0) + + arranged = arranged.flatten(end_dim=-2) + arranged = arranged.tile((1, -1, -1)) + arranged.dtype = arranged.dtype.squeeze(0) + + return arranged + + if k % 4 == 0: + input_arranged = _arrange_0(input) + output_arranged = _arrange_0(output) + elif k % 4 == 1: + input_arranged = _arrange_1_or_3(input, dims) + output_arranged = _arrange_1_or_3(output, tuple(reversed(dims))) + elif k % 4 == 3: + input_arranged = _arrange_1_or_3(input, tuple(reversed(dims))) + output_arranged = _arrange_1_or_3(output, dims) + else: # k % 4 == 2 + input_arranged = _arrange_2(input, dims) + output_arranged = _arrange_2(output, dims) + + return input_arranged, output_arranged + + +def application_0(input, output): + output = input # noqa: F841 + + +def application_1_or_3(input, output): + if input.shape[0] == 1: + output = input # noqa: F841 + else: + output = ntl.flip(input, 0) # noqa: F841 + + +def application_2(input, output): + output = ntl.flip(ntl.flip(input, 0), 1) # noqa: F841 + + +def premake(ndim, k, dims, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, k=k, dims=dims, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + ) + + if k % 4 == 0: + application = application_0 + elif k % 4 == 2: + application = application_2 + else: # k % 4 == 1 or 3 + application = application_1_or_3 + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/select_copy.py b/src/ntops/kernels/select_copy.py new file mode 100644 index 0000000..52d777e --- /dev/null +++ b/src/ntops/kernels/select_copy.py @@ -0,0 +1,50 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, index, output, dim, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + if output.ndim < 1: + output = output.unsqueeze(0) + else: + output = output.flatten() + + output_arranged = output.tile((1,)) + output_arranged.dtype = output_arranged.dtype.squeeze(0) + + if input.ndim < 2: + input = input.unsqueeze(0) + else: + if dim < 0: + dim += input.ndim + + non_target_dims = tuple(i for i in range(input.ndim) if i != dim) + input = input.permute(non_target_dims + (dim,)) + + input_arranged = input.flatten(end_dim=-1) + input_arranged = input_arranged.tile((1, -1)) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + + return input_arranged, index, output_arranged + + +def application(input, index, output): + idx = ntl.cast(index, ntl.int32) + output = input[idx] # noqa: F841 + + +def premake(in_ndim, out_ndim, dim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor(in_ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(0, dtype=ninetoothed.int32), + Tensor(out_ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/sgn.py b/src/ntops/kernels/sgn.py new file mode 100644 index 0000000..dcbef4b --- /dev/null +++ b/src/ntops/kernels/sgn.py @@ -0,0 +1,39 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + def _arrange(input): + arranged = input.flatten(end_dim=-1) + arranged = arranged.tile((block_size, 1)) + arranged = arranged.tile((1, -1)) + arranged.dtype = arranged.dtype.squeeze(0) + + return arranged + + return _arrange(input), _arrange(output) + + +def application(input, output): + denominators = ntl.sqrt(input[0] * input[0] + input[1] * input[1]) + denominators = ntl.where(denominators == 0.0, 1.0, denominators) + + for i in range(input.shape[0]): + output[i] = input[i] / denominators # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/sign.py b/src/ntops/kernels/sign.py new file mode 100644 index 0000000..685fc93 --- /dev/null +++ b/src/ntops/kernels/sign.py @@ -0,0 +1,18 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = ntl.where(input > 0, 1, ntl.where(input < 0, -1, 0)) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/signbit.py b/src/ntops/kernels/signbit.py new file mode 100644 index 0000000..a74e2e3 --- /dev/null +++ b/src/ntops/kernels/signbit.py @@ -0,0 +1,26 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + if input.dtype is ntl.float16: + i_unint = ntl.cast(input, ntl.uint16, bitcast=True) + output = (i_unint >> 15) & 0x1 # noqa: F841 + elif input.dtype is ntl.float32: + i_unint = ntl.cast(input, ntl.uint32, bitcast=True) + output = (i_unint >> 31) & 0x1 # noqa: F841 + elif input.dtype is ntl.float64: + i_unint = ntl.cast(input, ntl.uint64, bitcast=True) + output = (i_unint >> 63) & 0x1 # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..4245d7c 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -28,12 +28,18 @@ from ntops.torch.ne import ne from ntops.torch.neg import neg from ntops.torch.pow import pow +from ntops.torch.quantile import quantile from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm +from ntops.torch.rot90 import rot90 from ntops.torch.rotary_position_embedding import rotary_position_embedding from ntops.torch.rsqrt import rsqrt from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention +from ntops.torch.select_copy import select_copy +from ntops.torch.sgn import sgn from ntops.torch.sigmoid import sigmoid +from ntops.torch.sign import sign +from ntops.torch.signbit import signbit from ntops.torch.silu import silu from ntops.torch.sin import sin from ntops.torch.softmax import softmax @@ -71,12 +77,18 @@ "ne", "neg", "pow", + "quantile", "relu", "rms_norm", + "rot90", "rotary_position_embedding", "rsqrt", "scaled_dot_product_attention", + "select_copy", + "sgn", "sigmoid", + "sign", + "signbit", "silu", "sin", "softmax", diff --git a/src/ntops/torch/quantile.py b/src/ntops/torch/quantile.py new file mode 100644 index 0000000..892de94 --- /dev/null +++ b/src/ntops/torch/quantile.py @@ -0,0 +1,82 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _pad_dims_to_next_power_of_2 + + +def quantile(input, q, dim=None, keepdim=False, interpolation="linear", out=None): + is_scalar = False + + if isinstance(q, float): + q = torch.tensor([q], dtype=input.dtype, device=input.device) + # Use `from_list` method to create a tensor in infinicore. + # q = torch.from_list([q], dtype=input.dtype, device=input.device) + is_scalar = True + elif q.ndim == 0: + q = q.unsqueeze(0) + + # If `dim` is None, `input` will be flattened before computation. + ndim = None + + if dim is None: + ndim = input.ndim + # `flatten` is not supported in `infinicore.tensor`, use `view` instead. + flattened_size = 1 + + for s in input.shape: + flattened_size *= s + + input = input.contiguous().view([flattened_size]) + dim = 0 + + # Pad the `input` and `q` to the next power of 2 along the specified dimensions. + input_padded = _pad_dims_to_next_power_of_2(input, dim, value=float("inf")) + q_padded = _pad_dims_to_next_power_of_2(q, 0) + + copy_back = False + + if out is None: + out_shape = list(input.shape) + out_shape[dim] = 1 + + if not keepdim: + out_shape.pop(dim) + elif ndim is not None: + out_shape.extend([1] * (ndim - 1)) + + out_shape.insert(0, q.shape[0]) + out = torch.empty(out_shape, dtype=input.dtype, device=input.device) + else: + if not out.is_contiguous(): + # Non-contiguous tensor does not work right, + # so we create a new contiguous tensor and copy back the result after computation. + original_out = out + copy_back = True + out = out.contiguous() + + if is_scalar: + # If `q` is a scalar, the corresponding `output` will also be a scalar, + # but the application uses `gather` to get the sorted values, which requires + # the `output` to have at least 1 dimension. We can unsqueeze the `output` + # to make it compatible with the application. + out = out.unsqueeze(0) + + if keepdim: + out_adjust = out.squeeze(dim + 1) + else: + out_adjust = out + + kernel = _cached_make( + ntops.kernels.quantile.premake, input.ndim, out_adjust.ndim, dim, interpolation + ) + + kernel(input_padded, q_padded, input.shape[dim], out_adjust) + + if is_scalar: + out = out.squeeze(0) + + if copy_back: + original_out.copy_(out) + out = original_out + + return out diff --git a/src/ntops/torch/rot90.py b/src/ntops/torch/rot90.py new file mode 100644 index 0000000..64c24a0 --- /dev/null +++ b/src/ntops/torch/rot90.py @@ -0,0 +1,42 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _pad_dims_to_next_power_of_2 + + +def rot90(input, k=1, dims=(0, 1), *, out=None): + if out is None: + if k % 2 == 0: + out = torch.empty_like(input) + else: + dims_permute = list(range(input.ndim)) + dims_permute[dims[0]], dims_permute[dims[1]] = ( + dims_permute[dims[1]], + dims_permute[dims[0]], + ) + out = torch.empty( + input.permute(dims_permute).shape, + dtype=input.dtype, + device=input.device, + ) + + if k % 4 == 1: + input_prepared = _pad_dims_to_next_power_of_2( + input, dims[1], padding_right=False + ) + elif k % 4 == 2: + input_prepared = _pad_dims_to_next_power_of_2( + input, list(reversed(dims)), padding_right=False + ) + elif k % 4 == 3: + input_prepared = _pad_dims_to_next_power_of_2( + input, dims[0], padding_right=False + ) + else: + input_prepared = input + + kernel = _cached_make(ntops.kernels.rot90.premake, input.ndim, k, tuple(dims)) + + kernel(input_prepared, out) + + return out diff --git a/src/ntops/torch/select_copy.py b/src/ntops/torch/select_copy.py new file mode 100644 index 0000000..21cdb22 --- /dev/null +++ b/src/ntops/torch/select_copy.py @@ -0,0 +1,16 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def select_copy(input, dim, index, *, out=None): + if out is None: + shape = tuple(input.size(i) for i in range(input.ndim) if i != dim) + out = torch.empty(shape, dtype=input.dtype, device=input.device) + + kernel = _cached_make(ntops.kernels.select_copy.premake, input.ndim, out.ndim, dim) + + kernel(input, index, out) + + return out diff --git a/src/ntops/torch/sgn.py b/src/ntops/torch/sgn.py new file mode 100644 index 0000000..0062345 --- /dev/null +++ b/src/ntops/torch/sgn.py @@ -0,0 +1,21 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def sgn(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + if input.dtype not in (torch.complex64, torch.complex128): + kernel = _cached_make(ntops.kernels.sign.premake, input.ndim) + kernel(input, out) + else: + input = torch.view_as_real(input) + out_rm = torch.view_as_real(out) + + kernel = _cached_make(ntops.kernels.sgn.premake, input.ndim) + kernel(input, out_rm) + + return out diff --git a/src/ntops/torch/sign.py b/src/ntops/torch/sign.py new file mode 100644 index 0000000..7c43a4c --- /dev/null +++ b/src/ntops/torch/sign.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def sign(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.sign.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/signbit.py b/src/ntops/torch/signbit.py new file mode 100644 index 0000000..6083825 --- /dev/null +++ b/src/ntops/torch/signbit.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def signbit(input, *, out=None): + if out is None: + out = torch.empty_like(input, dtype=torch.bool) + + kernel = _cached_make(ntops.kernels.signbit.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py index e9b2dde..94f5e2d 100644 --- a/src/ntops/torch/utils.py +++ b/src/ntops/torch/utils.py @@ -2,6 +2,7 @@ import ninetoothed import torch +import torch.nn.functional as F import ntops @@ -68,3 +69,49 @@ def _get_matmul_input_precision(): return ntops.kernels.mm.InputPrecisionVariant.IEEE return ntops.kernels.mm.InputPrecisionVariant.TF32 + + +# Current ninetoothed (0.23.0) does not support `ninetoothed.Tensor.pad` yet, +# so we use `torch.nn.functional.pad` for now. +# Todo: Switch to `ninetoothed.Tensor.pad` once it's supported. +def _pad_dims_to_next_power_of_2(tensor, dims, padding_right=True, value=0): + if isinstance(dims, int): + target_dims = [dims] + elif isinstance(dims, (list, tuple)): + target_dims = list(dims) + else: + raise ValueError("dims must be an int or a list/tuple of ints") + + for i, d in enumerate(target_dims): + if d < 0: + d += tensor.ndim + + if d < 0 or d >= tensor.ndim: + raise ValueError(f"Invalid dims: {dims}") + + target_dims[i] = d + + padding = [0] * (tensor.ndim * 2) + padding_flag = False + + for d in target_dims: + current_len = tensor.size(d) + + if (current_len & (current_len - 1)) == 0: + continue + else: + padding_flag = True + exponent = current_len.bit_length() + target_len = 1 << exponent + + pad_len = target_len - current_len + + pad_idx = (tensor.ndim - 1 - d) * 2 + (1 if padding_right else 0) + padding[pad_idx] = pad_len + + if not padding_flag: + return tensor + + padded_tensor = F.pad(tensor, padding, mode="constant", value=value) + + return padded_tensor diff --git a/tests/test_quantile.py b/tests/test_quantile.py new file mode 100644 index 0000000..413a0b4 --- /dev/null +++ b/tests/test_quantile.py @@ -0,0 +1,38 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("keepdim", (False, True)) +@pytest.mark.parametrize( + "interpolation", ("linear", "lower", "higher", "nearest", "midpoint") +) +@pytest.mark.parametrize(*generate_arguments()) +def test_quantile(shape, keepdim, interpolation, dtype, device, rtol, atol): + # `torch.quantile` does not support float16. + if dtype == torch.float16: + return + + input = torch.randn(shape, dtype=dtype, device=device) + q_size = random.randint(0, 7) + q = ( + torch.rand(q_size, dtype=dtype, device=device) + if q_size > 0 + else random.random() + ) + dim = random.randint(0, input.ndim - 1) if q_size < 5 else None + + ninetoothed_output = ntops.torch.quantile( + input, q, dim=dim, keepdim=keepdim, interpolation=interpolation + ) + reference_output = torch.quantile( + input, q, dim=dim, keepdim=keepdim, interpolation=interpolation + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_rot90.py b/tests/test_rot90.py new file mode 100644 index 0000000..6982bbb --- /dev/null +++ b/tests/test_rot90.py @@ -0,0 +1,32 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("k", (0, 1, 2, 3)) +@pytest.mark.parametrize(*generate_arguments()) +def test_rot90(shape, k, dtype, device, rtol, atol): + if len(shape) < 2: + shape.append(2) + + input = torch.randn(shape, dtype=dtype, device=device) + k += random.randint(-100, 100) * 4 + + dim_0 = random.randint(0, len(shape) - 1) + dim_1 = random.randint(0, len(shape) - 1) + + if dim_0 == dim_1: + dim_1 = (dim_1 + 1) % len(shape) + + dims = (dim_0, dim_1) + + ninetoothed_output = ntops.torch.rot90(input, k=k, dims=dims) + reference_output = torch.rot90(input, k=k, dims=dims) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_select_copy.py b/tests/test_select_copy.py new file mode 100644 index 0000000..8eefd05 --- /dev/null +++ b/tests/test_select_copy.py @@ -0,0 +1,21 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_select_copy(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input.ndim - 1) + index = random.randint(0, input.size(dim) - 1) + + ninetoothed_output = ntops.torch.select_copy(input, dim=dim, index=index) + reference_output = torch.select_copy(input, dim=dim, index=index) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sgn.py b/tests/test_sgn.py new file mode 100644 index 0000000..1214904 --- /dev/null +++ b/tests/test_sgn.py @@ -0,0 +1,23 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("is_complex", (False, True)) +@pytest.mark.parametrize(*generate_arguments()) +def test_sgn(shape, is_complex, dtype, device, rtol, atol): + if dtype == torch.float16 or not is_complex: + input_tensor = torch.randn(shape, dtype=dtype, device=device) + else: + real_part = torch.randn(shape, dtype=dtype.to_real(), device=device) + imag_part = torch.randn(shape, dtype=dtype.to_real(), device=device) + input_tensor = torch.complex(real_part, imag_part) + + ninetoothed_output = ntops.torch.sgn(input_tensor) + reference_output = torch.sgn(input_tensor) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sign.py b/tests/test_sign.py new file mode 100644 index 0000000..1ad9314 --- /dev/null +++ b/tests/test_sign.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_sign(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.sign(input) + reference_output = torch.sign(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_signbit.py b/tests/test_signbit.py new file mode 100644 index 0000000..9bcac14 --- /dev/null +++ b/tests/test_signbit.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_signbit(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.signbit(input) + reference_output = torch.signbit(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)