diff --git a/tensorflow/core/kernels/conv_ops_impl.h b/tensorflow/core/kernels/conv_ops_impl.h index 0d3fc798bbe3c2..e4a80a1524e19a 100644 --- a/tensorflow/core/kernels/conv_ops_impl.h +++ b/tensorflow/core/kernels/conv_ops_impl.h @@ -90,6 +90,41 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +// Maximum tensor size (in bytes) that cuDNN can handle safely. +// cuDNN has internal limits around 2GB for certain operations. +// We use a conservative threshold to avoid CUDA invalid resource handle errors. +constexpr int64_t kMaxCudnnTensorSizeBytes = 2LL * 1024 * 1024 * 1024; // 2GB + +// Helper function to check if the tensor size exceeds the safe limit for cuDNN. +// Returns true if the tensor is too large and needs fallback processing. +template +inline bool IsTensorTooLargeForCudnn(const Tensor& tensor) { + int64_t tensor_size_bytes = tensor.NumElements() * sizeof(T); + return tensor_size_bytes > kMaxCudnnTensorSizeBytes; +} + +// Helper function to compute the maximum batch size that keeps the tensor +// under the cuDNN size limit. +template +inline int64_t ComputeSafeBatchSize(const Tensor& tensor, int64_t current_batch, + TensorFormat data_format) { + if (current_batch <= 0) return 1; + int64_t total_elements = tensor.NumElements(); + if (total_elements <= 0) return 1; + // Handle edge case where total_elements < current_batch + if (total_elements < current_batch) { + // Each batch has less than 1 element on average, return 1 + return 1; + } + int64_t elements_per_batch = total_elements / current_batch; + if (elements_per_batch <= 0) return 1; + int64_t max_elements = kMaxCudnnTensorSizeBytes / sizeof(T); + int64_t safe_batch = max_elements / elements_per_batch; + // Ensure at least batch size of 1, and cap at current batch size + return std::max(static_cast(1), + std::min(safe_batch, current_batch)); +} + template struct LaunchGeneric { void operator()(OpKernelContext* ctx, const Tensor& input, @@ -773,6 +808,123 @@ void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, absl::InvalidArgumentError("filter must not have zero elements " "(i.e. all dimensions must be non-zero)")); + // Check if input tensor is too large for cuDNN and needs batch splitting. + // This addresses CUDA invalid resource handle errors with large tensors. + if (IsTensorTooLargeForCudnn(input) && in_batch > 1) { + int64_t safe_batch = ComputeSafeBatchSize(input, in_batch, data_format); + if (safe_batch < in_batch && safe_batch > 0) { + VLOG(2) << "Input tensor too large for cuDNN, splitting batch from " + << in_batch << " to chunks of " << safe_batch; + + // Process in batches to avoid cuDNN memory limits + int64_t batch_idx = GetTensorDimIndex(data_format, 'N', input.dims()); + + // Validate batch dimension before proceeding + OP_REQUIRES(context, batch_idx >= 0 && batch_idx < input.dims(), + absl::InternalError("Invalid batch dimension index")); + OP_REQUIRES(context, input.dim_size(batch_idx) > 0, + absl::InternalError("Input batch dimension is zero")); + OP_REQUIRES(context, output->dim_size(batch_idx) > 0, + absl::InternalError("Output batch dimension is zero")); + + for (int64_t start = 0; start < in_batch; start += safe_batch) { + int64_t chunk_size = std::min(safe_batch, in_batch - start); + + // Create sliced input tensor + std::vector input_slice_shape; + for (int i = 0; i < input.dims(); ++i) { + if (i == batch_idx) { + input_slice_shape.push_back(chunk_size); + } else { + input_slice_shape.push_back(input.dim_size(i)); + } + } + TensorShape input_slice_ts(input_slice_shape); + Tensor input_slice; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + input_slice_ts, + &input_slice)); + + // Create sliced output tensor + std::vector output_slice_shape; + for (int i = 0; i < output->dims(); ++i) { + if (i == batch_idx) { + output_slice_shape.push_back(chunk_size); + } else { + output_slice_shape.push_back(output->dim_size(i)); + } + } + TensorShape output_slice_ts(output_slice_shape); + Tensor output_slice; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + output_slice_ts, + &output_slice)); + + // Calculate elements per batch with validated dimensions + int64_t input_batch_dim = input.dim_size(batch_idx); + int64_t elements_per_batch = input.NumElements() / input_batch_dim; + + // Validate bounds before pointer arithmetic + int64_t input_offset = start * elements_per_batch; + OP_REQUIRES(context, input_offset + chunk_size * elements_per_batch <= + input.NumElements(), + absl::InternalError("Input slice bounds check failed")); + + // Copy input slice from input tensor (device to device) + int64_t copy_size_bytes = chunk_size * elements_per_batch * sizeof(T); + auto src_ptr = se::DeviceMemoryBase( + const_cast(input.template flat().data() + input_offset), + copy_size_bytes); + auto dst_ptr = se::DeviceMemoryBase( + const_cast(input_slice.template flat().data()), + copy_size_bytes); + OP_REQUIRES_OK(context, + stream->MemcpyD2D(&dst_ptr, src_ptr, copy_size_bytes)); + + // Recursively call LaunchConvOpImpl with the smaller batch. + // Safety note: The recursive call is guaranteed not to re-enter this + // batch-splitting code path because: + // 1. safe_batch is computed to keep sliced tensors under the size limit + // 2. IsTensorTooLargeForCudnn will return false for the sliced tensor + // 3. Even if it were to trigger, in_batch would equal chunk_size, + // and safe_batch would equal chunk_size, so the condition + // "safe_batch < in_batch" would be false + LaunchConvOpImpl(context, cudnn_use_autotune, input_slice, filter, + dilations, strides, padding, explicit_paddings, + data_format, &output_slice); + + // Check for errors from recursive call + if (!context->status().ok()) return; + + // Calculate output elements per batch with validated dimensions + int64_t output_batch_dim = output->dim_size(batch_idx); + int64_t output_elements_per_batch = + output->NumElements() / output_batch_dim; + + // Validate bounds before pointer arithmetic + int64_t output_offset = start * output_elements_per_batch; + OP_REQUIRES( + context, + output_offset + chunk_size * output_elements_per_batch <= + output->NumElements(), + absl::InternalError("Output slice bounds check failed")); + + // Copy output slice to output tensor (device to device) + int64_t output_copy_size_bytes = + chunk_size * output_elements_per_batch * sizeof(T); + auto out_src_ptr = se::DeviceMemoryBase( + const_cast(output_slice.template flat().data()), + output_copy_size_bytes); + auto out_dst_ptr = se::DeviceMemoryBase( + const_cast(output->template flat().data() + output_offset), + output_copy_size_bytes); + OP_REQUIRES_OK(context, stream->MemcpyD2D(&out_dst_ptr, out_src_ptr, + output_copy_size_bytes)); + } + return; + } + } + bool is_grouped_convolution = filter_depth != in_depth; // check if filter is 1x1 and stride/dilation are all ones bool one_filter = true; diff --git a/tensorflow/python/util/mixed_dict_keys_test.py b/tensorflow/python/util/mixed_dict_keys_test.py new file mode 100644 index 00000000000000..d742457ea7f99d --- /dev/null +++ b/tensorflow/python/util/mixed_dict_keys_test.py @@ -0,0 +1,182 @@ +from tensorflow.python.framework import constant_op +from tensorflow.python.util import nest_util +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA JIT compilation with mixed-type dictionary keys. + +This test validates the fix for issue #105333 where @tf.function(jit_compile=True) +fails when returning dictionaries with mixed key types (e.g., strings and integers). +""" + +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +class XLAMixedDictKeysTest(test.TestCase): + """Test XLA JIT compilation with mixed-type dictionary keys.""" + + def test_mixed_string_int_keys_flatten(self): + """Test flattening dict with mixed string and int keys.""" + mixed_dict = {'string_key': 1, 123: 2, 'another': 3, 456: 4} + flattened = nest.flatten(mixed_dict) + # Should flatten successfully with deterministic order + # Keys sorted by type name first (int < str), then by value + self.assertEqual(len(flattened), 4) + self.assertIn(1, flattened) + self.assertIn(2, flattened) + self.assertIn(3, flattened) + self.assertIn(4, flattened) + + def test_mixed_keys_with_xla_simple(self): + """Test simple XLA function with mixed dict keys.""" + @tf.function(jit_compile=True) + def simple_mixed_dict(x): + results = {} + results['string_key'] = x + results[123] = x + 1 + return results + + input_tensor = constant_op.constant([1.0, 2.0, 3.0]) + output = simple_mixed_dict(input_tensor) + + self.assertIn('string_key', output) + self.assertIn(123, output) + self.assertAllClose(output['string_key'], [1.0, 2.0, 3.0]) + self.assertAllClose(output[123], [2.0, 3.0, 4.0]) + + def test_mixed_keys_with_xla_in_model(self): + """Test XLA with mixed dict keys in Keras model (original issue #105333).""" + class SimpleModel(tf.keras.Model): + @tf.function(jit_compile=True) + def call(self, x): + results = {} + results['string_key'] = x + results[123] = x + 1 + return x, results + + model = SimpleModel() + input_tensor = tf.random.normal([2, 16, 16, 16, 32]) + output_tensor, output_dict = model(input_tensor) + + self.assertEqual(output_tensor.shape, (2, 16, 16, 16, 32)) + self.assertIn('string_key', output_dict) + self.assertIn(123, output_dict) + + def test_multiple_mixed_types(self): + """Test dict with multiple mixed key types.""" + @tf.function(jit_compile=True) + def multi_type_dict(x): + results = {} + results['str1'] = x + results[1] = x + 1 + results['str2'] = x + 2 + results[2] = x + 3 + results[3] = x + 4 + results['str3'] = x + 5 + return results + + input_tensor = constant_op.constant(10.0) + output = multi_type_dict(input_tensor) + + # Verify all keys are present + self.assertIn('str1', output) + self.assertIn('str2', output) + self.assertIn('str3', output) + self.assertIn(1, output) + self.assertIn(2, output) + self.assertIn(3, output) + + # Verify values + self.assertAlmostEqual(output['str1'].numpy(), 10.0) + self.assertAlmostEqual(output[1].numpy(), 11.0) + self.assertAlmostEqual(output['str2'].numpy(), 12.0) + self.assertAlmostEqual(output[2].numpy(), 13.0) + + def test_nested_mixed_keys(self): + """Test nested dicts with mixed keys.""" + @tf.function(jit_compile=True) + def nested_mixed_dict(x): + inner = { + 'inner_str': x, + 100: x + 1 + } + outer = { + 'outer': inner, + 200: x + 2 + } + return outer + + input_tensor = constant_op.constant(5.0) + output = nested_mixed_dict(input_tensor) + + self.assertIn('outer', output) + self.assertIn(200, output) + self.assertIn('inner_str', output['outer']) + self.assertIn(100, output['outer']) + + def test_pack_sequence_as_with_mixed_keys(self): + """Test pack_sequence_as with mixed key types.""" + structure = {'a': 1, 10: 2, 'b': 3, 20: 4} + flat_sequence = [100, 200, 300, 400] + + packed = nest.pack_sequence_as(structure, flat_sequence) + + # Verify repacking works correctly + self.assertEqual(len(packed), 4) + # Values should be assigned in sorted key order (int keys first, then str keys) + + def test_without_xla_still_works(self): + """Verify mixed keys work without XLA as well.""" + @tf.function(jit_compile=False) + def no_xla_mixed_dict(x): + results = {} + results['string_key'] = x + results[123] = x + 1 + return results + + input_tensor = constant_op.constant([1.0, 2.0]) + output = no_xla_mixed_dict(input_tensor) + + self.assertIn('string_key', output) + self.assertIn(123, output) + + def test_consistent_ordering(self): + """Ensure consistent ordering across multiple calls.""" + @tf.function(jit_compile=True) + def consistent_dict(x): + results = {} + results['z'] = x + results[3] = x + 1 + results['a'] = x + 2 + results[1] = x + 3 + return results + + input_tensor = constant_op.constant(1.0) + + # Call multiple times and verify same order + output1 = consistent_dict(input_tensor) + output2 = consistent_dict(input_tensor) + output3 = consistent_dict(input_tensor) + + keys1 = sorted(output1.keys(), key=lambda x: (type(x).__name__, x)) + keys2 = sorted(output2.keys(), key=lambda x: (type(x).__name__, x)) + keys3 = sorted(output3.keys(), key=lambda x: (type(x).__name__, x)) + + self.assertEqual(keys1, keys2) + self.assertEqual(keys2, keys3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/util/nest_util.py b/tensorflow/python/util/nest_util.py index 54f8cf1026a0e0..3c6279630aeddb 100644 --- a/tensorflow/python/util/nest_util.py +++ b/tensorflow/python/util/nest_util.py @@ -272,8 +272,14 @@ def _tf_core_sorted(dict_): try: return sorted(dict_.keys()) except TypeError: - # pylint: disable=raise-missing-from - raise TypeError("nest only supports dicts with sortable keys.") + # If direct sorting fails (e.g., mixed types like int and str), + # try sorting by (type name, key) to group by type first, then by value + try: + return sorted(dict_.keys(), key=lambda x: (type(x).__name__, x)) + except TypeError: + # If that still fails, fall back to sorting by string representation + # This ensures deterministic ordering even with complex mixed types + return sorted(dict_.keys(), key=lambda x: (type(x).__name__, str(x))) def _tf_data_sorted(dict_): @@ -281,10 +287,14 @@ def _tf_data_sorted(dict_): try: return sorted(list(dict_)) except TypeError as e: - # pylint: disable=raise-missing-from - raise TypeError( - f"nest only supports dicts with sortable keys. Error: {e.message}" - ) + # If direct sorting fails (e.g., mixed types like int and str), + # try sorting by (type name, key) to group by type first, then by value + try: + return sorted(list(dict_), key=lambda x: (type(x).__name__, x)) + except TypeError: + # If that still fails, fall back to sorting by string representation + # This ensures deterministic ordering even with complex mixed types + return sorted(list(dict_), key=lambda x: (type(x).__name__, str(x))) def yield_value(modality, iterable):