Skip to content
Open
152 changes: 152 additions & 0 deletions tensorflow/core/kernels/conv_ops_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,41 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

// Maximum tensor size (in bytes) that cuDNN can handle safely.
// cuDNN has internal limits around 2GB for certain operations.
// We use a conservative threshold to avoid CUDA invalid resource handle errors.
constexpr int64_t kMaxCudnnTensorSizeBytes = 2LL * 1024 * 1024 * 1024; // 2GB

// Helper function to check if the tensor size exceeds the safe limit for cuDNN.
// Returns true if the tensor is too large and needs fallback processing.
template <typename T>
inline bool IsTensorTooLargeForCudnn(const Tensor& tensor) {
int64_t tensor_size_bytes = tensor.NumElements() * sizeof(T);
return tensor_size_bytes > kMaxCudnnTensorSizeBytes;
}

// Helper function to compute the maximum batch size that keeps the tensor
// under the cuDNN size limit.
template <typename T>
inline int64_t ComputeSafeBatchSize(const Tensor& tensor, int64_t current_batch,
TensorFormat data_format) {
if (current_batch <= 0) return 1;
int64_t total_elements = tensor.NumElements();
if (total_elements <= 0) return 1;
// Handle edge case where total_elements < current_batch
if (total_elements < current_batch) {
// Each batch has less than 1 element on average, return 1
return 1;
}
int64_t elements_per_batch = total_elements / current_batch;
if (elements_per_batch <= 0) return 1;
int64_t max_elements = kMaxCudnnTensorSizeBytes / sizeof(T);
int64_t safe_batch = max_elements / elements_per_batch;
// Ensure at least batch size of 1, and cap at current batch size
return std::max(static_cast<int64_t>(1),
std::min(safe_batch, current_batch));
}

template <typename Device, typename T>
struct LaunchGeneric {
void operator()(OpKernelContext* ctx, const Tensor& input,
Expand Down Expand Up @@ -773,6 +808,123 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune,
absl::InvalidArgumentError("filter must not have zero elements "
"(i.e. all dimensions must be non-zero)"));

// Check if input tensor is too large for cuDNN and needs batch splitting.
// This addresses CUDA invalid resource handle errors with large tensors.
if (IsTensorTooLargeForCudnn<T>(input) && in_batch > 1) {
int64_t safe_batch = ComputeSafeBatchSize<T>(input, in_batch, data_format);
if (safe_batch < in_batch && safe_batch > 0) {
VLOG(2) << "Input tensor too large for cuDNN, splitting batch from "
<< in_batch << " to chunks of " << safe_batch;

// Process in batches to avoid cuDNN memory limits
int64_t batch_idx = GetTensorDimIndex(data_format, 'N', input.dims());

// Validate batch dimension before proceeding
OP_REQUIRES(context, batch_idx >= 0 && batch_idx < input.dims(),
absl::InternalError("Invalid batch dimension index"));
OP_REQUIRES(context, input.dim_size(batch_idx) > 0,
absl::InternalError("Input batch dimension is zero"));
OP_REQUIRES(context, output->dim_size(batch_idx) > 0,
absl::InternalError("Output batch dimension is zero"));

for (int64_t start = 0; start < in_batch; start += safe_batch) {
int64_t chunk_size = std::min(safe_batch, in_batch - start);

// Create sliced input tensor
std::vector<int64_t> input_slice_shape;
for (int i = 0; i < input.dims(); ++i) {
if (i == batch_idx) {
input_slice_shape.push_back(chunk_size);
} else {
input_slice_shape.push_back(input.dim_size(i));
}
}
TensorShape input_slice_ts(input_slice_shape);
Tensor input_slice;
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
input_slice_ts,
&input_slice));

// Create sliced output tensor
std::vector<int64_t> output_slice_shape;
for (int i = 0; i < output->dims(); ++i) {
if (i == batch_idx) {
output_slice_shape.push_back(chunk_size);
} else {
output_slice_shape.push_back(output->dim_size(i));
}
}
TensorShape output_slice_ts(output_slice_shape);
Tensor output_slice;
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
output_slice_ts,
&output_slice));

// Calculate elements per batch with validated dimensions
int64_t input_batch_dim = input.dim_size(batch_idx);
int64_t elements_per_batch = input.NumElements() / input_batch_dim;

// Validate bounds before pointer arithmetic
int64_t input_offset = start * elements_per_batch;
OP_REQUIRES(context, input_offset + chunk_size * elements_per_batch <=
input.NumElements(),
absl::InternalError("Input slice bounds check failed"));

// Copy input slice from input tensor (device to device)
int64_t copy_size_bytes = chunk_size * elements_per_batch * sizeof(T);
auto src_ptr = se::DeviceMemoryBase(
const_cast<T*>(input.template flat<T>().data() + input_offset),
copy_size_bytes);
auto dst_ptr = se::DeviceMemoryBase(
const_cast<T*>(input_slice.template flat<T>().data()),
copy_size_bytes);
OP_REQUIRES_OK(context,
stream->MemcpyD2D(&dst_ptr, src_ptr, copy_size_bytes));

// Recursively call LaunchConvOpImpl with the smaller batch.
// Safety note: The recursive call is guaranteed not to re-enter this
// batch-splitting code path because:
// 1. safe_batch is computed to keep sliced tensors under the size limit
// 2. IsTensorTooLargeForCudnn will return false for the sliced tensor
// 3. Even if it were to trigger, in_batch would equal chunk_size,
// and safe_batch would equal chunk_size, so the condition
// "safe_batch < in_batch" would be false
LaunchConvOpImpl<T>(context, cudnn_use_autotune, input_slice, filter,
dilations, strides, padding, explicit_paddings,
data_format, &output_slice);

// Check for errors from recursive call
if (!context->status().ok()) return;

// Calculate output elements per batch with validated dimensions
int64_t output_batch_dim = output->dim_size(batch_idx);
int64_t output_elements_per_batch =
output->NumElements() / output_batch_dim;

// Validate bounds before pointer arithmetic
int64_t output_offset = start * output_elements_per_batch;
OP_REQUIRES(
context,
output_offset + chunk_size * output_elements_per_batch <=
output->NumElements(),
absl::InternalError("Output slice bounds check failed"));

// Copy output slice to output tensor (device to device)
int64_t output_copy_size_bytes =
chunk_size * output_elements_per_batch * sizeof(T);
auto out_src_ptr = se::DeviceMemoryBase(
const_cast<T*>(output_slice.template flat<T>().data()),
output_copy_size_bytes);
auto out_dst_ptr = se::DeviceMemoryBase(
const_cast<T*>(output->template flat<T>().data() + output_offset),
output_copy_size_bytes);
OP_REQUIRES_OK(context, stream->MemcpyD2D(&out_dst_ptr, out_src_ptr,
output_copy_size_bytes));
}
return;
}
}

bool is_grouped_convolution = filter_depth != in_depth;
// check if filter is 1x1 and stride/dilation are all ones
bool one_filter = true;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/dense_update_functor_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct DenseUpdate<GPUDevice, T, SUB> {
template struct functor::DenseUpdate<GPUDevice, T, SUB>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_INTEGRAL_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_COMPLEX_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_float8_e5m2(DEFINE_GPU_KERNELS);
TF_CALL_float8_e4m3fn(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/resource_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU_KERNELS);
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1864,5 +1864,64 @@ def testGatherBatchDimsNeg(self):
)
self.evaluate(result)

@test_util.run_in_graph_and_eager_modes
@test_util.run_gpu_only
def testComplexVariableAssignAddWithConj(self):
"""Test for issue #105367: Segfault with complex Variable, Conj, and assign_add."""
# Test with complex64
input_data_64 = constant_op.constant([1 + 2j, 3 + 4j], dtype=dtypes.complex64)
var_64 = resource_variable_ops.ResourceVariable(input_data_64, dtype=dtypes.complex64)
self.evaluate(var_64.initializer)

conj_result_64 = math_ops.conj(input_data_64)
assign_add_op_64 = var_64.assign_add(conj_result_64)
result_64 = self.evaluate(assign_add_op_64)

# Expected: [1+2j, 3+4j] + [1-2j, 3-4j] = [2+0j, 6+0j]
expected_64 = np.array([2+0j, 6+0j], dtype=np.complex64)
self.assertAllClose(result_64, expected_64)

# Test with complex128
input_data_128 = constant_op.constant([1 + 2j, 3 + 4j], dtype=dtypes.complex128)
var_128 = resource_variable_ops.ResourceVariable(input_data_128, dtype=dtypes.complex128)
self.evaluate(var_128.initializer)

conj_result_128 = math_ops.conj(input_data_128)
assign_add_op_128 = var_128.assign_add(conj_result_128)
result_128 = self.evaluate(assign_add_op_128)

# Expected: [1+2j, 3+4j] + [1-2j, 3-4j] = [2+0j, 6+0j]
expected_128 = np.array([2+0j, 6+0j], dtype=np.complex128)
self.assertAllClose(result_128, expected_128)

@test_util.run_in_graph_and_eager_modes
def testComplexVariableAssignAddCPU(self):
"""Test complex Variable assign_add on CPU."""
# Test with complex64
input_data_64 = constant_op.constant([1 + 2j, 3 + 4j], dtype=dtypes.complex64)
var_64 = resource_variable_ops.ResourceVariable(input_data_64, dtype=dtypes.complex64)
self.evaluate(var_64.initializer)

delta_64 = constant_op.constant([0.5 - 1j, 1 + 0.5j], dtype=dtypes.complex64)
assign_add_op_64 = var_64.assign_add(delta_64)
result_64 = self.evaluate(assign_add_op_64)

# Expected: [1+2j, 3+4j] + [0.5-1j, 1+0.5j] = [1.5+1j, 4+4.5j]
expected_64 = np.array([1.5+1j, 4+4.5j], dtype=np.complex64)
self.assertAllClose(result_64, expected_64)

# Test with complex128
input_data_128 = constant_op.constant([1 + 2j, 3 + 4j], dtype=dtypes.complex128)
var_128 = resource_variable_ops.ResourceVariable(input_data_128, dtype=dtypes.complex128)
self.evaluate(var_128.initializer)

delta_128 = constant_op.constant([0.5 - 1j, 1 + 0.5j], dtype=dtypes.complex128)
assign_add_op_128 = var_128.assign_add(delta_128)
result_128 = self.evaluate(assign_add_op_128)

# Expected: [1+2j, 3+4j] + [0.5-1j, 1+0.5j] = [1.5+1j, 4+4.5j]
expected_128 = np.array([1.5+1j, 4+4.5j], dtype=np.complex128)
self.assertAllClose(result_128, expected_128)

if __name__ == "__main__":
test.main()
Loading