diff --git a/.bazelrc b/.bazelrc index 843c0aac12b80e..a42ff862e855c9 100644 --- a/.bazelrc +++ b/.bazelrc @@ -302,9 +302,11 @@ common:cuda --@local_config_cuda//:enable_cuda common:cuda --config=cuda_version # This flag is needed to include CUDA libraries. common:cuda --@local_config_cuda//cuda:include_cuda_libs=true +common:cuda --@cuda_driver//:include_cuda_umd_libs=true # This configuration is used for building the wheels. common:cuda_wheel --@local_config_cuda//cuda:include_cuda_libs=false +common:cuda_wheel --@cuda_driver//:include_cuda_umd_libs=false # CUDA: This config refers to building CUDA op kernels with clang. common:cuda_clang --config=cuda @@ -596,7 +598,6 @@ common:use_tar_archive_files --repo_env=USE_LLVM_TAR_ARCHIVE_FILES=1 common:use_tar_archive_files --repo_env=USE_MIRRORED_TAR_ARCHIVE_FILES=1 # Make Bazel not try to probe the host system for a C++ toolchain. -common:rbe_base --config=use_tar_archive_files common:rbe_base --config=resultstore common:rbe_base --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 common:rbe_base --define=EXECUTOR=remote @@ -639,8 +640,8 @@ common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instance # Download CUDA/CUDNN redistributions to preserve the repositories cache between # CPU and GPU builds. # TODO(ybaturina): Uncomment when RBE is ready to support this. -commonld:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 -commonld:rbe_linux_cpu --config=cuda_version +common:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 +common:rbe_linux_cpu --config=cuda_version # Deprecated RBE config with non-hermetic toolchains. common:rbe_linux_cpu_clang_local --config=rbe_linux_cpu @@ -666,9 +667,6 @@ common:rbe_linux_cuda --config=cuda_clang_official common:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration common:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -# Enable forward compatibility for CUDA builds because RBE docker image doesn't -# have latest CUDA drivers installed. -common:rbe_linux_cuda --@cuda_driver//:enable_forward_compatibility=true common:rbe_linux_cuda_nvcc --config=rbe_linux_cuda common:rbe_linux_cuda_nvcc --config=cuda_nvcc @@ -861,7 +859,7 @@ test:linux_cpu_wheel_test --@local_xla//third_party/py:wheel_dependency=true --c test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310,-no_oss_py313 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310,-no_oss_py313 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_xla//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --repo_env=HERMETIC_CUDA_UMD_VERSION=12.8.1 --@local_xla//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310,-no_oss_py313 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310,-no_oss_py313 diff --git a/.bazelversion b/.bazelversion index 5c733d6c13a497..26c75fe8ad4fc9 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1,2 +1,2 @@ -7.4.1 +7.7.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index c0682a4cac7035..07896a48470753 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.2.3" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.2.4" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index 75339c6b4f6bd7..e635c4cd8ccc88 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -55,7 +55,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: SARIF file path: results.sarif @@ -64,6 +64,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@3599b3baa15b485a2e49ef411a7a4bb2452e7f93 # v3.29.5 + uses: github/codeql-action/upload-sarif@0499de31b99561a6d14a36a5f662c2a54f91beee # v3.29.5 with: sarif_file: results.sarif diff --git a/.github/workflows/stale-issues.yml b/.github/workflows/stale-issues.yml index d9408810eb32ac..53f272bd5b9d8a 100644 --- a/.github/workflows/stale-issues.yml +++ b/.github/workflows/stale-issues.yml @@ -31,7 +31,7 @@ jobs: pull-requests: write steps: - name: Awaiting response issues - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 + uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: 'override-stale' @@ -59,7 +59,7 @@ jobs: close-pr-message: "This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further." repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Contribution issues - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 + uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: 'override-stale' diff --git a/RELEASE.md b/RELEASE.md index 7ac60de2539cc0..6255a4a1d8679e 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -22,6 +22,10 @@ * `tf.lite` * Adds int8 and int16x8 support for SQRT operator. * Adds int16x8 support for EQUAL and NOT_EQUAL operators. + * Adds support for int2 type. + * Adds support for int2/int4 in tfl.cast . + * Adds support for SRQ int2 in tfl.fully_connected. + * Adds support for int4 in tfl.slice. ### Bug Fixes and Other Changes diff --git a/ci/official/containers/ml_build/Dockerfile b/ci/official/containers/ml_build/Dockerfile index d12c886cc6d57a..a4fb0cd9b1640a 100644 --- a/ci/official/containers/ml_build/Dockerfile +++ b/ci/official/containers/ml_build/Dockerfile @@ -12,14 +12,6 @@ COPY builder.packages.txt /builder.packages.txt RUN /setup.sources.sh && /setup.packages.sh /builder.packages.txt -# Install devtoolset-9 in /dt9 with glibc 2.17 and libstdc++ 4.8, for building -# manylinux2014-compatible packages. -COPY builder.devtoolset/fixlinks.sh /fixlinks.sh -COPY builder.devtoolset/rpm-patch.sh /rpm-patch.sh -COPY builder.devtoolset/build_devtoolset.sh /build_devtoolset.sh -COPY builder.devtoolset/glibc2.17-inline.patch /glibc2.17-inline.patch -RUN /build_devtoolset.sh devtoolset-9 /dt9 - # Setup Python COPY setup.python.sh /setup.python.sh COPY builder.requirements.txt /builder.requirements.txt @@ -56,9 +48,6 @@ RUN ln -sf /usr/bin/python3.12 /usr/bin/python3 RUN ln -sf /usr/bin/python3.12 /usr/bin/python RUN ln -sf /usr/lib/python3.12 /usr/lib/tf_python -# Make sure clang is on the path -RUN ln -s /usr/lib/llvm-18/bin/clang /usr/bin/clang - # Link the compat driver to the location if available. RUN if [ -e "/usr/local/cuda/compat/libcuda.so.1" ]; then ln -s /usr/local/cuda/compat/libcuda.so.1 /usr/lib/x86_64-linux-gnu/libcuda.so.1; fi diff --git a/ci/official/containers/ml_build/builder.packages.txt b/ci/official/containers/ml_build/builder.packages.txt index 8dbbf4196440da..cf914a0425ef11 100644 --- a/ci/official/containers/ml_build/builder.packages.txt +++ b/ci/official/containers/ml_build/builder.packages.txt @@ -1,28 +1,9 @@ -# Packages to be installed for the new Docker image. - -# Packages needed to build devtoolset -file -flex -g++ -make -patch -rpm2cpio -unar -wget -xz-utils -cpio - # Other build-related tools apt-transport-https autoconf automake build-essential ca-certificates -llvm-18 -clang-18 -clang-tidy-18 -lld-18 -clang-format-12 curl git parallel @@ -32,4 +13,6 @@ unzip zip openjdk-21-jdk vim +wget jq +file diff --git a/ci/official/containers/ml_build/builder.requirements.txt b/ci/official/containers/ml_build/builder.requirements.txt index 114efaf9dc9757..ae113c68c2f03c 100644 --- a/ci/official/containers/ml_build/builder.requirements.txt +++ b/ci/official/containers/ml_build/builder.requirements.txt @@ -5,6 +5,9 @@ id urllib3 requests +# For XLA +pyyaml + # For JAX build ~= 1.2.2 # uv is faster than pip for installing Python packages. diff --git a/ci/official/containers/ml_build/cuda13.0_cudnn9.15.packages.txt b/ci/official/containers/ml_build/cuda13.0_cudnn9.15.packages.txt new file mode 100644 index 00000000000000..dcc171ac5af019 --- /dev/null +++ b/ci/official/containers/ml_build/cuda13.0_cudnn9.15.packages.txt @@ -0,0 +1,23 @@ +# All required CUDA packages +cuda-compat-13-0 +cuda-command-line-tools-13-0 +cuda-cudart-dev-13-0 +cuda-nvcc-13-0 +cuda-cupti-13-0 +cuda-nvprune-13-0 +cuda-libraries-13-0 +cuda-libraries-dev-13-0 +cuda-nvml-dev-13-0 +libcufft-13-0 +libcurand-13-0 +libcusolver-dev-13-0 +libcusparse-dev-13-0 +libcublas-13-0 +libcublas-dev-13-0 +libnccl-dev=2.27.7-1+cuda13.0 +libnccl2=2.27.7-1+cuda13.0 +# CuDNN: https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#ubuntu-network-installation +libcudnn9-headers-cuda-13=9.15.1.9-1 +libcudnn9-static-cuda-13=9.15.1.9-1 +libcudnn9-dev-cuda-13=9.15.1.9-1 +libcudnn9-cuda-13=9.15.1.9-1 \ No newline at end of file diff --git a/ci/official/containers/ml_build/setup.python.sh b/ci/official/containers/ml_build/setup.python.sh index cd56f3ca552d0f..b849457420f522 100755 --- a/ci/official/containers/ml_build/setup.python.sh +++ b/ci/official/containers/ml_build/setup.python.sh @@ -45,16 +45,6 @@ fi /setup.packages.sh pythons.txt -# Re-link pyconfig.h from x86_64-linux-gnu into the devtoolset directory -# for any Python version present -pushd /usr/include/x86_64-linux-gnu -for f in $(ls | grep python); do - # set up symlink for devtoolset-9 - rm -f /dt9/usr/include/x86_64-linux-gnu/$f - ln -s /usr/include/x86_64-linux-gnu/$f /dt9/usr/include/x86_64-linux-gnu/$f -done -popd - # Python 3.10 include headers fix: # sysconfig.get_path('include') incorrectly points to /usr/local/include/python # map /usr/include/python3.10 to /usr/local/include/python3.10 diff --git a/ci/official/envs/linux_arm64 b/ci/official/envs/linux_arm64 index 52aa80518b4b9c..026cc1bee85bf7 100644 --- a/ci/official/envs/linux_arm64 +++ b/ci/official/envs/linux_arm64 @@ -28,5 +28,5 @@ TFCI_OUTPUT_DIR=build_output TFCI_WHL_AUDIT_ENABLE=1 TFCI_WHL_AUDIT_PLAT=manylinux2014_aarch64 TFCI_WHL_BAZEL_TEST_ENABLE=1 -TFCI_WHL_SIZE_LIMIT=265M +TFCI_WHL_SIZE_LIMIT=270M TFCI_WHL_SIZE_LIMIT_ENABLE=1 diff --git a/ci/official/envs/windows_x86_2022 b/ci/official/envs/windows_x86_2022 index 56187ad78eca17..3c57bcfb8114ee 100644 --- a/ci/official/envs/windows_x86_2022 +++ b/ci/official/envs/windows_x86_2022 @@ -15,7 +15,7 @@ TFCI_DOCKER_ENABLE=1 TFCI_DOCKER_PULL_ENABLE=1 TFCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2022@sha256:915cb093630432c38b028f56bd31116a5559ebbc688d427b6092d86828ae03bc" -TFCI_BAZEL_BAZELRC_ARGS="--output_user_root=C:/t" +TFCI_BAZEL_BAZELRC_ARGS="--output_user_root=C:/x" TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config=windows_x86_cpu_2022" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=windows_x86_cpu_2022 TFCI_BUILD_PIP_PACKAGE_WHEEL_NAME_ARG="--repo_env=WHEEL_NAME=tensorflow" diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements.in b/ci/official/requirements_updater/numpy1_requirements/requirements.in index c6a88054433ec0..a24dc1a57e3683 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements.in +++ b/ci/official/requirements_updater/numpy1_requirements/requirements.in @@ -1,7 +1,7 @@ # Requirements for NumPy 1.x numpy ~= 1.26.0 wheel ~= 0.41.2 -h5py >= 3.11.0 +h5py >= 3.11.0, < 3.15.0 lit ~= 17.0.2 opt_einsum == 3.3.0 astunparse == 1.6.3 diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 2a1fb43664c408..86d5526834753f 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -1,7 +1,7 @@ # Note that numpy 2.1.0 does not support python 3.9 numpy >= 2.0.0, < 2.2.0 wheel ~= 0.41.2 -h5py >= 3.11.0 +h5py >= 3.11.0, < 3.15.0 lit ~= 17.0.2 opt_einsum == 3.3.0 astunparse == 1.6.3 diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh index d928272d5ae1a3..03f49d85797225 100755 --- a/ci/official/utilities/setup_docker.sh +++ b/ci/official/utilities/setup_docker.sh @@ -62,6 +62,12 @@ if ! docker container inspect tf >/dev/null 2>&1 ; then # Additional setup is contained in ci/official/envs/rbe. CONTAINER_IP_ADDR=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' tf) netsh advfirewall firewall add rule name="Allow Metadata Proxy" dir=in action=allow protocol=TCP localport=80 remoteip="$CONTAINER_IP_ADDR" + + # Stop non-essential indexing and link tracking services that + # may lock new files or symlinks. + # They may be causing sporadic "Permission denied" errors during Bazel builds. + # b/461500885 + docker exec tf powershell -NoProfile -Command 'Stop-Service -Name SysMain,DiagTrack -Force -ErrorAction SilentlyContinue' fi fi diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f000821983b779..558b59368e615b 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1033,6 +1033,7 @@ package_group( "//tensorflow_models/google/recml/...", "//third_party/cloud_tpu/convergence_tools/sdc_monitoring/...", "//third_party/cloud_tpu/inference_converter/...", + "//third_party/pathways/...", "//third_party/py/cloud_ml_autoflow/...", "//third_party/py/envlogger/...", "//third_party/py/gldm/...", @@ -1180,38 +1181,31 @@ tf_cc_shared_library( linkstatic = 1, per_os_targets = True, roots = [ - "//tensorflow/c/experimental/filesystem:filesystem_interface", - "//tensorflow/c/experimental/stream_executor:stream_executor", - "//tensorflow/c:env", - "//tensorflow/c:kernels", - "//tensorflow/c:kernels_experimental", - "//tensorflow/c:logging", - "//tensorflow/c:ops", - "//tensorflow/cc/saved_model:fingerprinting_impl", - "//tensorflow/cc/saved_model:loader_lite_impl", - "//tensorflow/cc/saved_model:metrics_impl", - "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl", - "//tensorflow/core/common_runtime:core_cpu_impl", - "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", - "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl", - "//tensorflow/core:framework_internal_impl", - "//tensorflow/core/framework:tensor", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", - "//tensorflow/core:lib_internal_impl", - "//tensorflow/core/profiler:profiler_impl", - "//tensorflow/core/util:determinism", # Must be linked and exported to libtensorflow_framework.so. - "//tensorflow/lite/kernels/shim:tf_kernel_shim", - "@local_xla//xla/stream_executor:stream_executor_impl", - "@local_xla//xla/tsl/framework:bfc_allocator", - "@local_xla//xla/tsl/framework:metrics", - ] + tf_additional_binary_deps() + - # TODO(b/259305727): Remove this select and include captured_function in macos builds. - select({ - "//tensorflow:macos": [], - "//conditions:default": [ - "//tensorflow/core/data:captured_function", - ], - }), + "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/c/experimental/stream_executor:stream_executor", + "//tensorflow/c:env", + "//tensorflow/c:kernels", + "//tensorflow/c:kernels_experimental", + "//tensorflow/c:ops", + "//tensorflow/cc/saved_model:fingerprinting_impl", + "//tensorflow/cc/saved_model:loader_lite_impl", + "//tensorflow/cc/saved_model:metrics_impl", + "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl", + "//tensorflow/core/common_runtime:core_cpu_impl", + "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", + "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl", + "//tensorflow/core:framework_internal_impl", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", + "//tensorflow/core:lib_internal_impl", + "//tensorflow/core/profiler:profiler_impl", + "//tensorflow/core/util:determinism", # Must be linked and exported to libtensorflow_framework.so. + "//tensorflow/lite/kernels/shim:tf_kernel_shim", + "@local_xla//xla/stream_executor:stream_executor_impl", + "@local_xla//xla/tsl/framework:bfc_allocator", + "@local_xla//xla/tsl/framework:metrics", + "//tensorflow/core/data:captured_function", + ] + tf_additional_binary_deps(), soversion = VERSION, static_deps = PACKAGE_STATIC_DEPS, visibility = ["//visibility:public"], diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 726433bafded24..3f4ec98028e8c3 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -298,7 +298,6 @@ tf_cuda_library( ], "//conditions:default": [ ":env", - ":logging", ":tf_status", ":tf_tensor", "//tensorflow/c/experimental/filesystem:modular_filesystem", @@ -325,18 +324,6 @@ tf_cuda_library( alwayslink = 1, ) -cc_library( - name = "logging", - srcs = ["logging.cc"], - hdrs = ["logging.h"], - visibility = ["//visibility:public"], - deps = [ - ":c_api_macros", - "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:stringprintf", - ], -) - tf_cuda_library( name = "tf_status_internal", hdrs = [ diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index b919be52b0bf68..4dd78e4cd7bbb1 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1171,7 +1171,7 @@ TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 " "from function 'MyFunc'"), - string(TF_Message(s_))); + std::string(TF_Message(s_))); } TEST_F(CApiFunctionTest, NodeMissingInput) { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index e3e7d812b15838..f59a73a0871945 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -2478,7 +2478,7 @@ TEST_F(CApiAttributesTest, Names) { TF_OperationGetAttrName(oper, 0, value.get(), s_); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ("v", string(static_cast(value.get()), 1)); + EXPECT_EQ("v", std::string(static_cast(value.get()), 1)); } TEST_F(CApiAttributesTest, Errors) { diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 97a5bbd4b6077a..9dae0d3afd46fe 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -119,8 +119,7 @@ CheckpointReader::BuildV2VarMaps() { BundleEntryProto entry; v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - CHECK(entry.ParseFromArray(v2_reader_->value().data(), - v2_reader_->value().size())) + CHECK(entry.ParseFromString(v2_reader_->value())) << entry.InitializationErrorString(); for (int i = 0; i < entry.slices_size(); ++i) { const auto& slice_proto = entry.slices(i); @@ -140,8 +139,7 @@ CheckpointReader::BuildV2VarMaps() { v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { if (filtered_keys.count(string(v2_reader_->key())) > 0) continue; - CHECK(entry.ParseFromArray(v2_reader_->value().data(), - v2_reader_->value().size())) + CHECK(entry.ParseFromString(v2_reader_->value())) << entry.InitializationErrorString(); string key(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index ccde2ba3d9b769..91f83b3f88967d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -939,7 +939,8 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, const char* serialized_function_def, size_t size, TF_Status* status) { tensorflow::FunctionDef function_def; - if (!function_def.ParseFromArray(serialized_function_def, size)) { + if (!function_def.ParseFromString( + absl::string_view(serialized_function_def, size))) { status->status = tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; diff --git a/tensorflow/c/eager/c_api_experimental_reader.cc b/tensorflow/c/eager/c_api_experimental_reader.cc index 0959580a10438b..e93469bd4c1cfd 100644 --- a/tensorflow/c/eager/c_api_experimental_reader.cc +++ b/tensorflow/c/eager/c_api_experimental_reader.cc @@ -1,6 +1,6 @@ /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License");; +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/tensorflow/c/eager/c_api_experimental_reader.h b/tensorflow/c/eager/c_api_experimental_reader.h index 71c2e4650f0520..d8bc2f6c65716b 100644 --- a/tensorflow/c/eager/c_api_experimental_reader.h +++ b/tensorflow/c/eager/c_api_experimental_reader.h @@ -1,6 +1,6 @@ /* Copyright 2023 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License");; +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 0802cc46267f66..d96de81bfa4365 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -177,5 +177,6 @@ tf_cc_test( "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/platform:strcat", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc index a231fc74033fdd..fcdbd4ea9c2a2f 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc index 03dd862f95cb0f..7d25709df2dfc7 100644 --- a/tensorflow/c/env.cc +++ b/tensorflow/c/env.cc @@ -34,7 +34,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" struct TF_StringStream { - std::vector<::tensorflow::string>* list; + std::vector* list; size_t position; }; @@ -134,7 +134,7 @@ void TF_StringStreamDone(TF_StringStream* list) { delete list; } TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { - auto* children = new std::vector<::tensorflow::string>; + auto* children = new std::vector; TF_SetStatus(status, TF_OK, ""); ::tensorflow::Set_TF_Status_from_Status( @@ -147,7 +147,7 @@ TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { } TF_StringStream* TF_GetLocalTempDirectories() { - auto* tmpdirs = new std::vector<::tensorflow::string>; + auto* tmpdirs = new std::vector; ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs); diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc index d4c9bfce3c2127..3d338d4377366b 100644 --- a/tensorflow/c/env_test.cc +++ b/tensorflow/c/env_test.cc @@ -35,14 +35,12 @@ TEST(TestEnv, TestDirHandling) { TF_Status* s = TF_NewStatus(); - ::tensorflow::string dirpath = - ::tensorflow::io::JoinPath(tempdir, "somedir"); + std::string dirpath = ::tensorflow::io::JoinPath(tempdir, "somedir"); TF_CreateDir(dirpath.c_str(), s); ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": " << TF_Message(s); - ::tensorflow::string filepath = - ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); + std::string filepath = ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); TF_WritableFileHandle* handle; TF_NewWritableFile(filepath.c_str(), &handle, s); ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": " @@ -61,7 +59,7 @@ TEST(TestEnv, TestDirHandling) { ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath; const char* childpath; ASSERT_TRUE(TF_StringStreamNext(children, &childpath)); - ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt"); + ASSERT_EQ(std::string(childpath), "somefile.txt"); // There should only be one file in this directory. ASSERT_FALSE(TF_StringStreamNext(children, &childpath)); ASSERT_EQ(childpath, nullptr); diff --git a/tensorflow/c/experimental/filesystem/BUILD b/tensorflow/c/experimental/filesystem/BUILD index 1f3f66b36681a0..ec446fd8389687 100644 --- a/tensorflow/c/experimental/filesystem/BUILD +++ b/tensorflow/c/experimental/filesystem/BUILD @@ -49,6 +49,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@local_xla//xla/tsl/platform:env", "@local_xla//xla/tsl/platform:errors", ], diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.h b/tensorflow/c/experimental/filesystem/modular_filesystem.h index b8482bbdb4f85d..5a8c4ba3ccb56c 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.h +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "xla/tsl/platform/file_system.h" #include "tensorflow/core/platform/file_statistics.h" diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index 8fa3e726e6a837..f0f6e5351372e1 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -31,10 +31,10 @@ cc_library( ":gcs_helper", ":ram_file_block_cache", "//tensorflow/c:env", - "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud:google_cloud_cpp_common", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", @@ -65,7 +65,6 @@ cc_library( deps = [ ":cleanup", "//tensorflow/c:env", - "//tensorflow/c:logging", "//tensorflow/c:tf_status", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -86,6 +85,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/platform/cloud:now_seconds_env", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc index b0d283fff82d9b..e639f9a7dda476 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc @@ -27,7 +27,7 @@ namespace tensorflow { namespace { TEST(ExpiringLRUCacheTest, MaxAge) { - const string key = "a"; + const std::string key = "a"; std::unique_ptr env(new NowSecondsEnv); tf_gcs_filesystem::ExpiringLRUCache cache( 1, 0, [&env]() { return env->NowSeconds(); }); @@ -95,9 +95,10 @@ TEST(ExpiringLRUCacheTest, MaxEntries) { TEST(ExpiringLRUCacheTest, LookupOrCompute) { // max_age of 0 means we should always compute. - uint64 num_compute_calls = 0; + uint64_t num_compute_calls = 0; tf_gcs_filesystem::ExpiringLRUCache::ComputeFunc compute_func = - [&num_compute_calls](const string& key, int* value, TF_Status* status) { + [&num_compute_calls](const std::string& key, int* value, + TF_Status* status) { *value = num_compute_calls; num_compute_calls++; return TF_SetStatus(status, TF_OK, ""); diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index 3b9650b7416315..f61208c7b4a174 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -40,7 +40,6 @@ limitations under the License. #include "google/cloud/storage/client.h" #include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" -#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" // Implementation of a filesystem for GCS environments. diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h index 0060abc76699c3..3e972fa6292995 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h @@ -33,7 +33,6 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "tensorflow/c/env.h" -#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" namespace tf_gcs_filesystem { diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc index 4ad4a8ea1868f3..23645ed8e878bf 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h" #include -#include #include #include #include @@ -25,6 +24,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" @@ -39,7 +39,7 @@ namespace tensorflow { namespace { absl::Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache, - const string& filename, size_t offset, size_t n, + const std::string& filename, size_t offset, size_t n, std::vector* out) { out->clear(); out->resize(n, 0); @@ -54,7 +54,7 @@ absl::Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache, } TEST(RamFileBlockCacheTest, IsCacheEnabled) { - auto fetcher = [](const string& filename, size_t offset, size_t n, + auto fetcher = [](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { // Do nothing. TF_SetStatus(status, TF_OK, ""); @@ -73,14 +73,14 @@ TEST(RamFileBlockCacheTest, IsCacheEnabled) { TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); TF_SetStatus(status, TF_OK, ""); return n; }; - string filename = "file"; + std::string filename = "file"; tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher); std::vector out; @@ -101,12 +101,12 @@ TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) { } TEST(RamFileBlockCacheTest, PassThrough) { - const string want_filename = "foo/bar"; + const std::string want_filename = "foo/bar"; const size_t want_offset = 42; const size_t want_n = 1024; int calls = 0; auto fetcher = [&calls, want_filename, want_offset, want_n]( - const string& got_filename, size_t got_offset, + const std::string& got_filename, size_t got_offset, size_t got_n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(got_filename, want_filename); EXPECT_EQ(got_offset, want_offset); @@ -143,7 +143,7 @@ TEST(RamFileBlockCacheTest, BlockAlignment) { buf.push_back(i); } // The fetcher just fetches slices of the buffer. - auto fetcher = [&buf](const string& filename, size_t offset, size_t n, + auto fetcher = [&buf](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { int64_t bytes_transferred; if (offset < buf.size()) { @@ -191,8 +191,8 @@ TEST(RamFileBlockCacheTest, BlockAlignment) { TEST(RamFileBlockCacheTest, CacheHits) { const size_t block_size = 16; std::set calls; - auto fetcher = [&calls, block_size](const string& filename, size_t offset, - size_t n, char* buffer, + auto fetcher = [&calls, block_size](const std::string& filename, + size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); @@ -202,7 +202,7 @@ TEST(RamFileBlockCacheTest, CacheHits) { TF_SetStatus(status, TF_OK, ""); return n; }; - const uint32 block_count = 256; + const uint32_t block_count = 256; tf_gcs_filesystem::RamFileBlockCache cache( block_size, block_count * block_size, 0, fetcher); std::vector out; @@ -225,7 +225,7 @@ TEST(RamFileBlockCacheTest, OutOfRange) { bool first_block = false; bool second_block = false; auto fetcher = [block_size, file_size, &first_block, &second_block]( - const string& filename, size_t offset, size_t n, + const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); @@ -269,8 +269,9 @@ TEST(RamFileBlockCacheTest, Inconsistent) { // where we expected complete blocks. const size_t block_size = 16; // This fetcher returns OK but only fills in one byte for any offset. - auto fetcher = [block_size](const string& filename, size_t offset, size_t n, - char* buffer, TF_Status* status) -> int64_t { + auto fetcher = [block_size](const std::string& filename, size_t offset, + size_t n, char* buffer, + TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); EXPECT_GE(n, 1); @@ -293,8 +294,8 @@ TEST(RamFileBlockCacheTest, Inconsistent) { TEST(RamFileBlockCacheTest, LRU) { const size_t block_size = 16; std::list calls; - auto fetcher = [&calls, block_size](const string& filename, size_t offset, - size_t n, char* buffer, + auto fetcher = [&calls, block_size](const std::string& filename, + size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_FALSE(calls.empty()) << "at offset = " << offset; @@ -306,7 +307,7 @@ TEST(RamFileBlockCacheTest, LRU) { TF_SetStatus(status, TF_OK, ""); return n; }; - const uint32 block_count = 2; + const uint32_t block_count = 2; tf_gcs_filesystem::RamFileBlockCache cache( block_size, block_count * block_size, 0, fetcher); std::vector out; @@ -342,7 +343,7 @@ TEST(RamFileBlockCacheTest, LRU) { TEST(RamFileBlockCacheTest, MaxStaleness) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); @@ -386,13 +387,13 @@ TEST(RamFileBlockCacheTest, MaxStaleness) { TEST(RamFileBlockCacheTest, RemoveFile) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x'; if (offset > 0) { // The first block is lower case and all subsequent blocks are upper case. - c = toupper(c); + c = absl::ascii_toupper(c); } memset(buffer, c, n); TF_SetStatus(status, TF_OK, ""); @@ -448,7 +449,7 @@ TEST(RamFileBlockCacheTest, RemoveFile) { TEST(RamFileBlockCacheTest, Prune) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); @@ -458,7 +459,7 @@ TEST(RamFileBlockCacheTest, Prune) { std::vector out; // Our fake environment is initialized with the current timestamp. std::unique_ptr env(new NowSecondsEnv); - uint64 now = Env::Default()->NowSeconds(); + uint64_t now = Env::Default()->NowSeconds(); env->SetNowSeconds(now); tf_gcs_filesystem::RamFileBlockCache cache( 8, 32, 1 /* max staleness */, fetcher, @@ -487,7 +488,7 @@ TEST(RamFileBlockCacheTest, Prune) { // timestamp of `now` + 2, file "a" is stale because its first block is stale, // but file "b" is not stale yet. Thus, once the pruning thread wakes up (in // one second of wall time), it should remove "a" and leave "b" alone. - uint64 start = Env::Default()->NowSeconds(); + uint64_t start = Env::Default()->NowSeconds(); do { Env::Default()->SleepForMicroseconds(100000); } while (cache.CacheSize() == 24 && Env::Default()->NowSeconds() - start < 3); @@ -515,7 +516,7 @@ TEST(RamFileBlockCacheTest, ParallelReads) { absl::BlockingCounter counter(callers); absl::Notification notification; auto fetcher = [&counter, ¬ification]( - const string& filename, size_t offset, size_t n, + const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { if (counter.DecrementCount()) { notification.Notify(); @@ -560,7 +561,7 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { int num_requests = 0; absl::Notification notification; auto fetcher = [&num_requests, ¬ification, block_size]( - const string& filename, size_t offset, size_t n, + const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset, 0); @@ -591,7 +592,7 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { TEST(RamFileBlockCacheTest, Flush) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD index 20bc4a080f30ee..c0ae70b64abec7 100644 --- a/tensorflow/c/experimental/gradients/tape/BUILD +++ b/tensorflow/c/experimental/gradients/tape/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/core/platform:strcat", "//tensorflow/core/platform:stringpiece", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_xla//xla/tsl/platform:errors", ], diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.cc b/tensorflow/c/experimental/gradients/tape/tape_operation.cc index 7cd3acffbc9cec..2839616c63991b 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_operation.cc +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" diff --git a/tensorflow/c/experimental/grappler/grappler_test.cc b/tensorflow/c/experimental/grappler/grappler_test.cc index 32ac04832551c1..205aeec55ebf8c 100644 --- a/tensorflow/c/experimental/grappler/grappler_test.cc +++ b/tensorflow/c/experimental/grappler/grappler_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -70,11 +71,11 @@ TEST(Grappler, SuccessfulRegistration) { TF_ASSERT_OK(InitGraphPlugin(plugin_init)); ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers( - std::set{"Success"}) + std::set{"Success"}) .size(), 1); ConfigList config = PluginGraphOptimizerRegistry::GetPluginConfigs( - true, std::set{"Success"}); + true, std::set{"Success"}); ASSERT_EQ(config.toggle_config["remapping"], RewriterConfig::OFF); } @@ -95,7 +96,7 @@ TEST(Grappler, MultiplePluginRegistration) { TF_ASSERT_OK(InitGraphPlugin(plugin_init_0)); TF_ASSERT_OK(InitGraphPlugin(plugin_init_1)); ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers( - std::set{"Device0", "Device1"}) + std::set{"Device0", "Device1"}) .size(), 2); } @@ -132,12 +133,12 @@ TEST(Grappler, OptimizeFuncNotSet) { TEST(TF_GrapplerItem, NodesToPreserve) { GrapplerItem item; - item.fetch = std::vector{"Conv", "BiasAdd"}; - std::unordered_set nodes_preserved = item.NodesToPreserve(); + item.fetch = std::vector{"Conv", "BiasAdd"}; + std::unordered_set nodes_preserved = item.NodesToPreserve(); TF_GrapplerItem* c_item = reinterpret_cast(&item); int list_total_size = 0; - for (const string& s : nodes_preserved) { + for (const std::string& s : nodes_preserved) { list_total_size += s.size(); } @@ -158,20 +159,21 @@ TEST(TF_GrapplerItem, NodesToPreserve) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); for (size_t i = 0; i < nodes_preserved.size(); ++i) { - EXPECT_EQ(nodes_preserved.find(string(static_cast(values[i]), - lens[i])) != nodes_preserved.end(), - true); + EXPECT_EQ( + nodes_preserved.find(std::string(static_cast(values[i]), + lens[i])) != nodes_preserved.end(), + true); } TF_DeleteStatus(status); } TEST(TF_GrapplerItem, FetchNodes) { GrapplerItem item; - item.fetch = std::vector{"Conv", "BiasAdd"}; + item.fetch = std::vector{"Conv", "BiasAdd"}; TF_GrapplerItem* c_item = reinterpret_cast(&item); int list_total_size = 0; - for (const string& s : item.fetch) { + for (const std::string& s : item.fetch) { list_total_size += s.size(); } @@ -193,7 +195,7 @@ TEST(TF_GrapplerItem, FetchNodes) { for (size_t i = 0; i < item.fetch.size(); ++i) { EXPECT_EQ(item.fetch[i].size(), lens[i]) << i; EXPECT_EQ(item.fetch[i], - string(static_cast(values[i]), lens[i])) + std::string(static_cast(values[i]), lens[i])) << i; } TF_DeleteStatus(status); @@ -307,13 +309,13 @@ TEST(TF_FunctionLibraryDefinition, LookUpOpDef) { TF_NewFunctionLibraryDefinition(g_buf, status); TF_LookUpOpDef(func, "Add", op_buf, status); - string actual_string(reinterpret_cast(op_buf->data), - op_buf->length); + std::string actual_string(reinterpret_cast(op_buf->data), + op_buf->length); ASSERT_EQ(TF_OK, TF_GetCode(status)); const OpDef* expected_op_def; TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def)); - string expected_serialized; + std::string expected_serialized; expected_op_def->SerializeToString(&expected_serialized); EXPECT_EQ(expected_serialized, actual_string); TF_DeleteBuffer(g_buf); diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 348f5c5d6d0341..f4a57a7d265420 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -33,10 +33,10 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_xla//xla/pjrt:pjrt_c_api_client", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", "@local_xla//xla/pjrt/c:pjrt_c_api_helpers", + "@local_xla//xla/pjrt/c_api_client:pjrt_c_api_client", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", ], ) @@ -70,9 +70,9 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_xla//xla/pjrt:pjrt_c_api_client", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@local_xla//xla/pjrt/c_api_client:pjrt_c_api_client", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:statusor", ], @@ -96,10 +96,10 @@ tf_cc_test( "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/pjrt:pjrt_api", - "@local_xla//xla/pjrt:pjrt_c_api_client", "@local_xla//xla/pjrt/c:pjrt_c_api_cpu", "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", "@local_xla//xla/pjrt/c:pjrt_c_api_wrapper_impl", + "@local_xla//xla/pjrt/c_api_client:pjrt_c_api_client", "@local_xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@local_xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@local_xla//xla/tsl/lib/core:status_test_util", diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.cc b/tensorflow/c/experimental/next_pluggable_device/c_api.cc index fdb8a9e7f47794..569e7d0eed0ca4 100644 --- a/tensorflow/c/experimental/next_pluggable_device/c_api.cc +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.cc @@ -40,7 +40,7 @@ limitations under the License. #include "tensorflow/compiler/jit/variable_info_util.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" -#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/c_api_client/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h" diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc index 5344db87abcae0..4df0e5d336273f 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/pjrt_tensor_buffer_util.h" #include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/c_api_client/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h index c2378b68109fc9..24fc0cc20d3c3a 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/c_api_client/pjrt_c_api_client.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc index 7220877fad0ed8..c5d2b18dac36aa 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_cpu.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" +#include "xla/pjrt/c_api_client/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_api.h" -#include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/shape.h" diff --git a/tensorflow/c/experimental/ops/gen/common/case_format.cc b/tensorflow/c/experimental/ops/gen/common/case_format.cc index 82acc32f623fd8..1992357201af18 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format.cc +++ b/tensorflow/c/experimental/ops/gen/common/case_format.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/case_format.h" -#include +#include #include "absl/strings/ascii.h" #include "tensorflow/core/platform/types.h" @@ -31,14 +31,14 @@ enum CaseFormatType { UPPER_SNAKE, }; -string FormatStringCase(const string &str, CaseFormatType to, - const char delimiter = '_') { +std::string FormatStringCase(const std::string& str, CaseFormatType to, + const char delimiter = '_') { const bool from_snake = (str == absl::AsciiStrToUpper(str)) || (str == absl::AsciiStrToLower(str)); const bool toUpper = (to == UPPER_CAMEL || to == UPPER_SNAKE); const bool toSnake = (to == LOWER_SNAKE || to == UPPER_SNAKE); - string result; + std::string result; bool inputStart = true; bool wordStart = true; @@ -52,7 +52,7 @@ string FormatStringCase(const string &str, CaseFormatType to, wordStart = true; continue; } - if (!from_snake && isupper(c)) { + if (!from_snake && absl::ascii_isupper(c)) { wordStart = true; } @@ -65,9 +65,9 @@ string FormatStringCase(const string &str, CaseFormatType to, const bool shouldCapIfSnake = toUpper; const bool shouldCapIfCamel = wordStart && (toUpper || !inputStart); if ((toSnake && shouldCapIfSnake) || (!toSnake && shouldCapIfCamel)) { - result += toupper(c); + result += absl::ascii_toupper(c); } else { - result += tolower(c); + result += absl::ascii_tolower(c); } // at this point we are no longer at the start of a word: @@ -90,16 +90,16 @@ string FormatStringCase(const string &str, CaseFormatType to, // Public interface // -string toLowerCamel(const string &s, const char delimiter) { +std::string toLowerCamel(const std::string& s, const char delimiter) { return FormatStringCase(s, LOWER_CAMEL, delimiter); } -string toLowerSnake(const string &s, const char delimiter) { +std::string toLowerSnake(const std::string& s, const char delimiter) { return FormatStringCase(s, LOWER_SNAKE, delimiter); } -string toUpperCamel(const string &s, const char delimiter) { +std::string toUpperCamel(const std::string& s, const char delimiter) { return FormatStringCase(s, UPPER_CAMEL, delimiter); } -string toUpperSnake(const string &s, const char delimiter) { +std::string toUpperSnake(const std::string& s, const char delimiter) { return FormatStringCase(s, UPPER_SNAKE, delimiter); } diff --git a/tensorflow/c/experimental/ops/gen/common/case_format.h b/tensorflow/c/experimental/ops/gen/common/case_format.h index f8255f6aa21c17..880f286788e0a2 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format.h +++ b/tensorflow/c/experimental/ops/gen/common/case_format.h @@ -35,10 +35,10 @@ namespace generator { // "__OneTwo__" (in camel case) <==> "__ONE_TWO__" (in snake case) // // Note: performance not yet tested. -string toLowerCamel(const string &s, const char delimiter = '_'); -string toLowerSnake(const string &s, const char delimiter = '_'); -string toUpperCamel(const string &s, const char delimiter = '_'); -string toUpperSnake(const string &s, const char delimiter = '_'); +std::string toLowerCamel(const std::string& s, const char delimiter = '_'); +std::string toLowerSnake(const std::string& s, const char delimiter = '_'); +std::string toUpperCamel(const std::string& s, const char delimiter = '_'); +std::string toUpperSnake(const std::string& s, const char delimiter = '_'); } // namespace generator } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/gen/common/case_format_test.cc b/tensorflow/c/experimental/ops/gen/common/case_format_test.cc index 302bcc42453169..e769acb94bff73 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format_test.cc +++ b/tensorflow/c/experimental/ops/gen/common/case_format_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/case_format.h" +#include + #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -25,13 +27,13 @@ namespace { // For each test case, we manually construct the 4 variations in string case and // test all 16 conversions: from and to each of the 4 string case variations. struct Variations { - string lower_camel; - string lower_snake; - string upper_camel; - string upper_snake; + std::string lower_camel; + std::string lower_snake; + std::string upper_camel; + std::string upper_snake; }; -void TestSingleVariation(const string &str, Variations expected, +void TestSingleVariation(const std::string& str, Variations expected, char delimiter = '_') { EXPECT_EQ(expected.lower_camel, toLowerCamel(str, delimiter)); EXPECT_EQ(expected.lower_snake, toLowerSnake(str, delimiter)); diff --git a/tensorflow/c/experimental/ops/gen/common/controller.cc b/tensorflow/c/experimental/ops/gen/common/controller.cc index fb3e321714b108..7c9bf279fdcd2a 100644 --- a/tensorflow/c/experimental/ops/gen/common/controller.cc +++ b/tensorflow/c/experimental/ops/gen/common/controller.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/controller.h" +#include #include #include "absl/log/check.h" @@ -43,7 +44,7 @@ Controller::Controller(PathConfig path_config, Env* env) } Controller::~Controller() { delete api_def_map_; } -const void Controller::WriteFile(const string& file_path, +const void Controller::WriteFile(const std::string& file_path, const SourceCode& code) const { TF_CHECK_OK(WriteStringToFile(env_, file_path, code.Render())) << file_path; } @@ -60,8 +61,9 @@ void Controller::InitializeOpApi() { api_def_map_ = new ApiDefMap(op_list_); for (const auto& op : op_list_.op()) { for (const auto& dir : path_config_.api_dirs) { - const string file_name = absl::Substitute("api_def_$0.pbtxt", op.name()); - const string file_path = io::JoinPath(dir, file_name); + const std::string file_name = + absl::Substitute("api_def_$0.pbtxt", op.name()); + const std::string file_path = io::JoinPath(dir, file_name); if (env_->FileExists(file_path).ok()) { TF_CHECK_OK(api_def_map_->LoadFile(env_, file_path)) << file_path; } else { diff --git a/tensorflow/c/experimental/ops/gen/common/controller.h b/tensorflow/c/experimental/ops/gen/common/controller.h index e152efeb6d8f9f..c33891f963d7a6 100644 --- a/tensorflow/c/experimental/ops/gen/common/controller.h +++ b/tensorflow/c/experimental/ops/gen/common/controller.h @@ -32,7 +32,8 @@ class Controller { public: explicit Controller(PathConfig path_config, Env* env = Env::Default()); virtual ~Controller(); - const void WriteFile(const string& file_path, const SourceCode& code) const; + const void WriteFile(const std::string& file_path, + const SourceCode& code) const; const std::vector& GetModelOps() const; private: diff --git a/tensorflow/c/experimental/ops/gen/common/path_config.cc b/tensorflow/c/experimental/ops/gen/common/path_config.cc index 2ec57d67c9d6f7..6de98c242b1afa 100644 --- a/tensorflow/c/experimental/ops/gen/common/path_config.cc +++ b/tensorflow/c/experimental/ops/gen/common/path_config.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/common/path_config.h" #include +#include #include #include "absl/strings/str_join.h" @@ -24,9 +25,10 @@ limitations under the License. namespace tensorflow { namespace generator { -PathConfig::PathConfig(const string& output_dir, const string& source_dir, - const string& api_dir_list, - const std::vector op_names) +PathConfig::PathConfig(const std::string& output_dir, + const std::string& source_dir, + const std::string& api_dir_list, + const std::vector op_names) : output_path(output_dir), op_names(op_names) { api_dirs = str_util::Split(api_dir_list, ",", str_util::SkipEmpty()); @@ -39,7 +41,7 @@ PathConfig::PathConfig(const string& output_dir, const string& source_dir, tf_root_dir = "tensorflow"; // Prefix, e.g. "third_party" given root_dir "third_party/tensorflow/...." - std::vector source_path_components = + std::vector source_path_components = tensorflow::str_util::Split(source_dir, "/"); auto source_tfroot_pos = std::find(source_path_components.begin(), source_path_components.end(), tf_root_dir); @@ -51,7 +53,7 @@ PathConfig::PathConfig(const string& output_dir, const string& source_dir, } // TF subdir, e.g. "c/ops" given output_dir "blah/blah/tensorflow/c/ops" - std::vector output_path_components = + std::vector output_path_components = tensorflow::str_util::Split(output_dir, "/"); auto output_tfroot_pos = std::find(output_path_components.begin(), output_path_components.end(), tf_root_dir); diff --git a/tensorflow/c/experimental/ops/gen/common/path_config.h b/tensorflow/c/experimental/ops/gen/common/path_config.h index ce29063be5f682..d47266f86e38ef 100644 --- a/tensorflow/c/experimental/ops/gen/common/path_config.h +++ b/tensorflow/c/experimental/ops/gen/common/path_config.h @@ -23,17 +23,18 @@ namespace tensorflow { namespace generator { struct PathConfig { - string output_path; - std::vector op_names; - std::vector api_dirs; - string tf_prefix_dir; - string tf_root_dir; - string tf_output_dir; + std::string output_path; + std::vector op_names; + std::vector api_dirs; + std::string tf_prefix_dir; + std::string tf_root_dir; + std::string tf_output_dir; explicit PathConfig() = default; - explicit PathConfig(const string &output_dir, const string &source_dir, - const string &api_dir_list, - const std::vector op_names); + explicit PathConfig(const std::string& output_dir, + const std::string& source_dir, + const std::string& api_dir_list, + const std::vector op_names); }; } // namespace generator diff --git a/tensorflow/c/experimental/ops/gen/common/source_code.cc b/tensorflow/c/experimental/ops/gen/common/source_code.cc index 2b7bce6a263184..28e55659c1cc90 100644 --- a/tensorflow/c/experimental/ops/gen/common/source_code.cc +++ b/tensorflow/c/experimental/ops/gen/common/source_code.cc @@ -14,10 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include + #include "absl/log/log.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stringpiece.h" @@ -25,20 +28,20 @@ limitations under the License. namespace tensorflow { namespace generator { -string SourceCode::Render() const { - string code; +std::string SourceCode::Render() const { + std::string code; for (const Line& line : lines_) { - absl::StrAppend(&code, string(line.indent * spaces_per_indent_, ' '), + absl::StrAppend(&code, std::string(line.indent * spaces_per_indent_, ' '), line.text, "\n"); } return code; } -void SourceCode::AddLineWithIndent(const string& line) { +void SourceCode::AddLineWithIndent(const std::string& line) { ValidateAndAddLine(current_indent_, line); } -void SourceCode::AddLineWithoutIndent(const string& line) { +void SourceCode::AddLineWithoutIndent(const std::string& line) { ValidateAndAddLine(0, line); } @@ -48,7 +51,7 @@ void SourceCode::IncreaseIndent() { current_indent_++; } void SourceCode::DecreaseIndent() { current_indent_--; } -void SourceCode::ValidateAndAddLine(int indent, const string& raw_line) { +void SourceCode::ValidateAndAddLine(int indent, const std::string& raw_line) { absl::string_view line(raw_line); bool had_trailing_newline = absl::ConsumeSuffix(&line, "\n"); @@ -57,7 +60,8 @@ void SourceCode::ValidateAndAddLine(int indent, const string& raw_line) { } else if (had_trailing_newline) { LOG(WARNING) << "Superfluous trailing newline in '" << line << "'"; } - lines_.push_back({indent, string(absl::StripTrailingAsciiWhitespace(line))}); + lines_.push_back( + {indent, std::string(absl::StripTrailingAsciiWhitespace(line))}); } } // namespace generator diff --git a/tensorflow/c/experimental/ops/gen/common/source_code.h b/tensorflow/c/experimental/ops/gen/common/source_code.h index df1aa90acf7b8c..9fd7f7eec5e174 100644 --- a/tensorflow/c/experimental/ops/gen/common/source_code.h +++ b/tensorflow/c/experimental/ops/gen/common/source_code.h @@ -24,13 +24,13 @@ namespace generator { class SourceCode { public: - string Render() const; + std::string Render() const; void SetSpacesPerIndent(int spaces_per_indent) { spaces_per_indent_ = spaces_per_indent; } - void AddLineWithIndent(const string &line); - void AddLineWithoutIndent(const string &line); + void AddLineWithIndent(const std::string& line); + void AddLineWithoutIndent(const std::string& line); void AddBlankLine(); void IncreaseIndent(); void DecreaseIndent(); @@ -38,10 +38,10 @@ class SourceCode { private: struct Line { int indent; - string text; + std::string text; }; - void ValidateAndAddLine(int indent_level, const string &raw_line); + void ValidateAndAddLine(int indent_level, const std::string& raw_line); int spaces_per_indent_ = 2; int current_indent_ = 0; diff --git a/tensorflow/c/experimental/ops/gen/common/view_util.cc b/tensorflow/c/experimental/ops/gen/common/view_util.cc index 388aa0646db82b..d8095aca80cf51 100644 --- a/tensorflow/c/experimental/ops/gen/common/view_util.cc +++ b/tensorflow/c/experimental/ops/gen/common/view_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/common/view_util.h" +#include #include #include "absl/strings/str_join.h" @@ -23,17 +24,20 @@ limitations under the License. namespace tensorflow { namespace generator { -string Call(const string& object, const string& method, - std::vector arguments, const char* oper) { +std::string Call(const std::string& object, const std::string& method, + std::vector arguments, const char* oper) { return absl::Substitute("$0$1$2($3)", object, oper, method, absl::StrJoin(arguments, ", ")); } -string Call(const string& function, std::vector arguments) { +std::string Call(const std::string& function, + std::vector arguments) { return absl::Substitute("$0($1)", function, absl::StrJoin(arguments, ", ")); } -string Quoted(const string& s) { return absl::Substitute("\"$0\"", s); } +std::string Quoted(const std::string& s) { + return absl::Substitute("\"$0\"", s); +} } // namespace generator } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/gen/common/view_util.h b/tensorflow/c/experimental/ops/gen/common/view_util.h index 7ab437a90e4fd8..f23831ce8a07dd 100644 --- a/tensorflow/c/experimental/ops/gen/common/view_util.h +++ b/tensorflow/c/experimental/ops/gen/common/view_util.h @@ -22,10 +22,11 @@ limitations under the License. namespace tensorflow { namespace generator { -string Call(const string &function, std::vector arguments); -string Call(const string &object, const string &method, - std::vector arguments, const char *oper = "->"); -string Quoted(const string &s); +std::string Call(const std::string& function, + std::vector arguments); +std::string Call(const std::string& object, const std::string& method, + std::vector arguments, const char* oper = "->"); +std::string Quoted(const std::string& s); } // namespace generator } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc index 3fe5c059ca4e70..45e7b87069e361 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc @@ -52,11 +52,11 @@ SourceCode CppGenerator::SourceFileContents() const { return GenerateOneFile(cpp::RendererContext::kSource); } -string CppGenerator::HeaderFileName() const { +std::string CppGenerator::HeaderFileName() const { return io::JoinPath(path_config_.output_path, cpp_config_.unit + "_ops.h"); } -string CppGenerator::SourceFileName() const { +std::string CppGenerator::SourceFileName() const { return io::JoinPath(path_config_.output_path, cpp_config_.unit + "_ops.cc"); } diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h index 0a7b08cd9b171f..b4d016e0ecca44 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h @@ -30,8 +30,8 @@ class CppGenerator { explicit CppGenerator(cpp::CppConfig cpp_config, PathConfig path_config); SourceCode HeaderFileContents() const; SourceCode SourceFileContents() const; - string HeaderFileName() const; - string SourceFileName() const; + std::string HeaderFileName() const; + std::string SourceFileName() const; void WriteHeaderFile() const; void WriteSourceFile() const; diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc index f4a4d82bbce423..e1db2c9b8ce14b 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc @@ -30,12 +30,12 @@ namespace generator { namespace { TEST(CppGeneratorTest, typical_usage) { - string category = "testing"; - string name_space = "tensorflow::ops"; - string output_dir = "tensorflow/c/experimental/ops/gen/cpp/golden"; - string source_dir = "tensorflow"; - string api_dirs = ""; - std::vector ops = { + std::string category = "testing"; + std::string name_space = "tensorflow::ops"; + std::string output_dir = "tensorflow/c/experimental/ops/gen/cpp/golden"; + std::string source_dir = "tensorflow"; + std::string api_dirs = ""; + std::vector ops = { "Neg", // Simple unary Op "MatMul", // 2 inputs & attrs with default values "IdentityN", // Variadic input+output @@ -50,17 +50,19 @@ TEST(CppGeneratorTest, typical_usage) { CppGenerator generator(cpp_config, controller_config); Env *env = Env::Default(); - string golden_dir = io::JoinPath(testing::TensorFlowSrcRoot(), - controller_config.tf_output_dir); + std::string golden_dir = io::JoinPath(testing::TensorFlowSrcRoot(), + controller_config.tf_output_dir); - string generated_header = generator.HeaderFileContents().Render(); - string generated_source = generator.SourceFileContents().Render(); - string expected_header; - string header_file_name = io::JoinPath(golden_dir, "testing_ops.h.golden"); + std::string generated_header = generator.HeaderFileContents().Render(); + std::string generated_source = generator.SourceFileContents().Render(); + std::string expected_header; + std::string header_file_name = + io::JoinPath(golden_dir, "testing_ops.h.golden"); TF_CHECK_OK(ReadFileToString(env, header_file_name, &expected_header)); - string expected_source; - string source_file_name = io::JoinPath(golden_dir, "testing_ops.cc.golden"); + std::string expected_source; + std::string source_file_name = + io::JoinPath(golden_dir, "testing_ops.cc.golden"); TF_CHECK_OK(ReadFileToString(env, source_file_name, &expected_source)); // Remove carriage returns (for Windows) diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc index 4f0e64e3b0f8eb..7c8231a71133f5 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc @@ -22,7 +22,7 @@ namespace tensorflow { namespace generator { namespace cpp { -CppConfig::CppConfig(const string &category, const string &name_space) +CppConfig::CppConfig(const std::string& category, const std::string& name_space) : category(category), unit(absl::AsciiStrToLower(category)), namespaces(absl::StrSplit(name_space, "::")) {} diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h index fa7571d98a1214..eec5888e17e7cf 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h @@ -24,13 +24,13 @@ namespace generator { namespace cpp { struct CppConfig { - string category; - string unit; - std::vector namespaces; + std::string category; + std::string unit; + std::vector namespaces; explicit CppConfig() = default; - explicit CppConfig(const string &category, - const string &name_space = "tensorflow::ops"); + explicit CppConfig(const std::string& category, + const std::string& name_space = "tensorflow::ops"); }; } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc index 1a685cac0c405c..50db08df1db988 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc @@ -27,10 +27,10 @@ namespace generator { namespace cpp { GuardRenderer::GuardRenderer(RendererContext context) : Renderer(context) { - string self_path = io::JoinPath(context_.path_config.tf_root_dir, - context_.path_config.tf_output_dir, - context_.cpp_config.unit + "_ops.h"); - string with_underscores(self_path); + std::string self_path = io::JoinPath(context_.path_config.tf_root_dir, + context_.path_config.tf_output_dir, + context_.cpp_config.unit + "_ops.h"); + std::string with_underscores(self_path); std::replace(with_underscores.begin(), with_underscores.end(), '/', '_'); std::replace(with_underscores.begin(), with_underscores.end(), '.', '_'); guard_ = toUpperSnake(with_underscores) + "_"; diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h index a45fe89a7a011c..bbd29e4620e2c2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h @@ -31,7 +31,7 @@ class GuardRenderer : public Renderer { void Close(); private: - string guard_; + std::string guard_; }; } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc index 38f31209f6da24..0ec8108bee7aaf 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc @@ -30,13 +30,13 @@ void IncludeRenderer::SelfHeader() { BlankLine(); } -string IncludeRenderer::SelfHeaderPath() const { +std::string IncludeRenderer::SelfHeaderPath() const { return io::JoinPath(context_.path_config.tf_root_dir, context_.path_config.tf_output_dir, context_.cpp_config.unit + "_ops.h"); } -void IncludeRenderer::Include(const string &tf_file_path) { +void IncludeRenderer::Include(const std::string& tf_file_path) { CodeLine("#include \"$0\"", io::JoinPath(context_.path_config.tf_prefix_dir, tf_file_path)); } diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h index e43715a62e45b0..4178f0da5beeb9 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h @@ -27,12 +27,12 @@ class IncludeRenderer : public Renderer { public: explicit IncludeRenderer(RendererContext context); - string SelfHeaderPath() const; + std::string SelfHeaderPath() const; void SelfHeader(); void Headers(); private: - void Include(const string &tf_file_path); + void Include(const std::string& tf_file_path); }; } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc index db28ab303ae5c6..b490cc7fe9e86a 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc @@ -26,7 +26,7 @@ NamespaceRenderer::NamespaceRenderer(RendererContext context) : Renderer(context) {} void NamespaceRenderer::Open() { - for (const string& ns : context_.cpp_config.namespaces) { + for (const std::string& ns : context_.cpp_config.namespaces) { CodeLine("namespace " + ns + " {"); } } diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc index c459d239ca699f..63cb5f30eb1d9d 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc @@ -31,11 +31,11 @@ namespace tensorflow { namespace generator { namespace cpp { -string OpRenderer::Signature() const { - std::vector args_with_default_val; - std::vector args_without_default_val; +std::string OpRenderer::Signature() const { + std::vector args_with_default_val; + std::vector args_without_default_val; for (OpArgumentView const& argument : op_.AllArguments()) { - string text = argument.Declaration(); + std::string text = argument.Declaration(); if (context_.mode == RendererContext::kHeader) { absl::StrAppend(&text, argument.Initializer()); } @@ -45,7 +45,7 @@ string OpRenderer::Signature() const { args_without_default_val.push_back(text); } } - std::vector arguments; + std::vector arguments; arguments.reserve(args_without_default_val.size() + args_with_default_val.size()); arguments.insert(arguments.end(), diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h index 3360e14e672e3a..1ea161f55bdad9 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h @@ -34,7 +34,7 @@ class OpRenderer : public Renderer { OpView op_; OpCommentRenderer comment_; - string Signature() const; + std::string Signature() const; }; } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc index a9efb94335c0a6..6a608d759a3753 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc @@ -34,21 +34,21 @@ Renderer& Renderer::BlankLine() { return *this; } -Renderer& Renderer::CodeLine(const string& text) { +Renderer& Renderer::CodeLine(const std::string& text) { context_.code.AddLineWithoutIndent(text); return *this; } -Renderer& Renderer::CodeLines(const string& text) { +Renderer& Renderer::CodeLines(const std::string& text) { absl::string_view trimmed_text(text); str_util::RemoveWhitespaceContext(&trimmed_text); - for (const string& line : str_util::Split(trimmed_text, '\n')) { + for (const std::string& line : str_util::Split(trimmed_text, '\n')) { context_.code.AddLineWithoutIndent(line); } return *this; } -Renderer& Renderer::Statement(const string& text) { +Renderer& Renderer::Statement(const std::string& text) { if (absl::EndsWith(text, ";")) { LOG(WARNING) << "Superfluous terminating ';' in '" << text << "'"; context_.code.AddLineWithIndent(text); @@ -58,22 +58,22 @@ Renderer& Renderer::Statement(const string& text) { return *this; } -Renderer& Renderer::TFStatement(const string& text) { +Renderer& Renderer::TFStatement(const std::string& text) { return Statement(absl::Substitute("TF_RETURN_IF_ERROR($0)", text)); } -Renderer& Renderer::CommentLine(const string& text) { +Renderer& Renderer::CommentLine(const std::string& text) { context_.code.AddLineWithIndent(absl::StrCat("// ", text)); return *this; } -Renderer& Renderer::BlockOpen(const string& text) { +Renderer& Renderer::BlockOpen(const std::string& text) { context_.code.AddLineWithIndent(absl::StrCat(text, " {")); context_.code.IncreaseIndent(); return *this; } -Renderer& Renderer::BlockClose(const string& text) { +Renderer& Renderer::BlockClose(const std::string& text) { context_.code.DecreaseIndent(); context_.code.AddLineWithIndent(absl::StrCat("}", text)); return *this; diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h index b6168b196b35b2..f41923651f44e2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h @@ -34,7 +34,7 @@ class Renderer { // Append a line of source code, left-justified (not indented). // Use for preprocessors directives ("#include"), namespaces, etc. - Renderer &CodeLine(const string &text); + Renderer& CodeLine(const std::string& text); template Renderer CodeLine(absl::string_view text, const Args &...args) { return CodeLine(absl::Substitute(text, args...)); @@ -44,7 +44,7 @@ class Renderer { // Note: Trims leading/trailing whitespace including newlines, making this // method convenient for multiline raw strings. // Newlines ('\n') are allowed/expected. - Renderer &CodeLines(const string &text); + Renderer& CodeLines(const std::string& text); template Renderer CodeLines(absl::string_view text, const Args &...args) { return CodeLines(absl::Substitute(text, args...)); @@ -52,7 +52,7 @@ class Renderer { // Indent and append a C++ statement. // Note: do *not* include a trailing semicolon in the statement text. - Renderer &Statement(const string &text); + Renderer& Statement(const std::string& text); template Renderer Statement(absl::string_view text, const Args &...args) { return Statement(absl::Substitute(text, args...)); @@ -60,14 +60,14 @@ class Renderer { // Indent and append a call to a TF method returning a Status to check. // Note: do *not* include a trailing semicolon in the statement text. - Renderer &TFStatement(const string &text); + Renderer& TFStatement(const std::string& text); template Renderer TFStatement(absl::string_view text, const Args &...args) { return TFStatement(absl::Substitute(text, args...)); } // Indent and append a C++ single-line style comment (using '//'). - Renderer &CommentLine(const string &text = ""); + Renderer& CommentLine(const std::string& text = ""); template Renderer CommentLine(absl::string_view text, const Args &...args) { return CommentLine(absl::Substitute(text, args...)); @@ -75,7 +75,7 @@ class Renderer { // Append a line of code which starts a new block: trailing with '{') and // indenting. - Renderer &BlockOpen(const string &text); + Renderer& BlockOpen(const std::string& text); template Renderer BlockOpen(absl::string_view text, const Args &...args) { return BlockOpen(absl::Substitute(text, args...)); @@ -83,7 +83,7 @@ class Renderer { // Append a line of code ending a block: unindenting and adding '}'. // Note: optional trailing text is often a comment, e.g. '// namespace xyz'. - Renderer &BlockClose(const string &text = ""); + Renderer& BlockClose(const std::string& text = ""); template Renderer BlockClose(absl::string_view text, const Args &...args) { return BlockClose(absl::Substitute(text, args...)); diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc index eff654c5938160..6621d1aea2c217 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc @@ -57,7 +57,7 @@ TEST(Renderer, typical_usage) { SourceCode code; TestRenderer(code).Render(); - string expected = R"(// File level comment. + std::string expected = R"(// File level comment. #include "header.h" void TestFunction() { diff --git a/tensorflow/c/experimental/ops/gen/generate_cpp_main.cc b/tensorflow/c/experimental/ops/gen/generate_cpp_main.cc index 18a506942de5b7..cb922d0a06b7ae 100644 --- a/tensorflow/c/experimental/ops/gen/generate_cpp_main.cc +++ b/tensorflow/c/experimental/ops/gen/generate_cpp_main.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "absl/log/check.h" diff --git a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc index c2bf61d785e6b2..417a0f26d70b92 100644 --- a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc +++ b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc @@ -26,8 +26,7 @@ namespace { SavedObjectGraph ParseSavedObjectGraph(absl::string_view text_proto) { SavedObjectGraph value; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), - &value)); + CHECK(tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &value)); return value; } diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index 4214f76cee1cee..de027662df30cf 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -82,6 +82,7 @@ tf_cc_test( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc index 1d55dabcc9ab87..866dbaa94895d0 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/test_utils.h" #include "tensorflow/c/tensor_interface.h" diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc index 2ac31f313230ac..673411a44456d1 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_context.h" diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc index 6250af6dba1359..1796c99dc79f17 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc @@ -178,8 +178,7 @@ tuple_value: { StructuredValue ParseStructuredValue(absl::string_view text_proto) { StructuredValue value; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), - &value)); + CHECK(tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &value)); return value; } diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc index f2ff59a4c853e0..c44bc832547dab 100644 --- a/tensorflow/c/kernels/bitcast_op_test.cc +++ b/tensorflow/c/kernels/bitcast_op_test.cc @@ -60,7 +60,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type, (*def.mutable_attr())["type"] = outTypeAttr; def.add_input( - strings::StrCat("input1: ", DataTypeString(input_tensor->dtype()))); + absl::StrCat("input1: ", DataTypeString(input_tensor->dtype()))); std::unique_ptr kernel = CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status); @@ -86,13 +86,13 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type, TEST(BitcastOpTest, TestUpcast) { Tensor int8_input(DT_UINT8, {8}); for (int i = 0; i < 8; i++) { - int8_input.vec()(i) = static_cast(1); + int8_input.vec()(i) = static_cast(1); } TestBitcastOp(&int8_input, DT_UINT64, TensorShape(), error::OK); } TEST(BitcastOpTest, TestDowncast) { - Tensor int64_input(static_cast(1)); + Tensor int64_input(static_cast(1)); TestBitcastOp(&int64_input, DT_UINT8, TensorShape({8}), error::OK); } diff --git a/tensorflow/c/kernels/histogram_summary_op.cc b/tensorflow/c/kernels/histogram_summary_op.cc index 7f34e5217c20ba..35340baa5749ce 100644 --- a/tensorflow/c/kernels/histogram_summary_op.cc +++ b/tensorflow/c/kernels/histogram_summary_op.cc @@ -151,13 +151,13 @@ void RegisterHistogramSummaryOpKernel() { TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() { if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary")) { RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); RegisterHistogramSummaryOpKernel(); RegisterHistogramSummaryOpKernel(); RegisterHistogramSummaryOpKernel(); diff --git a/tensorflow/c/kernels/merge_summary_op.cc b/tensorflow/c/kernels/merge_summary_op.cc index 339267d094a554..ddbc3440d47dc1 100644 --- a/tensorflow/c/kernels/merge_summary_op.cc +++ b/tensorflow/c/kernels/merge_summary_op.cc @@ -50,7 +50,7 @@ void MergeSummaryOp_Delete(void* kernel) {} void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { tensorflow::Summary s; - std::unordered_set tags; + std::unordered_set tags; Safe_TF_StatusPtr status(TF_NewStatus()); for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) { TF_Tensor* input; @@ -74,7 +74,7 @@ void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { for (int v = 0; v < summary_in.value_size(); ++v) { // This tag is unused by the TensorSummary op, so no need to check for // duplicates. - const tensorflow::string& tag = summary_in.value(v).tag(); + const std::string& tag = summary_in.value(v).tag(); if ((!tag.empty()) && !tags.insert(tag).second) { std::ostringstream err; err << "Duplicate tag " << tag << " found in summary inputs "; diff --git a/tensorflow/c/kernels/summary_op.cc b/tensorflow/c/kernels/summary_op.cc index 486aea1af53b50..5688d00fa8fa7c 100644 --- a/tensorflow/c/kernels/summary_op.cc +++ b/tensorflow/c/kernels/summary_op.cc @@ -126,7 +126,7 @@ std::string SingleTag(TF_Tensor* tags) { if (TF_TensorElementCount(tags) == 1) { const char* single_tag = static_cast(TF_TensorData(tags))->c_str(); - return tensorflow::strings::StrCat(" (tag '", single_tag, "')"); + return absl::StrCat(" (tag '", single_tag, "')"); } else { return ""; } @@ -155,13 +155,13 @@ void RegisterScalarSummaryOpKernel() { TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() { if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) { RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); RegisterScalarSummaryOpKernel(); RegisterScalarSummaryOpKernel(); RegisterScalarSummaryOpKernel(); diff --git a/tensorflow/c/kernels/summary_op_test.cc b/tensorflow/c/kernels/summary_op_test.cc index 11a7c06c1d2e30..43de49bc39419d 100644 --- a/tensorflow/c/kernels/summary_op_test.cc +++ b/tensorflow/c/kernels/summary_op_test.cc @@ -45,13 +45,15 @@ class DummyDevice : public DeviceBase { }; // Helper for comparing output and expected output -void ExpectSummaryMatches(const Summary& actual, const string& expected_str) { +void ExpectSummaryMatches(const Summary& actual, + const std::string& expected_str) { Summary expected; ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected)); EXPECT_EQ(expected.DebugString(), actual.DebugString()); } -void TestScalarSummaryOp(Tensor* tags, Tensor* values, string expected_output, +void TestScalarSummaryOp(Tensor* tags, Tensor* values, + std::string expected_output, error::Code expected_code) { // Initialize node used to fetch OpKernel absl::Status status; @@ -64,8 +66,8 @@ void TestScalarSummaryOp(Tensor* tags, Tensor* values, string expected_output, SetAttrValue(values->dtype(), &valuesTypeAttr); (*def.mutable_attr())["T"] = valuesTypeAttr; - def.add_input(strings::StrCat("input1: ", DataTypeString(tags->dtype()))); - def.add_input(strings::StrCat("input2: ", DataTypeString(values->dtype()))); + def.add_input(absl::StrCat("input1: ", DataTypeString(tags->dtype()))); + def.add_input(absl::StrCat("input2: ", DataTypeString(values->dtype()))); std::unique_ptr kernel = CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status); diff --git a/tensorflow/c/kernels/tensor_shape_utils.cc b/tensorflow/c/kernels/tensor_shape_utils.cc index 967330ccb93f87..ba54dc4eda4df9 100644 --- a/tensorflow/c/kernels/tensor_shape_utils.cc +++ b/tensorflow/c/kernels/tensor_shape_utils.cc @@ -26,15 +26,15 @@ namespace tensorflow { std::string ShapeDebugString(TF_Tensor* tensor) { // A TF_Tensor cannot have an unknown rank. CHECK_GE(TF_NumDims(tensor), 0); - tensorflow::string s = "["; + std::string s = "["; for (int i = 0; i < TF_NumDims(tensor); ++i) { - if (i > 0) tensorflow::strings::StrAppend(&s, ","); + if (i > 0) absl::StrAppend(&s, ","); int64_t dim = TF_Dim(tensor, i); // A TF_Tensor cannot have an unknown dimension. CHECK_GE(dim, 0); - tensorflow::strings::StrAppend(&s, dim); + absl::StrAppend(&s, dim); } - tensorflow::strings::StrAppend(&s, "]"); + absl::StrAppend(&s, "]"); return s; } } // namespace tensorflow diff --git a/tensorflow/c/logging.cc b/tensorflow/c/logging.cc deleted file mode 100644 index 13c9e6ac208a14..00000000000000 --- a/tensorflow/c/logging.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/logging.h" - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stringprintf.h" - -static ::tensorflow::string BuildMessage(const char* fmt, va_list args) { - ::tensorflow::string message; - ::tensorflow::strings::Appendv(&message, fmt, args); - return message; -} - -void TF_Log(TF_LogLevel level, const char* fmt, ...) { - if (level < TF_INFO || level > TF_FATAL) return; - va_list args; - va_start(args, fmt); - auto message = BuildMessage(fmt, args); - va_end(args); - switch (level) { - case TF_INFO: - LOG(INFO) << message; - break; - case TF_WARNING: - LOG(WARNING) << message; - break; - case TF_ERROR: - LOG(ERROR) << message; - break; - case TF_FATAL: - LOG(FATAL) << message; - break; - } -} - -void TF_VLog(int level, const char* fmt, ...) { - va_list args; - va_start(args, fmt); - auto message = BuildMessage(fmt, args); - va_end(args); - VLOG(level) << message; -} - -void TF_DVLog(int level, const char* fmt, ...) { - va_list args; - va_start(args, fmt); - auto message = BuildMessage(fmt, args); - va_end(args); - DVLOG(level) << message; -} diff --git a/tensorflow/c/logging.h b/tensorflow/c/logging.h deleted file mode 100644 index 9583777b661122..00000000000000 --- a/tensorflow/c/logging.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_LOGGING_H_ -#define TENSORFLOW_C_LOGGING_H_ - -#include "tensorflow/c/c_api_macros.h" - -// -------------------------------------------------------------------------- -// C API for tensorflow::Logging. - -#ifdef __cplusplus -extern "C" { -#endif - -typedef enum TF_LogLevel { - TF_INFO = 0, - TF_WARNING = 1, - TF_ERROR = 2, - TF_FATAL = 3, -} TF_LogLevel; - -TF_CAPI_EXPORT extern void TF_Log(TF_LogLevel level, const char* fmt, ...); -TF_CAPI_EXPORT extern void TF_VLog(int level, const char* fmt, ...); -TF_CAPI_EXPORT extern void TF_DVLog(int level, const char* fmt, ...); - -#ifdef __cplusplus -} -#endif - -#endif // TENSORFLOW_C_LOGGING_H_ diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h index 02a38e9b164eb3..c991fc1f74f2e8 100644 --- a/tensorflow/c/tf_datatype.h +++ b/tensorflow/c/tf_datatype.h @@ -65,6 +65,7 @@ typedef enum TF_DataType { TF_UINT4 = 30, TF_INT2 = 31, TF_UINT2 = 32, + TF_FLOAT4_E2M1FN = 33 // 2 exponent bits, 1 mantissa bit, finite-only } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 95748942f06390..f776fdd9612ecd 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -34,7 +34,7 @@ class ClientSession::Impl { Impl(Session* session, std::shared_ptr graph) : session_(session), graph_(std::move(graph)) {} - static SessionOptions MakeDefaultSessionOptions(const string& target); + static SessionOptions MakeDefaultSessionOptions(const std::string& target); absl::Status MaybeExtendGraph() const; std::unique_ptr session_; @@ -44,7 +44,7 @@ class ClientSession::Impl { mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0; }; -ClientSession::ClientSession(const Scope& scope, const string& target) +ClientSession::ClientSession(const Scope& scope, const std::string& target) : ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {} ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {} @@ -64,7 +64,7 @@ ClientSession::ClientSession(const Scope& scope, ClientSession::~ClientSession() {} SessionOptions ClientSession::Impl::MakeDefaultSessionOptions( - const string& target) { + const std::string& target) { SessionOptions options; options.env = Env::Default(); options.target = target; @@ -108,7 +108,7 @@ absl::Status ClientSession::Run(const RunOptions& run_options, const std::vector& run_outputs, std::vector* outputs, RunMetadata* run_metadata) const { - std::vector> feeds; + std::vector> feeds; feeds.reserve(inputs.size()); for (auto const& feed : inputs) { TF_RETURN_IF_ERROR(feed.second.status); @@ -117,12 +117,12 @@ absl::Status ClientSession::Run(const RunOptions& run_options, std::forward_as_tuple(feed.second.tensor)); } - std::vector output_tensor_names; + std::vector output_tensor_names; output_tensor_names.reserve(fetch_outputs.size()); for (auto const& output : fetch_outputs) { output_tensor_names.push_back(output.name()); } - std::vector target_node_names; + std::vector target_node_names; target_node_names.reserve(run_outputs.size()); for (auto const& output : run_outputs) { target_node_names.push_back(output.node()->name()); @@ -138,17 +138,17 @@ absl::Status ClientSession::Run( const std::vector& run_outputs, std::vector* outputs, RunMetadata* run_metadata, const thread::ThreadPoolOptions& threadpool_options) const { - std::vector> feeds; + std::vector> feeds; for (auto const& feed : inputs) { TF_RETURN_IF_ERROR(feed.second.status); feeds.emplace_back(feed.first.name(), feed.second.tensor); } - std::vector output_tensor_names; + std::vector output_tensor_names; output_tensor_names.reserve(fetch_outputs.size()); for (auto const& output : fetch_outputs) { output_tensor_names.push_back(output.name()); } - std::vector target_node_names; + std::vector target_node_names; target_node_names.reserve(run_outputs.size()); for (auto const& output : run_outputs) { target_node_names.push_back(output.node()->name()); diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index 9dc790d0171528..bf5cf8b2c6c371 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -65,7 +65,7 @@ class ClientSession { /// Create a new session to evaluate the graph contained in `scope` by /// connecting to the TensorFlow runtime specified by `target`. - ClientSession(const Scope& scope, const string& target); + ClientSession(const Scope& scope, const std::string& target); /// Same as above, but use the empty string ("") as the target specification. explicit ClientSession(const Scope& scope); diff --git a/tensorflow/cc/framework/cc_op_gen_util.cc b/tensorflow/cc/framework/cc_op_gen_util.cc index 45c88283a47a6c..048378e68f4525 100644 --- a/tensorflow/cc/framework/cc_op_gen_util.cc +++ b/tensorflow/cc/framework/cc_op_gen_util.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/cc/framework/cc_op_gen_util.h" -#include #include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -107,10 +107,10 @@ string ToGuard(absl::string_view path) { string guard; guard.reserve(path.size() + 1); // + 1 -> trailing _ for (const char c : path) { - if (c >= 'A' && c <= 'Z') { + if (absl::ascii_isupper(c)) { guard += c; - } else if (c >= 'a' && c <= 'z') { - guard += c + 'A' - 'a'; + } else if (absl::ascii_islower(c)) { + guard += absl::ascii_toupper(c); } else { guard += '_'; } @@ -306,7 +306,7 @@ string ToCamelCase(absl::string_view str) { } else if (c == joiner) { cap = true; } else if (cap) { - result += toupper(c); + result += absl::ascii_toupper(c); cap = false; } else { result += c; diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc index dcac1e4c0373bd..cd332ed1791849 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc @@ -42,7 +42,7 @@ namespace tensorflow { namespace cc_op { namespace { -string DefaultValue(OpDef_AttrDef attr) { +std::string DefaultValue(OpDef_AttrDef attr) { static const auto* attr_default_value_map = new absl::flat_hash_map{ @@ -80,19 +80,19 @@ string DefaultValue(OpDef_AttrDef attr) { return std::string(entry->second); } -string WriteClassFuzzDef(const OpInfo& op_info) { - string class_signature_str = absl::Substitute( +std::string WriteClassFuzzDef(const OpInfo& op_info) { + std::string class_signature_str = absl::Substitute( "class Fuzz$0 : public FuzzSession<$1> {\n", op_info.op_name, absl::StrJoin(op_info.graph_op_def.input_arg(), ", ", - [](string* out, const auto arg) { + [](std::string* out, const auto arg) { absl::StrAppend(out, "Tensor"); if (ArgIsList(arg)) absl::StrAppend(out, ", Tensor"); })); - string build_graph_body = absl::StrCat( + std::string build_graph_body = absl::StrCat( absl::StrJoin( op_info.graph_op_def.input_arg(), "", - [op_info](string* out, const OpDef_ArgDef arg) { + [op_info](std::string* out, const OpDef_ArgDef arg) { std::string type = "DT_UINT8"; if (arg.type() != DT_INVALID) { @@ -130,7 +130,7 @@ string WriteClassFuzzDef(const OpInfo& op_info) { } }), absl::StrJoin(op_info.graph_op_def.attr(), "", - [op_info](string* out, const OpDef_AttrDef attr) { + [op_info](std::string* out, const OpDef_AttrDef attr) { if (op_info.inferred_input_attrs.count(attr.name()) == 0 && !attr.has_default_value()) { @@ -139,22 +139,22 @@ string WriteClassFuzzDef(const OpInfo& op_info) { } })); - string constructor_call_str = absl::Substitute( + std::string constructor_call_str = absl::Substitute( " tensorflow::ops::$0(scope.WithOpName(\"output\")$1);\n", op_info.op_name, absl::StrCat( op_info.api_def.arg_order().empty() ? absl::StrJoin(op_info.api_def.in_arg(), "", - [](string* out, const auto api_def_arg) { + [](std::string* out, const auto api_def_arg) { strings::StrAppend(out, ", ", api_def_arg.name()); }) : absl::StrJoin(op_info.api_def.arg_order(), "", - [](string* out, const auto name) { + [](std::string* out, const auto name) { strings::StrAppend(out, ", ", name); }), absl::StrJoin(op_info.graph_op_def.attr(), "", - [op_info](string* out, const OpDef_AttrDef attr) { + [op_info](std::string* out, const OpDef_AttrDef attr) { if (op_info.inferred_input_attrs.count(attr.name()) == 0 && !attr.has_default_value()) { @@ -162,20 +162,20 @@ string WriteClassFuzzDef(const OpInfo& op_info) { } }))); - string fuzz_impl_signature_str = absl::Substitute( + std::string fuzz_impl_signature_str = absl::Substitute( " void FuzzImpl($0) final {\n", absl::StrJoin( op_info.graph_op_def.input_arg(), ", ", - [](string* out, const auto arg) { + [](std::string* out, const auto arg) { strings::StrAppend(out, "const Tensor& ", arg.name(), "_0"); if (ArgIsList(arg)) strings::StrAppend(out, ", const Tensor& ", arg.name(), "_1"); })); - string run_inputs_str = absl::Substitute( + std::string run_inputs_str = absl::Substitute( " RunInputs({$0});\n", absl::StrJoin(op_info.graph_op_def.input_arg(), ", ", - [](string* out, const auto arg) { + [](std::string* out, const auto arg) { if (ArgIsList(arg)) { strings::StrAppend( out, "{\"", arg.name(), "\", ", arg.name(), "_0}, ", @@ -186,7 +186,7 @@ string WriteClassFuzzDef(const OpInfo& op_info) { } })); - string fuzz_class_def = strings::StrCat( + std::string fuzz_class_def = strings::StrCat( class_signature_str, " void BuildGraph(const Scope& scope) override {\n", build_graph_body, constructor_call_str, " }\n", fuzz_impl_signature_str, run_inputs_str, " }\n", "};\n"); @@ -194,24 +194,24 @@ string WriteClassFuzzDef(const OpInfo& op_info) { return fuzz_class_def; } -string WriteFuzzTest(const OpInfo& op_info) { +std::string WriteFuzzTest(const OpInfo& op_info) { return absl::Substitute( "FUZZ_TEST_F(Fuzz$0, Fuzz).WithDomains($1);\n", op_info.op_name, absl::StrJoin(op_info.graph_op_def.input_arg(), ", ", - [](string* out, const auto arg) { + [](std::string* out, const auto arg) { absl::StrAppend(out, "AnyTensor()"); if (ArgIsList(arg)) absl::StrAppend(out, ", AnyTensor()"); })); } -string FuzzerFileStart() { - const string fuzz_namespace_begin = R"namespace( +std::string FuzzerFileStart() { + const std::string fuzz_namespace_begin = R"namespace( namespace tensorflow { namespace fuzzing { )namespace"; - const string fuzz_header = + const std::string fuzz_header = absl::StrCat(R"include(// This file is MACHINE GENERATED! Do not edit. #include "tensorflow/cc/ops/const_op.h" @@ -224,8 +224,8 @@ namespace fuzzing { return fuzz_header; } -string FuzzerFileEnd() { - const string fuzz_footer = R"footer( +std::string FuzzerFileEnd() { + const std::string fuzz_footer = R"footer( } // namespace fuzzing } // namespace tensorflow )footer"; @@ -258,7 +258,7 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { } // TODO(unda) : zero input ops - std::set zero_input_ops = {"Placeholder", "ImmutableConst"}; + std::set zero_input_ops = {"Placeholder", "ImmutableConst"}; if (zero_input_ops.find(op_info.op_name) != zero_input_ops.end()) { std::cout << "NOT fuzzing: " << op_info.graph_op_def.name() << " takes zero inputs.\n"; @@ -266,19 +266,19 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { } // TODO(unda, 253431636): constrained kernel - std::set constrained_kernel = {"Diag", - "DiagPart", - "GatherNd", - "GatherV2", - "QuantizeAndDequantizeV2", - "QuantizeAndDequantizeV3", - "QuantizeAndDequantizeV4", - "QuantizeAndDequantizeV4Grad", - "QuantizedConcat", - "QuantizedInstanceNorm", - "QuantizedReshape", - "ScatterNd", - "TensorScatterUpdate"}; + std::set constrained_kernel = {"Diag", + "DiagPart", + "GatherNd", + "GatherV2", + "QuantizeAndDequantizeV2", + "QuantizeAndDequantizeV3", + "QuantizeAndDequantizeV4", + "QuantizeAndDequantizeV4Grad", + "QuantizedConcat", + "QuantizedInstanceNorm", + "QuantizedReshape", + "ScatterNd", + "TensorScatterUpdate"}; // TODO(unda, b/253431636): constrained kernel if (constrained_kernel.find(op_info.op_name) != constrained_kernel.end()) { @@ -297,7 +297,7 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { } } - std::set unhandled_attr_types = { + std::set unhandled_attr_types = { "list(type)", "func", "float", "bool", "tensor", "list(string)", "list(bool)", "list(shape)", "list(tensor)", "list(attr)"}; @@ -321,7 +321,7 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { return true; } -string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable) { +std::string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable) { return absl::StrCat( FuzzerFileStart(), is_fuzzable ? WriteClassFuzzDef(op_info) : "", is_fuzzable ? WriteFuzzTest(op_info) : "", FuzzerFileEnd()); diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h index c11c9635d6d149..9dfee93e55e2e1 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h @@ -25,7 +25,7 @@ namespace tensorflow { namespace cc_op { // String with single fuzzer file content. -string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable); +std::string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable); // Do we have all we need to create a fuzzer bool OpFuzzingIsOk(const OpInfo& op_info); diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc index f4a1eb642557de..6da6e2af6c3445 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc @@ -39,8 +39,9 @@ namespace tensorflow { namespace cc_op { namespace { -void WriteAllFuzzers(string root_location, std::vector api_def_dirs, - std::vector op_names) { +void WriteAllFuzzers(std::string root_location, + std::vector api_def_dirs, + std::vector op_names) { OpList ops; absl::StatusOr api_def_map = LoadOpsAndApiDefs(ops, false, api_def_dirs); @@ -60,7 +61,7 @@ void WriteAllFuzzers(string root_location, std::vector api_def_dirs, continue; } - OpInfo op_info(op_def, *api_def, std::vector()); + OpInfo op_info(op_def, *api_def, std::vector()); status.Update(env->NewWritableFile( root_location + "/" + op_def.name() + "_fuzz.cc", &fuzz_file)); status.Update( @@ -87,9 +88,9 @@ int main(int argc, char* argv[]) { for (int i = 1; i < argc; ++i) { fprintf(stdout, "Arg %d = %s\n", i, argv[i]); } - std::vector api_def_srcs = tensorflow::str_util::Split( + std::vector api_def_srcs = tensorflow::str_util::Split( argv[2], ",", tensorflow::str_util::SkipEmpty()); - std::vector op_names = tensorflow::str_util::Split( + std::vector op_names = tensorflow::str_util::Split( argv[3], ",", tensorflow::str_util::SkipEmpty()); tensorflow::cc_op::WriteAllFuzzers(argv[1], api_def_srcs, op_names); return 0; diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 357515a5dccb00..f3c3fd045a3d6f 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -218,9 +218,9 @@ REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad); absl::Status CheckNumericsGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string message; + std::string message; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); - string err_msg = absl::StrCat( + std::string err_msg = absl::StrCat( "Not a number (NaN) or infinity (Inf) values detected in gradient. ", message); grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg)); @@ -411,7 +411,7 @@ REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad); absl::Status MirrorPadGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string mode; + std::string mode; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( scope, grad_inputs[0], op.input(1), mode)); @@ -424,7 +424,7 @@ REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad); absl::Status MirrorPadGradGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string mode; + std::string mode; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 77e2a3bfc38476..deb90eec264ee7 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -95,7 +95,7 @@ absl::Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string kernel_type; + std::string kernel_type; TF_RETURN_IF_ERROR( GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type)); bool antialias; @@ -117,7 +117,7 @@ absl::Status CropAndResizeGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { DataType input_type; - string method; + std::string method; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "method", &method)); TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "T", &input_type)); auto image_shape = Shape(scope, op.input(0)); diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index f7a39f39cfc42a..b77f5512237024 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -203,7 +203,7 @@ class ScaleAndTranslateGradTest : public ::testing::Test { template void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale, - Input translation, const string& kernel_type, bool antialias, + Input translation, const std::string& kernel_type, bool antialias, Output* x, Output* y) { *x = Const(scope_, x_data); *y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation, @@ -216,7 +216,7 @@ class ScaleAndTranslateGradTest : public ::testing::Test { template void TestScaleAndTranslate(const TensorShape x_shape, const int out_height, const int out_width, Input scale, - Input translation, const string& kernel_type, + Input translation, const std::string& kernel_type, bool antialias) { Tensor x_data = MakeData(x_shape); Output x, y; diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index bf6f509c21ee8a..c785af15f95447 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -1070,8 +1070,8 @@ absl::Status MatMulGradHelper(const Scope& scope, const bool is_batch, absl::Status MatMulGradCommon(const Scope& scope, const Operation& op, const bool is_batch, const std::vector& grad_inputs, - const string& attr_adj_x, - const string& attr_adj_y, + const std::string& attr_adj_x, + const std::string& attr_adj_y, std::vector* grad_outputs) { auto a = op.input(0); auto b = op.input(1); diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 34c0a8fd54b4c4..6309080492c1da 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -54,7 +54,7 @@ absl::Status SoftmaxGrad(const Scope& scope, const Operation& op, REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); bool IsZero(const Scope& scope, const Output& grad) { - string op_type_name = grad.op().node()->type_string(); + std::string op_type_name = grad.op().node()->type_string(); if (op_type_name == "ZerosLike" || op_type_name == "Zeros") { return true; } @@ -204,7 +204,7 @@ REGISTER_GRADIENT_OP("L2Loss", L2LossGrad); absl::Status BiasAddGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string data_format; + std::string data_format; TF_RETURN_IF_ERROR( GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format)); auto dx_1 = @@ -218,9 +218,9 @@ REGISTER_GRADIENT_OP("BiasAdd", BiasAddGradHelper); absl::Status Conv2DGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string data_format; - string padding; - std::vector strides; + std::string data_format; + std::string padding; + std::vector strides; bool use_cudnn_on_gpu; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); @@ -245,10 +245,10 @@ REGISTER_GRADIENT_OP("Conv2D", Conv2DGrad); absl::Status MaxPoolGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string data_format; - string padding; - std::vector strides; - std::vector ksize; + std::string data_format; + std::string padding; + std::vector strides; + std::vector ksize; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); @@ -265,8 +265,8 @@ REGISTER_GRADIENT_OP("MaxPool", MaxPoolGradHelper); absl::Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string data_format; - string padding; + std::string data_format; + std::string padding; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); @@ -283,10 +283,10 @@ REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); absl::Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - std::vector ksize; - std::vector strides; - string padding; - string data_format; + std::vector ksize; + std::vector strides; + std::string padding; + std::string data_format; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); @@ -304,10 +304,10 @@ REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper); absl::Status AvgPoolGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - std::vector ksize; - std::vector strides; - string padding; - string data_format; + std::vector ksize; + std::vector strides; + std::string padding; + std::string data_format; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); @@ -325,10 +325,10 @@ REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper); absl::Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - std::vector ksize; - std::vector strides; - string padding; - string data_format; + std::vector ksize; + std::vector strides; + std::string padding; + std::string data_format; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 56ac37e86b7168..1d23f9d87e2d7d 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include #include +#include #include #include "absl/log/log.h" @@ -70,7 +72,7 @@ absl::Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { queue_runner_def.enqueue_op_name().begin(), queue_runner_def.enqueue_op_name().end()); size_t op_names_size = enqueue_op_names_.size(); - if (op_names_size > kint32max) { + if (op_names_size > std::numeric_limits::max()) { return absl::Status(absl::StatusCode::kInvalidArgument, "Enqueue ops to run cannot exceed kint32max"); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index a8dedd0e40997a..1722da0d390915 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -51,8 +51,8 @@ cc_library( "@local_xla//xla:status_macros", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/backends/cpu/runtime:convolution_lib", - "@local_xla//xla/backends/cpu/runtime:dot_lib", + "@local_xla//xla/backends/cpu/runtime:convolution_dims", + "@local_xla//xla/backends/cpu/runtime:dot_dims", "@local_xla//xla/backends/cpu/runtime:thunk_proto_cc", "@local_xla//xla/service/cpu:executable_proto_cc", "@local_xla//xla/tsl/platform:statusor", @@ -96,6 +96,7 @@ cc_library( ":thunk_proto_execution_deserializer", "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:allocator", + "//tensorflow/compiler/tf2xla:encoded_buffer_allocation_info", "//tensorflow/compiler/tf2xla:mlir_tf2xla", # fixdeps: keep "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -119,12 +120,13 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", - "@local_xla//xla:cpu_function_runtime", "@local_xla//xla:debug_options_flags", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/backends/cpu:buffer_allocation_info", + "@local_xla//xla/backends/cpu:buffer_allocation_info_util", "@local_xla//xla/backends/cpu/codegen:symbol_name_util", "@local_xla//xla/backends/cpu/runtime:thunk_proto_cc", "@local_xla//xla/backends/cpu/runtime:thunk_proto_serdes", @@ -132,7 +134,6 @@ cc_library( "@local_xla//xla/client:compile_only_client", "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/service:compiler", - "@local_xla//xla/service/cpu:buffer_info_util", "@local_xla//xla/service/cpu:cpu_aot_compilation_result", "@local_xla//xla/service/cpu:cpu_compiler", "@local_xla//xla/service/cpu:cpu_executable", @@ -155,7 +156,6 @@ tf_cc_test( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", # fixdeps: keep - "@local_xla//xla:cpu_function_runtime", "@local_xla//xla:shape_util", "@local_xla//xla/service/cpu:cpu_aot_compilation_result", ] + if_llvm_x86_available([ diff --git a/tensorflow/compiler/aot/aot_only_var_handle_op.cc b/tensorflow/compiler/aot/aot_only_var_handle_op.cc index 86666b073b0f71..f6293e0a2063bb 100644 --- a/tensorflow/compiler/aot/aot_only_var_handle_op.cc +++ b/tensorflow/compiler/aot/aot_only_var_handle_op.cc @@ -31,7 +31,7 @@ class XlaAotOnlyVarHandleOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override; private: - string name_; + std::string name_; }; XlaAotOnlyVarHandleOp::XlaAotOnlyVarHandleOp(OpKernelConstruction* c) diff --git a/tensorflow/compiler/aot/benchmark.cc b/tensorflow/compiler/aot/benchmark.cc index 43b9c06418c2e1..ee4af4ca65a20f 100644 --- a/tensorflow/compiler/aot/benchmark.cc +++ b/tensorflow/compiler/aot/benchmark.cc @@ -37,10 +37,10 @@ namespace benchmark { // // TODO(b/33546473): Refactor tensorflow::Env::NowMicros() so that we can re-use // the implementation without pulling in all of the Env dependencies. -static uint64 NowMicros() { +static uint64_t NowMicros() { struct timeval tv; gettimeofday(&tv, nullptr); - return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; } void DumpStatsToStdout(const Stats& stats) { diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 054a7fdde77bd9..87cb051b75df63 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -42,13 +43,14 @@ limitations under the License. #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/aot/thunk_proto_execution_deserializer.h" #include "tensorflow/compiler/tf2xla/allocator.h" +#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "xla/backends/cpu/buffer_allocation_info.h" +#include "xla/backends/cpu/buffer_allocation_info_util.h" #include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/backends/cpu/runtime/thunk_proto_serdes.h" -#include "xla/cpu_function_runtime.h" #include "xla/debug_options_flags.h" -#include "xla/service/cpu/buffer_info_util.h" #include "xla/service/cpu/cpu_aot_compilation_result.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/shape.h" @@ -65,43 +67,37 @@ namespace tfcompile { namespace { -using BufferInfo = xla::cpu_function_runtime::BufferInfo; - -bool IsAlpha(char c) { - return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); -} - -bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); } +using xla::cpu::BufferAllocationInfo; // Convert an XLA type into a C++ type. -absl::Status XLATypeToCpp(xla::PrimitiveType type, string* str) { +absl::Status XLATypeToCpp(xla::PrimitiveType type, std::string* str) { switch (type) { case xla::PRED: *str = "bool"; break; case xla::S8: - *str = "tensorflow::int8"; + *str = "int8_t"; break; case xla::S16: - *str = "tensorflow::int16"; + *str = "int16_t"; break; case xla::S32: - *str = "tensorflow::int32"; + *str = "int32_t"; break; case xla::S64: *str = "int64_t"; break; case xla::U8: - *str = "tensorflow::uint8"; + *str = "uint8_t"; break; case xla::U16: - *str = "tensorflow::uint16"; + *str = "uint16_t"; break; case xla::U32: - *str = "tensorflow::uint32"; + *str = "uint32_t"; break; case xla::U64: - *str = "tensorflow::uint64"; + *str = "uint64_t"; break; case xla::F32: *str = "float"; @@ -117,33 +113,36 @@ absl::Status XLATypeToCpp(xla::PrimitiveType type, string* str) { } // Returns the sum of the size of each buffer in `buffer_infos`. -size_t TotalBufferBytes(const std::vector& buffer_infos) { - return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0}, - [](size_t size, const BufferInfo& buffer_info) { - return size + buffer_info.size(); - }); +size_t TotalBufferBytes(absl::Span buffer_infos) { + return std::accumulate( + buffer_infos.begin(), buffer_infos.end(), size_t{0}, + [](size_t size, const BufferAllocationInfo& buffer_info) { + return size + buffer_info.size(); + }); } -// Returns a vector of BufferInfo instances in `buffer_infos` that are entry -// parameter buffers. -std::vector ExtractEntryParamBufferInfos( - const std::vector& buffer_infos) { - std::vector result; +// Returns a vector of BufferAllocationInfo instances in `buffer_infos` that are +// entry parameter buffers. +std::vector ExtractEntryParamBufferAllocationInfos( + absl::Span buffer_infos) { + std::vector result; std::copy_if(buffer_infos.begin(), buffer_infos.end(), - std::back_inserter(result), [](const BufferInfo& buffer_info) { + std::back_inserter(result), + [](const BufferAllocationInfo& buffer_info) { return buffer_info.is_entry_parameter(); }); return result; } -// Returns a vector of BufferInfo instances in `buffer_infos` that are temp -// buffers. -std::vector ExtractTempBufferInfos( - const std::vector& buffer_infos) { - std::vector result; +// Returns a vector of BufferAllocationInfo instances in `buffer_infos` that are +// temp buffers. +std::vector ExtractTempBufferAllocationInfos( + absl::Span buffer_infos) { + std::vector result; std::copy_if(buffer_infos.begin(), buffer_infos.end(), - std::back_inserter(result), [](const BufferInfo& buffer_info) { - return buffer_info.is_temp_buffer(); + std::back_inserter(result), + [](const BufferAllocationInfo& buffer_info) { + return buffer_info.is_temp(); }); return result; } @@ -152,11 +151,11 @@ std::vector ExtractTempBufferInfos( // are used to generate methods for args and results. absl::Status AddRewritesForShape( int i, const xla::Shape& shape, - std::vector>* rewrites) { - string type; + std::vector>* rewrites) { + std::string type; TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); - std::vector dim_vars; - string dim_sizes, indices; + std::vector dim_vars; + std::string dim_sizes, indices; int count = 1; if (shape.dimensions().size() == 0 || (shape.dimensions().size() == 1 && shape.dimensions(0) == 1)) { @@ -165,8 +164,8 @@ absl::Status AddRewritesForShape( } else { for (int dim = 0; dim < shape.dimensions().size(); ++dim) { dim_vars.push_back(absl::StrCat("size_t dim", dim)); - dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); - indices += absl::StrCat("[dim", dim, "]"); + absl::StrAppend(&dim_sizes, "[", shape.dimensions(dim), "]"); + absl::StrAppend(&indices, "[dim", dim, "]"); count *= shape.dimensions(dim); } } @@ -187,8 +186,9 @@ absl::Status AddRewritesForShape( // TODO(toddw): If this becomes a problem, we should be able to change the // algorithm to O(N) by using a state machine, e.g. regexps or a real // text-templating mechanism. -string RewriteWithName(const string& name, string code, - const std::vector>& rewrites) { +std::string RewriteWithName( + const std::string& name, std::string code, + const std::vector>& rewrites) { absl::StrReplaceAll(rewrites, &code); absl::StrReplaceAll({{"{{NAME}}", name}}, &code); return code; @@ -198,7 +198,7 @@ string RewriteWithName(const string& name, string code, absl::Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, const CompileResult& compile_result, - string* methods) { + std::string* methods) { const int num_args = ps.parameters_size(); // feed_size() + variable_size() is the maximum number of args as an // implementation may not create an argument for an unused variable. @@ -208,11 +208,11 @@ absl::Status GenArgMethods(const tf2xla::Config& config, config.variable_size(), ") and num_args(", num_args, ")"); } for (int i = 0; i < config.feed_size(); ++i) { - std::vector> rewrites; + std::vector> rewrites; TF_ASSIGN_OR_RETURN(xla::Shape shape, xla::Shape::FromProto(ps.parameters(i))); TF_RETURN_IF_ERROR(AddRewritesForShape(i, shape, &rewrites)); - const string code = R"( + const std::string code = R"( void set_arg{{NAME}}_data(const void* data) { set_arg_data({{I}}, data); } @@ -248,7 +248,7 @@ absl::Status GenArgMethods(const tf2xla::Config& config, // Generate methods for results (outputs). absl::Status GenResultMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, - string* methods) { + std::string* methods) { if (ps.result().element_type() != xla::TUPLE) { // The XlaCompiler we use to build the xla computation always generates a // tuple result, and we rely on this to simplify code generation. @@ -267,11 +267,11 @@ absl::Status GenResultMethods(const tf2xla::Config& config, ps.result().tuple_shapes_size(), ")"); } for (int i = 0; i < config.fetch_size(); ++i) { - std::vector> rewrites; + std::vector> rewrites; TF_ASSIGN_OR_RETURN(xla::Shape shape, xla::Shape::FromProto(ps.result().tuple_shapes(i))); TF_RETURN_IF_ERROR(AddRewritesForShape(i, shape, &rewrites)); - string code = R"( + std::string code = R"( {{TYPE}}* result{{NAME}}_data() { return static_cast<{{TYPE}}*>(result_data({{I}})); } @@ -304,14 +304,14 @@ absl::Status GenResultMethods(const tf2xla::Config& config, // Generate methods for variables. absl::Status GenVariableMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, - string* methods) { + std::string* methods) { const int num_args = ps.parameters_size(); for (int i = config.feed_size(); i < num_args; ++i) { - std::vector> rewrites; + std::vector> rewrites; TF_ASSIGN_OR_RETURN(xla::Shape shape, xla::Shape::FromProto(ps.parameters(i))); TF_RETURN_IF_ERROR(AddRewritesForShape(i, shape, &rewrites)); - const string code = R"( + const std::string code = R"( void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) { set_arg_data({{I}}, data); } @@ -345,7 +345,8 @@ absl::Status GenVariableMethods(const tf2xla::Config& config, } // Generate shape infos for args (inputs). -absl::Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { +absl::Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, + std::string* infos) { for (int i = 0; i < ps.parameters_size(); ++i) { const xla::ShapeProto& shape = ps.parameters(i); if (shape.element_type() == xla::TUPLE) { @@ -383,7 +384,7 @@ absl::Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { // Generate shape infos for results. absl::Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, - string* infos) { + std::string* infos) { if (ps.result().element_type() != xla::TUPLE) { return absl::InternalError("codegen requires the XLA result to be a tuple"); } @@ -417,7 +418,7 @@ absl::Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, // tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style // string literal in the array, with nullptr terminating the array. template -string GenNameToIndexCode(const T& entries, bool generate) { +std::string GenNameToIndexCode(const T& entries, bool generate) { // No need for a static array if we're not supposed to generate the data. if (!generate) { return "{\n return nullptr;\n }"; @@ -432,7 +433,7 @@ string GenNameToIndexCode(const T& entries, bool generate) { end = i; } // Emit string literals up to the last non-empty name. - string code = "{\n static const char* kNames[] = {"; + std::string code = "{\n static const char* kNames[] = {"; for (int i = 0; i < end; ++i) { if (i > 0) { code += ", "; @@ -471,25 +472,24 @@ absl::Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { } // Returns a list of C++ expressions that, when executed, will construct the -// BufferInfo instances in `buffer_infos`. -std::vector BufferInfosToCppExpression( - const std::vector& buffer_infos) { - std::vector buffer_infos_as_strings; - std::transform(buffer_infos.begin(), buffer_infos.end(), - std::back_inserter(buffer_infos_as_strings), - [](const BufferInfo& buffer_info) { - xla::cpu_function_runtime::EncodedBufferInfo encoded = - buffer_info.Encode(); - auto param_to_str = [](uint32_t param) -> std::string { - return param == ~0U ? "~0U" : absl::StrCat(param, "U"); - }; - return absl::StrCat( - "::xla::cpu_function_runtime::BufferInfo(" - "::xla::cpu_function_runtime::EncodedBufferInfo{", - encoded.packed_kind_and_size, "ULL, ", - param_to_str(encoded.entry_param_number), ", ", - param_to_str(encoded.result_param_number), "})"); - }); +// BufferAllocationInfo instances in `buffer_infos`. +std::vector BufferAllocationInfosToCppExpression( + absl::Span buffer_infos) { + std::vector buffer_infos_as_strings; + absl::c_transform( + buffer_infos, std::back_inserter(buffer_infos_as_strings), + [](const BufferAllocationInfo& buffer_info) { + xla::cpu::EncodedBufferAllocationInfo encoded(buffer_info); + auto param_to_str = [](int32_t param) -> std::string { + return param == -1 ? "~0U" : absl::StrCat(param, "U"); + }; + return absl::StrCat( + "static_cast<::xla::cpu::BufferAllocationInfo>(" + "::xla::cpu::EncodedBufferAllocationInfo{", + encoded.packed_kind_and_size, "ULL, ", + param_to_str(encoded.entry_param_number), ", ", + param_to_str(encoded.result_number), "})"); + }); return buffer_infos_as_strings; } @@ -659,8 +659,8 @@ absl::Status ExtendRewrites( const std::string function_declarations_from_obj_files, GenFunctionDeclarations(absl::MakeSpan(entry_point_symbols))); - const int64_t buffer_infos_size = aot_thunks->buffer_infos().size(); - const std::optional temp_allocation_index = + int64_t buffer_infos_size = aot_thunks->buffer_allocation_infos().size(); + std::optional temp_allocation_index = aot_thunks->temp_allocation_index(); if (temp_allocation_index.has_value() && (*temp_allocation_index < 0 || @@ -670,45 +670,36 @@ absl::Status ExtendRewrites( " is outside the range of temp sizes: [0,", buffer_infos_size, ")")); } - const bool xla_cpu_multi_thread_eigen = - xla::GetDebugOptionsFromFlags().xla_cpu_multi_thread_eigen(); - std::vector runtime_specific_includes = {R"( #include "absl/log/check.h" +#include "absl/synchronization/blocking_counter.h" #include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/types.h")"}; if (HasThunkKind(aot_thunks->proto().thunk_sequence(), xla::cpu::ThunkProto::kDotThunk)) { - if (xla_cpu_multi_thread_eigen) { - runtime_specific_includes.push_back( - R"(#include "xla/service/cpu/runtime_matmul.h")"); - } runtime_specific_includes.push_back( - R"(#include "xla/service/cpu/runtime_single_threaded_matmul.h")"); + R"(#include "xla/backends/cpu/runtime/dot_lib.h")"); } if (HasThunkKind(aot_thunks->proto().thunk_sequence(), xla::cpu::ThunkProto::kConvolutionThunk)) { - if (xla_cpu_multi_thread_eigen) { - runtime_specific_includes.push_back( - R"(#include "xla/service/cpu/runtime_conv2d.h")"); - } - runtime_specific_includes.push_back( - R"(#include "xla/service/cpu/runtime_single_threaded_conv2d.h")"); + R"(#include "absl/synchronization/notification.h")"); + runtime_specific_includes.push_back( + R"(#include "xla/backends/cpu/runtime/convolution_lib.h")"); } if (HasThunkKind(aot_thunks->proto().thunk_sequence(), xla::cpu::ThunkProto::kSortThunk)) { runtime_specific_includes.push_back( - R"(#include "xla/service/cpu/runtime_key_value_sort.h")"); + R"(#include "xla/backends/cpu/runtime/sort_lib.h")"); } if (HasThunkKind(aot_thunks->proto().thunk_sequence(), xla::cpu::ThunkProto::kTopKThunk)) { runtime_specific_includes.push_back( - R"(#include "xla/service/cpu/runtime_topk.h")"); + R"(#include "xla/backends/cpu/runtime/topk_lib.h")"); } TF_ASSIGN_OR_RETURN( @@ -834,31 +825,32 @@ absl::Status ExtendRewrites( absl::Status GenerateHeader( const CodegenOpts& opts, const tf2xla::Config& config, const CompileResult& compile_result, const MetadataResult& metadata_result, - const EmbeddedConstantBuffers& embedded_constant_buffers, string* header) { + const EmbeddedConstantBuffers& embedded_constant_buffers, + std::string* header) { TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); - const std::vector& buffer_infos = - compile_result.aot->buffer_infos(); + absl::Span buffer_infos = + compile_result.aot->buffer_allocation_infos(); - const std::vector arg_index_table = - ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); - const std::vector result_index_table = - ::xla::cpu::CreateResultIndexTableFromBufferInfos(buffer_infos); - std::vector buffer_infos_as_strings = - BufferInfosToCppExpression(buffer_infos); + const std::vector arg_index_table = + ::xla::cpu::CreateArgIndexTable(buffer_infos); + const std::vector result_index_table = + ::xla::cpu::CreateResultIndexTable(buffer_infos); + std::vector buffer_infos_as_strings = + BufferAllocationInfosToCppExpression(buffer_infos); // Compute sizes and generate methods. - std::vector buffer_infos_for_args = - ExtractEntryParamBufferInfos(buffer_infos); - std::vector buffer_infos_for_temps = - ExtractTempBufferInfos(buffer_infos); + std::vector buffer_infos_for_args = + ExtractEntryParamBufferAllocationInfos(buffer_infos); + std::vector buffer_infos_for_temps = + ExtractTempBufferAllocationInfos(buffer_infos); const xla::ProgramShapeProto& ps = compile_result.program_shape; - string methods_arg, methods_result, methods_variable; + std::string methods_arg, methods_result, methods_variable; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable)); - string arg_shape_infos, result_shape_infos; + std::string arg_shape_infos, result_shape_infos; TF_RETURN_IF_ERROR(GenArgShapeInfos(ps, &arg_shape_infos)); TF_RETURN_IF_ERROR( CheckEqual(ps.parameters_size(), arg_index_table.size(), @@ -868,29 +860,29 @@ absl::Status GenerateHeader( CheckEqual(ps.result().tuple_shapes_size(), result_index_table.size(), "Result number mismatch, proto vs. result_index_table")); TF_ASSIGN_OR_RETURN(auto program_shape, xla::ProgramShape::FromProto(ps)); - const size_t arg_bytes_aligned = tensorflow::AlignedBufferBytes( - buffer_infos_for_args.data(), buffer_infos_for_args.size(), - /*allocate_entry_params=*/true); + const size_t arg_bytes_aligned = + tensorflow::AlignedBufferBytes(buffer_infos_for_args, + /*allocate_entry_params=*/true); const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args); - const size_t temp_bytes_aligned = tensorflow::AlignedBufferBytes( - buffer_infos_for_temps.data(), buffer_infos_for_temps.size(), - /*allocate_entry_params=*/true); + const size_t temp_bytes_aligned = + tensorflow::AlignedBufferBytes(buffer_infos_for_temps, + /*allocate_entry_params=*/true); const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps); // Create rewrite strings for namespace start and end. - string ns_start; - for (const string& n : opts.namespaces) { + std::string ns_start; + for (const std::string& n : opts.namespaces) { ns_start += absl::StrCat("namespace ", n, " {\n"); } ns_start += "\n"; - string ns_end("\n"); + std::string ns_end("\n"); for (int i = opts.namespaces.size() - 1; i >= 0; --i) { - const string& n = opts.namespaces[i]; + const std::string& n = opts.namespaces[i]; ns_end += absl::StrCat("} // end namespace ", n, "\n"); } // Generate metadata. - const string arg_names_code = + const std::string arg_names_code = GenNameToIndexCode(config.feed(), opts.gen_name_to_index); auto variable_copy = config.variable(); @@ -899,12 +891,12 @@ absl::Status GenerateHeader( var.set_name(var.node_name()); } } - const string variable_names_code = + const std::string variable_names_code = GenNameToIndexCode(variable_copy, opts.gen_name_to_index); - const string result_names_code = + const std::string result_names_code = GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); - const string include_xla_data_proto = + const std::string include_xla_data_proto = opts.gen_program_shape ? R"(#include "xla/xla_data.pb.h")" : ""; @@ -980,7 +972,7 @@ class {{CLASS}} final : public tensorflow::{{COMPUTATION_CLASS_BASE}} { // Byte size of each argument buffer. There are kNumArgs entries. static const ::int64_t ArgSize(::tensorflow::int32 index) { - return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); + return BufferAllocationInfos()[ArgIndexToBufferIndex()[index]].size(); } // Returns static data used to create an XlaCompiledCpuFunction. @@ -989,7 +981,7 @@ class {{CLASS}} final : public tensorflow::{{COMPUTATION_CLASS_BASE}} { XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; set_static_data_function_library_symbol_map(data, FunctionLibrarySymbolMap()); - set_static_data_buffer_infos(data, BufferInfos()); + set_static_data_buffer_infos(data, BufferAllocationInfos()); set_static_data_num_buffers(data, kNumBuffers); set_static_data_result_index_table(data, ResultIndexToBufferIndex()); set_static_data_num_results(data, kNumResults); @@ -1081,12 +1073,12 @@ class {{CLASS}} final : public tensorflow::{{COMPUTATION_CLASS_BASE}} { // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = {{NUM_BUFFERS}}; - static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() { - static const ::xla::cpu_function_runtime::BufferInfo - kBufferInfos[kNumBuffers] = { + static const ::xla::cpu::BufferAllocationInfo* BufferAllocationInfos() { + static const ::xla::cpu::BufferAllocationInfo + kBufferAllocationInfos[kNumBuffers] = { {{BUFFER_INFOS_AS_STRING}} }; - return kBufferInfos; + return kBufferAllocationInfos; } static const ::tensorflow::int32* ResultIndexToBufferIndex() { @@ -1153,7 +1145,7 @@ class {{CLASS}} final : public tensorflow::{{COMPUTATION_CLASS_BASE}} { } // The replacement strategy is naive, but good enough for our purposes. - std::vector> rewrites = { + std::vector> rewrites = { {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, @@ -1192,10 +1184,10 @@ class {{CLASS}} final : public tensorflow::{{COMPUTATION_CLASS_BASE}} { return absl::OkStatus(); } -static string CreateUniqueIdentifier(const CodegenOpts& opts, - absl::string_view suffix) { - string result = "__tfcompile"; - for (const string& n : opts.namespaces) { +static std::string CreateUniqueIdentifier(const CodegenOpts& opts, + absl::string_view suffix) { + std::string result = "__tfcompile"; + for (const std::string& n : opts.namespaces) { absl::StrAppend(&result, "_", n); } @@ -1301,14 +1293,15 @@ absl::Status GenerateMetadata(const CodegenOpts& opts, return absl::OkStatus(); } -absl::Status ParseCppClass(const string& cpp_class, string* class_name, - std::vector* namespaces) { +absl::Status ParseCppClass(const std::string& cpp_class, + std::string* class_name, + std::vector* namespaces) { class_name->clear(); namespaces->clear(); if (cpp_class.empty()) { return errors::InvalidArgument("empty cpp_class: " + cpp_class); } - std::vector parts = absl::StrSplit(cpp_class, "::"); + std::vector parts = absl::StrSplit(cpp_class, "::"); if (parts.front().empty()) { // Allow a fully qualified name that starts with "::". parts.erase(parts.begin()); @@ -1341,11 +1334,11 @@ absl::Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) { // implementation-defined characters`. We disallow those here to give // better error messages, at the expensive of being more restrictive than // the standard. - if (ident[0] != '_' && !IsAlpha(ident[0])) { + if (ident[0] != '_' && !absl::ascii_isalpha(ident[0])) { return errors::InvalidArgument("illegal leading char: ", msg); } for (size_t pos = 1; pos < ident.size(); ++pos) { - if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) { + if (ident[pos] != '_' && !absl::ascii_isalnum(ident[pos])) { return errors::InvalidArgument("illegal char: ", msg); } } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 77300b0fde4e3d..ff7d96720b4eba 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -32,14 +32,14 @@ namespace tfcompile { // and the generated metadata object file. struct CodegenOpts { // The name of the generated C++ class, wrapping the generated function. - string class_name; + std::string class_name; // Target triple for the architecture we're targeting. - string target_triple; + std::string target_triple; // Namespaces specifies a list of C++ namespaces to add to the generated // header. If empty, all symbols will be in the global namespace. - std::vector namespaces; + std::vector namespaces; // If true, generate name-to-index data for Lookup{Arg,Result}Index methods. bool gen_name_to_index = false; @@ -62,27 +62,27 @@ struct CodegenOpts { struct MetadataResult { // These are top level "extern C" declarations that are expected to be visible // wherever program_shape_access_shim is emitted. - std::vector header_variable_decls; + std::vector header_variable_decls; // program_shape_access_shim is a C++ expression that constructs the // xla::ProgramShapeProto instance for the CompileResult passed to // GenerateMetadata. - string program_shape_access_shim; + std::string program_shape_access_shim; // hlo_profile_printer_data_access_shim is a C++ expression that constructs // the xla::HloProfilePrinterData instance for the CompileResult passed to // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a // C++ expression that evaluates to nullptr at runtime. // This is set only for AOT legacy. - string hlo_profile_printer_data_access_shim; + std::string hlo_profile_printer_data_access_shim; // cpu_executable_access_shim is a C++ expression that constructs // a protobuf required to construct a CpuExecutable. // This is set only for AOT thunks. - string cpu_executable_access_shim; + std::string cpu_executable_access_shim; // The contents of the object (".o") file. - string object_file_data; + std::string object_file_data; }; // Generates a set of constant buffers embedded into an object file. @@ -105,14 +105,16 @@ absl::Status GenerateMetadata(const CodegenOpts& opts, absl::Status GenerateHeader( const CodegenOpts& opts, const tf2xla::Config& config, const CompileResult& compile_result, const MetadataResult& metadata_result, - const EmbeddedConstantBuffers& embedded_constant_buffers, string* header); + const EmbeddedConstantBuffers& embedded_constant_buffers, + std::string* header); // ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` // components. The syntax is [[::],...]. This // mirrors the C++ syntax for referring to a class, where multiple namespaces // may precede the class name, separated by double-colons. -absl::Status ParseCppClass(const string& cpp_class, string* class_name, - std::vector* namespaces); +absl::Status ParseCppClass(const std::string& cpp_class, + std::string* class_name, + std::vector* namespaces); // ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is // appended to error messages. diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index afa5a86af9ef47..ec0f336d87f716 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/aot/compile.h" -#include "xla/cpu_function_runtime.h" #include "xla/service/cpu/cpu_aot_compilation_result.h" #include "xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -40,7 +39,7 @@ namespace tensorflow { namespace tfcompile { namespace { -using ::xla::cpu_function_runtime::BufferInfo; +using ::xla::cpu::BufferAllocationInfo; void ExpectErrorContains(const absl::Status& status, absl::string_view str) { EXPECT_NE(absl::OkStatus(), status); @@ -54,7 +53,7 @@ TEST(ValidateCppIdent, Simple) { TF_EXPECT_OK(ValidateCppIdent("_abc", "")); TF_EXPECT_OK(ValidateCppIdent("_abc123", "")); // Make sure we didn't skip a valid letter or digit - string ident; + std::string ident; for (char c = 'a'; c <= 'z'; c++) { ident.append(1, c); } @@ -79,18 +78,19 @@ TEST(ValidateCppIdent, Simple) { class ParseCppClassTest : public ::testing::Test { protected: - void ExpectOK(const string& cpp_class, const string& want_class_name, - const std::vector& want_namespaces) { - string class_name; - std::vector namespaces; + void ExpectOK(const std::string& cpp_class, + const std::string& want_class_name, + const std::vector& want_namespaces) { + std::string class_name; + std::vector namespaces; TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces)); EXPECT_EQ(class_name, want_class_name); EXPECT_EQ(namespaces, want_namespaces); } - void ExpectFail(const string& cpp_class) { - string class_name; - std::vector namespaces; + void ExpectFail(const std::string& cpp_class) { + std::string class_name; + std::vector namespaces; EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), absl::OkStatus()) << cpp_class; @@ -111,7 +111,7 @@ TEST_F(ParseCppClassTest, ParseOK) { ExpectOK("::_foo::MyClass", "MyClass", {"_foo"}); ExpectOK("::_foo::_MyClass", "_MyClass", {"_foo"}); // Make sure we didn't skip a valid letter or digit - string ident; + std::string ident; for (char c = 'a'; c <= 'z'; c++) { ident.append(1, c); } @@ -144,10 +144,10 @@ TEST_F(ParseCppClassTest, ParseFail) { } static void CompareWithGoldenFile( - const string& tensorflow_relative_golden_file_name, - const string& expected_contents, bool ignore_cr) { + const std::string& tensorflow_relative_golden_file_name, + const std::string& expected_contents, bool ignore_cr) { // Get rid of all CR characters, we may be running under windows. - string sanitized_expected_contents(expected_contents); + std::string sanitized_expected_contents(expected_contents); if (ignore_cr) { sanitized_expected_contents.erase( std::remove(sanitized_expected_contents.begin(), @@ -160,7 +160,7 @@ static void CompareWithGoldenFile( // blaz test --test_strategy=local \ // "third_party/tensorflow/compiler/aot:codegen_test" const bool update_golden = false; - string golden_file_name = + std::string golden_file_name = GetDataDependencyFilepath(tensorflow_relative_golden_file_name); if (update_golden) { @@ -168,7 +168,7 @@ static void CompareWithGoldenFile( WriteStringToFile(Env::Default(), golden_file_name, expected_contents)); } - string golden_file_contents; + std::string golden_file_contents; TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name, &golden_file_contents)); if (ignore_cr) { diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 7d0897829b98ca..48c92bf346926f 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -212,7 +212,7 @@ absl::Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, return CompileXla(client, computation, aot_opts, compile_result); } -static absl::Status ReadProtoFile(const string& fname, +static absl::Status ReadProtoFile(const std::string& fname, protobuf::Message* proto) { if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); @@ -297,7 +297,7 @@ absl::Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config)); TF_RETURN_IF_ERROR(ValidateConfig(config)); if (flags.dump_fetch_nodes) { - std::set nodes; + std::set nodes; for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } @@ -368,7 +368,7 @@ absl::Status Main(const MainFlags& flags) { GenerateMetadata(codegen_opts, compile_result, &metadata_result)); TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object, metadata_result.object_file_data)); - string header; + std::string header; TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result, metadata_result, embedded_constant_buffers, &header)); diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 303854f40ed88c..2a0418126b8aaf 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -38,7 +38,7 @@ struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; xla::ProgramShapeProto program_shape; // Static shape of args and results. - string entry_point; // Name of generated function. + std::string entry_point; // Name of generated function. int pointer_size = 0; // Size of a pointer in bytes. }; diff --git a/tensorflow/compiler/aot/embedded_constant_buffers.cc b/tensorflow/compiler/aot/embedded_constant_buffers.cc index 987dac62bca0fe..b56ca80e26e875 100644 --- a/tensorflow/compiler/aot/embedded_constant_buffers.cc +++ b/tensorflow/compiler/aot/embedded_constant_buffers.cc @@ -118,8 +118,8 @@ static absl::StatusOr CodegenModule( static absl::StatusOr> GetTargetMachineFromTriple(absl::string_view target_triple) { std::string error; - std::string normalized_triple = - llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); + llvm::Triple normalized_triple( + llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)))); const llvm::Target* target = llvm::TargetRegistry::lookupTarget(normalized_triple, error); if (target == nullptr) { @@ -128,7 +128,7 @@ GetTargetMachineFromTriple(absl::string_view target_triple) { } return absl::WrapUnique(target->createTargetMachine( - llvm::Triple(normalized_triple), /*CPU=*/"", + normalized_triple, /*CPU=*/"", /*Features=*/"", llvm::TargetOptions(), std::nullopt)); } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index ae5b62ccec01a9..1626686ba465ad 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -41,9 +41,9 @@ using xla::llvm_ir::AsStringRef; static void AddEmbeddedProtocolBufferToLlvmModule( llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, - absl::string_view unique_identifier, string* protobuf_array_symbol_name, - int64_t* protobuf_array_size) { - string protobuf_array_contents = proto.SerializeAsString(); + absl::string_view unique_identifier, + std::string* protobuf_array_symbol_name, int64_t* protobuf_array_size) { + std::string protobuf_array_contents = proto.SerializeAsString(); *protobuf_array_symbol_name = absl::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); @@ -58,10 +58,10 @@ static void AddEmbeddedProtocolBufferToLlvmModule( protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); } -static string CreateCPPShimExpression( +static std::string CreateCPPShimExpression( absl::string_view qualified_cpp_protobuf_name, absl::string_view protobuf_array_symbol_name, int64_t protobuf_array_size) { - string code = + std::string code = "[]() {\n" " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n" " proto->ParseFromArray(&{{ARRAY_SYMBOL}}[0], {{ARRAY_SIZE}});\n" @@ -77,7 +77,7 @@ static string CreateCPPShimExpression( }); } -static absl::StatusOr CodegenModule( +static absl::StatusOr CodegenModule( llvm::TargetMachine* target_machine, std::unique_ptr module) { llvm::SmallVector stream_buffer; llvm::raw_svector_ostream ostream(stream_buffer); @@ -91,14 +91,14 @@ static absl::StatusOr CodegenModule( codegen_passes.run(*module); - return string(stream_buffer.begin(), stream_buffer.end()); + return std::string(stream_buffer.begin(), stream_buffer.end()); } static absl::StatusOr> GetTargetMachineFromTriple(absl::string_view target_triple) { std::string error; - std::string normalized_triple = - llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); + llvm::Triple normalized_triple( + llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)))); const llvm::Target* target = llvm::TargetRegistry::lookupTarget(normalized_triple, error); if (target == nullptr) { @@ -107,7 +107,7 @@ GetTargetMachineFromTriple(absl::string_view target_triple) { } return absl::WrapUnique(target->createTargetMachine( - llvm::Triple(normalized_triple), /*CPU=*/"", + normalized_triple, /*CPU=*/"", /*Features=*/"", llvm::TargetOptions(), std::nullopt)); } @@ -124,9 +124,9 @@ absl::StatusOr CreateEmbeddedProtocolBuffers( EmbeddedProtocolBuffers result; for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) { - string cpp_shim, cpp_variable_decl; + std::string cpp_shim, cpp_variable_decl; if (protobuf_to_embed.message) { - string protobuf_array_symbol_name; + std::string protobuf_array_symbol_name; int64_t protobuf_array_size; AddEmbeddedProtocolBufferToLlvmModule( diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 0af4d4a3362f8c..aa3553f3b6a85b 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -37,11 +37,11 @@ struct EmbeddedProtocolBuffers { struct CPPShim { // `expression` is a C++ expression that creates an instance of said // protocol buffer when executed. - string expression; + std::string expression; // `variable_decl` is an "extern C" array declaration that is used in // `expression`. It must be visible wherever `expression` is emitted. - string variable_decl; + std::string variable_decl; }; // Each cpp_shim corresponds to one embedded protocol buffer. @@ -50,20 +50,20 @@ struct EmbeddedProtocolBuffers { // The contents of the object (".o") file the protocol buffers are embbed in. // This needs to be linked in to any program that wants to execute any of the // expressions in `cpp_shims`. - string object_file_data; + std::string object_file_data; }; // Describes a protocol buffer to embed into an object file. struct ProtobufToEmbed { // `symbol_prefix` is prefix that is guaranteed to be unique across the binary // or DSO the generated object file will be linked into. - string symbol_prefix; + std::string symbol_prefix; // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ // namespace qualified) protocol buffer name. This is only used in // CPPShim::expression so relatively qualified names are fine as long as // they're valid wherever CPPShim::expression is emitted. - string qualified_cpp_protobuf_name; + std::string qualified_cpp_protobuf_name; // `message` is the protocol buffer to be embedded. It is allowed to be // nullptr, in which case the generated C++ shim expression is just `nullptr`, diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index 9a3f2900dbafe4..5d0f93f7d67b88 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -27,27 +27,27 @@ namespace tfcompile { // Flags for the tfcompile binary. See *.cc file for descriptions. struct MainFlags { - string graph; - string debug_info; - string debug_info_path_begin_marker; - string config; + std::string graph; + std::string debug_info; + std::string debug_info_path_begin_marker; + std::string config; bool dump_fetch_nodes = false; - string target_triple; - string target_cpu; - string target_features; - string entry_point; - string cpp_class; - string out_function_object; - string out_metadata_object; - string out_header; - string out_constant_buffers_object; - string out_session_module; - string mlir_components; + std::string target_triple; + std::string target_cpu; + std::string target_features; + std::string entry_point; + std::string cpp_class; + std::string out_function_object; + std::string out_metadata_object; + std::string out_header; + std::string out_constant_buffers_object; + std::string out_session_module; + std::string mlir_components; bool experimental_quantize = false; // Sanitizer pass options bool sanitize_dataflow = false; - string sanitize_abilists_dataflow; + std::string sanitize_abilists_dataflow; // C++ codegen options bool gen_name_to_index = false; diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 8caeec32b7bc5e..67fea2e6a022c1 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -63,6 +63,7 @@ def _tfcompile_model_library_rule_impl(ctx): "--xla_cpu_fast_math_honor_functions=false " + "--xla_cpu_fast_math_honor_division=false " + "--xla_cpu_enable_fast_min_max=true " + + "--xla_cpu_experimental_ynn_fusion_type= " + additional_xla_flags + " " + "$${XLA_FLAGS:-}' "), "CUDA_VISIBLE_DEVICES": "", @@ -321,10 +322,11 @@ def _tf_library( # include_standard_runtime_deps is False. Without them, the # generated code will fail to compile. "//third_party/absl/log:check", + "//third_party/absl/synchronization", + "//tensorflow/core:framework_lite", "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", "@local_xla//xla:types", "@local_xla//xla/backends/cpu/runtime:kernel_c_api", - "//tensorflow/core:framework_lite", "@local_xla//xla/backends/cpu/runtime:rng_state_lib", ] + (need_xla_data_proto and [ # If we're generating the program shape, we must depend on the @@ -335,12 +337,11 @@ def _tf_library( ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually # needed. - "@local_xla//xla/service/cpu:runtime_conv2d", - "@local_xla//xla/service/cpu:runtime_custom_call_status", - "@local_xla//xla/service/cpu:runtime_key_value_sort", + "@local_xla//xla/backends/cpu/runtime:dot_lib", + "@local_xla//xla/backends/cpu/runtime:sort_lib", + "@local_xla//xla/backends/cpu/runtime:topk_lib", + "@local_xla//xla/backends/cpu/runtime:convolution_lib", "@local_xla//xla/service/cpu:runtime_matmul", - "@local_xla//xla/service/cpu:runtime_topk", - "@local_xla//xla/service/cpu:runtime_single_threaded_conv2d", "@local_xla//xla/service/cpu:runtime_single_threaded_matmul", "@eigen_archive//:eigen3", ] or []) + (use_xla_nanort_runtime and [ diff --git a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc index d2ced20a8d5eec..485bfd36dfa0a5 100644 --- a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc +++ b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc @@ -28,8 +28,8 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" -#include "xla/backends/cpu/runtime/convolution_lib.h" -#include "xla/backends/cpu/runtime/dot_lib.h" +#include "xla/backends/cpu/runtime/convolution_dims.h" +#include "xla/backends/cpu/runtime/dot_dims.h" #include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/layout_util.h" #include "xla/service/cpu/executable.pb.h" @@ -127,32 +127,23 @@ ThunkProtoExecutionDeserializer::ThunkSpecificRunImplFromThunkSequence( } absl::StatusOr ThunkProtoExecutionDeserializer::GetMatmulFunction( - xla::PrimitiveType xla_type, bool is_single_threaded) { + xla::PrimitiveType xla_type) { switch (xla_type) { case xla::F16: - return is_single_threaded - ? "__xla_cpu_runtime_EigenSingleThreadedMatMulF16" - : "__xla_cpu_runtime_EigenMatMulF16"; + return "::xla::cpu::internal::TypedMatMul"; case xla::F32: - return is_single_threaded - ? "__xla_cpu_runtime_EigenSingleThreadedMatMulF32" - : "__xla_cpu_runtime_EigenMatMulF32"; + return "::xla::cpu::internal::TypedMatMul"; case xla::F64: - return is_single_threaded - ? "__xla_cpu_runtime_EigenSingleThreadedMatMulF64" - : "__xla_cpu_runtime_EigenMatMulF64"; + return "::xla::cpu::internal::TypedMatMul"; case xla::C64: - return is_single_threaded - ? "__xla_cpu_runtime_EigenSingleThreadedMatMulC64" - : "__xla_cpu_runtime_EigenMatMulC64"; + return "::xla::cpu::internal::TypedMatMul, " + "std::complex, std::complex"; case xla::C128: - return is_single_threaded - ? "__xla_cpu_runtime_EigenSingleThreadedMatMulC128" - : "__xla_cpu_runtime_EigenMatMulC128"; + return "::xla::cpu::internal::TypedMatMul, " + "std::complex, std::complex"; case xla::S32: - return is_single_threaded - ? "__xla_cpu_runtime_EigenSingleThreadedMatMulS32" - : "__xla_cpu_runtime_EigenMatMulS32"; + return "::xla::cpu::internal::TypedMatMul"; default: return xla::Internal("Unsupported xla type: %d", xla_type); } @@ -166,43 +157,23 @@ absl::StatusOr ThunkProtoExecutionDeserializer::GetDotThunkRunImpl( } const xla::cpu::DotThunkProto& dot_thunk = thunk.dot_thunk(); - absl::string_view dot_thunk_invocation_format = xla_cpu_multi_thread_eigen_ - ? R"( + absl::string_view dot_thunk_invocation_format = R"( // Dot Thunk { + absl::BlockingCounter done({{BATCH_SIZE}}); for (int64_t i = 0; i < {{BATCH_SIZE}}; ++i) { - if (run_options->intra_op_thread_pool() != nullptr) { - {{MATMUL_FUNCTION}}( - run_options, - {{OUTPUT_PTR}} + {{OUTPUT_STRIDE}} * i, - {{LHS_PTR}} + {{LHS_STRIDE}} * i, - {{RHS_PTR}} + {{RHS_STRIDE}} * i, - {{M}}, {{N}}, {{K}}, {{TRANSPOSE_LHS}}, {{TRANSPOSE_RHS}}); - } else { - {{SINGLE_THREADED_MATMUL_FUNCTION}}( - nullptr, - {{OUTPUT_PTR}} + {{OUTPUT_STRIDE}} * i, - {{LHS_PTR}} + {{LHS_STRIDE}} * i, - {{RHS_PTR}} + {{RHS_STRIDE}} * i, - {{M}}, {{N}}, {{K}}, {{TRANSPOSE_LHS}}, {{TRANSPOSE_RHS}}); - } + {{MATMUL_FUNCTION}}( + run_options->intra_op_thread_pool(), + {{OUTPUT_PTR}} + {{OUTPUT_STRIDE}} * i, + {{LHS_PTR}} + {{LHS_STRIDE}} * i, + {{RHS_PTR}} + {{RHS_STRIDE}} * i, + {{M}}, {{N}}, {{K}}, {{TRANSPOSE_LHS}}, {{TRANSPOSE_RHS}}, + [&done] { done.DecrementCount(); } + ); } + done.Wait(); } - )" - : - R"( - // Dot Thunk - { - for (int64_t i = 0; i < {{BATCH_SIZE}}; ++i) { - {{SINGLE_THREADED_MATMUL_FUNCTION}}( - nullptr, - {{OUTPUT_PTR}} + {{OUTPUT_STRIDE}} * i, - {{LHS_PTR}} + {{LHS_STRIDE}} * i, - {{RHS_PTR}} + {{RHS_STRIDE}} * i, - {{M}}, {{N}}, {{K}}, {{TRANSPOSE_LHS}}, {{TRANSPOSE_RHS}}); - } - } - )"; + )"; if (!(dot_thunk.lhs_buffer_shape().shape().element_type() == dot_thunk.rhs_buffer_shape().shape().element_type() && @@ -214,13 +185,7 @@ absl::StatusOr ThunkProtoExecutionDeserializer::GetDotThunkRunImpl( TF_ASSIGN_OR_RETURN( std::string matmul_function, - GetMatmulFunction(dot_thunk.lhs_buffer_shape().shape().element_type(), - /*is_single_threaded=*/false)); - - TF_ASSIGN_OR_RETURN( - std::string single_threaded_matmul_function, - GetMatmulFunction(dot_thunk.lhs_buffer_shape().shape().element_type(), - /*is_single_threaded=*/true)); + GetMatmulFunction(dot_thunk.lhs_buffer_shape().shape().element_type())); TF_ASSIGN_OR_RETURN(std::string data_type, CppDataTypeFromXlaType( @@ -280,7 +245,7 @@ absl::StatusOr ThunkProtoExecutionDeserializer::GetDotThunkRunImpl( int64_t out_stride = m * n; std::vector> rewrites = { - {"{{SINGLE_THREADED_MATMUL_FUNCTION}}", single_threaded_matmul_function}, + {"{{MATMUL_FUNCTION}}", matmul_function}, {"{{OUTPUT_PTR}}", output_ptr}, {"{{OUTPUT_STRIDE}}", absl::StrCat(out_stride)}, {"{{LHS_PTR}}", lhs_ptr}, @@ -294,25 +259,17 @@ absl::StatusOr ThunkProtoExecutionDeserializer::GetDotThunkRunImpl( {"{{TRANSPOSE_RHS}}", transpose_rhs ? "true" : "false"}, {"{{BATCH_SIZE}}", absl::StrCat(dot_shape.batch_size)}}; - if (xla_cpu_multi_thread_eigen_) { - rewrites.push_back({"{{MATMUL_FUNCTION}}", matmul_function}); - } - return absl::StrReplaceAll(dot_thunk_invocation_format, rewrites); }; absl::StatusOr ThunkProtoExecutionDeserializer::GetConvolutionFunction( - xla::PrimitiveType xla_type, bool is_single_threaded) { + xla::PrimitiveType xla_type) { switch (xla_type) { case xla::F16: - return is_single_threaded - ? "__xla_cpu_runtime_EigenSingleThreadedConv2DF16" - : "__xla_cpu_runtime_EigenConv2DF16"; + return "xla::cpu::internal::EigenConv2D"; case xla::F32: - return is_single_threaded - ? "__xla_cpu_runtime_EigenSingleThreadedConv2DF32" - : "__xla_cpu_runtime_EigenConv2DF32"; + return "xla::cpu::internal::EigenConv2D"; default: return xla::Internal("Unsupported xla type: %d", xla_type); } @@ -345,63 +302,28 @@ ThunkProtoExecutionDeserializer::GetConvolution2DRunImpl( TF_ASSIGN_OR_RETURN( std::string convolution_function, GetConvolutionFunction( - convolution_thunk.input_buffer_shape().shape().element_type(), - /*is_single_threaded=*/false)); - - TF_ASSIGN_OR_RETURN( - std::string single_threaded_convolution_function, - GetConvolutionFunction( - convolution_thunk.input_buffer_shape().shape().element_type(), - /*is_single_threaded=*/true)); + convolution_thunk.input_buffer_shape().shape().element_type())); - absl::string_view convolution_thunk_invocation_format = - xla_cpu_multi_thread_eigen_ ? R"( + absl::string_view convolution_thunk_invocation_format = R"( // Convolution Thunk { - if (run_options->intra_op_thread_pool() != nullptr) { - {{CONVOLUTION_FUNCTION}}( - run_options, - {{OUTPUT_PTR}}, {{LHS_PTR}}, {{RHS_PTR}}, {{INPUT_BATCH}}, - {{INPUT_ROWS}}, {{INPUT_COLS}}, {{INPUT_CHANNELS}}, {{KERNEL_ROWS}}, - {{KERNEL_COLS}}, {{KERNEL_CHANNELS}}, {{KERNEL_FILTERS}}, - {{OUTPUT_ROWS}}, {{OUTPUT_COLS}}, {{ROW_STRIDE}}, {{COL_STRIDE}}, - {{PADDING_TOP}}, {{PADDING_BOTTOM}}, {{PADDING_LEFT}}, - {{PADDING_RIGHT}}, {{LHS_ROW_DILATION}}, {{LHS_COL_DILATION}}, - {{RHS_ROW_DILATION}}, {{RHS_COL_DILATION}}, {{FEATURE_GROUP_COUNT}} - ); - } else { - {{SINGLE_THREADED_CONVOLUTION_FUNCTION}}( - nullptr, - {{OUTPUT_PTR}}, {{LHS_PTR}}, {{RHS_PTR}}, {{INPUT_BATCH}}, - {{INPUT_ROWS}}, {{INPUT_COLS}}, {{INPUT_CHANNELS}}, {{KERNEL_ROWS}}, - {{KERNEL_COLS}}, {{KERNEL_CHANNELS}}, {{KERNEL_FILTERS}}, - {{OUTPUT_ROWS}}, {{OUTPUT_COLS}}, {{ROW_STRIDE}}, {{COL_STRIDE}}, - {{PADDING_TOP}}, {{PADDING_BOTTOM}}, {{PADDING_LEFT}}, - {{PADDING_RIGHT}}, {{LHS_ROW_DILATION}}, {{LHS_COL_DILATION}}, - {{RHS_ROW_DILATION}}, {{RHS_COL_DILATION}}, {{FEATURE_GROUP_COUNT}} - ); - } - })" - : - R"( - // Convolution Thunk - { - {{SINGLE_THREADED_CONVOLUTION_FUNCTION}}( - nullptr, - {{OUTPUT_PTR}}, {{LHS_PTR}}, {{RHS_PTR}}, {{INPUT_BATCH}}, - {{INPUT_ROWS}}, {{INPUT_COLS}}, {{INPUT_CHANNELS}}, {{KERNEL_ROWS}}, - {{KERNEL_COLS}}, {{KERNEL_CHANNELS}}, {{KERNEL_FILTERS}}, - {{OUTPUT_ROWS}}, {{OUTPUT_COLS}}, {{ROW_STRIDE}}, {{COL_STRIDE}}, - {{PADDING_TOP}}, {{PADDING_BOTTOM}}, {{PADDING_LEFT}}, - {{PADDING_RIGHT}}, {{LHS_ROW_DILATION}}, {{LHS_COL_DILATION}}, - {{RHS_ROW_DILATION}}, {{RHS_COL_DILATION}}, {{FEATURE_GROUP_COUNT}} - ); - } - )"; + absl::Notification done; + {{CONVOLUTION_FUNCTION}}( + run_options->intra_op_thread_pool(), + {{OUTPUT_PTR}}, {{LHS_PTR}}, {{RHS_PTR}}, {{INPUT_BATCH}}, + {{INPUT_ROWS}}, {{INPUT_COLS}}, {{INPUT_CHANNELS}}, {{KERNEL_ROWS}}, + {{KERNEL_COLS}}, {{KERNEL_CHANNELS}}, {{KERNEL_FILTERS}}, + {{OUTPUT_ROWS}}, {{OUTPUT_COLS}}, {{ROW_STRIDE}}, {{COL_STRIDE}}, + {{PADDING_TOP}}, {{PADDING_BOTTOM}}, {{PADDING_LEFT}}, + {{PADDING_RIGHT}}, {{LHS_ROW_DILATION}}, {{LHS_COL_DILATION}}, + {{RHS_ROW_DILATION}}, {{RHS_COL_DILATION}}, {{FEATURE_GROUP_COUNT}}, + [&done] { done.Notify(); } + ); + done.WaitForNotification(); + })"; std::vector> rewrites = { - {"{{SINGLE_THREADED_CONVOLUTION_FUNCTION}}", - single_threaded_convolution_function}, + {"{{CONVOLUTION_FUNCTION}}", convolution_function}, {"{{OUTPUT_PTR}}", output_ptr}, {"{{LHS_PTR}}", lhs_ptr}, {"{{RHS_PTR}}", rhs_ptr}, @@ -428,10 +350,6 @@ ThunkProtoExecutionDeserializer::GetConvolution2DRunImpl( {"{{FEATURE_GROUP_COUNT}}", absl::StrCat(canonical_dims.feature_group_count)}}; - if (xla_cpu_multi_thread_eigen_) { - rewrites.push_back({"{{CONVOLUTION_FUNCTION}}", convolution_function}); - } - return absl::StrReplaceAll(convolution_thunk_invocation_format, rewrites); } @@ -594,35 +512,47 @@ ThunkProtoExecutionDeserializer::GetSortThunkRunImpl( std::vector buffers_to_sort; buffers_to_sort.reserve(sort_thunk.inputs_shapes_size()); - std::vector values_primitive_type_size_in_bytes; - values_primitive_type_size_in_bytes.reserve(sort_thunk.inputs_shapes_size()); + std::vector primitive_sizes; + primitive_sizes.reserve(sort_thunk.inputs_shapes_size()); for (const auto& buffer_proto : sort_thunk.inputs_shapes()) { buffers_to_sort.push_back( - absl::StrCat("reinterpret_cast(", + absl::StrCat("reinterpret_cast(", GetBufferAllocationString(buffer_proto.slice()), ")")); - values_primitive_type_size_in_bytes.push_back( - xla::ShapeUtil::ByteSizeOfPrimitiveType( - buffer_proto.shape().element_type())); + primitive_sizes.push_back(xla::ShapeUtil::ByteSizeOfPrimitiveType( + buffer_proto.shape().element_type())); } absl::string_view sort_thunk_invocation_format = R"( // Sort Thunk { - std::vector values = { + std::vector values = { {{BUFFERS_TO_SORT}} }; - std::vector values_primitive_type_size_in_bytes = { + std::vector primitive_sizes = { {{VALUES_PRIMITIVE_TYPE_SIZE_IN_BYTES}} }; - __xla_cpu_runtime_KeyValueSort( - {{HIGHER_DIMENSIONS}}, {{SORT_DIMENSION_ELEMENTS}}, {{LOWER_DIMENSIONS}}, - values.data(), - int32_t(values.size()), - values_primitive_type_size_in_bytes.data(), - /*is_stable=*/{{IS_STABLE}}, - reinterpret_cast(run_options), - /*prof_counters=*/nullptr, - reinterpret_cast({{SORT_FUNCTION_NAME}})); + // Type alias compatible with `FunctionLibrary::Comparator`. + using Comparator = void(bool* result, const void* run_options, + const void** params, const void* buffer_table, + const void* status, const void* prof_counters); + Comparator* comparator = reinterpret_cast( + {{SORT_FUNCTION_NAME}}); + + absl::AnyInvocable less_than = + [comparator](const void** data) { + bool result; + (*comparator)(&result, nullptr, data, nullptr, nullptr, nullptr); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&result, sizeof(result)); + return result; + }; + + xla::cpu::internal::SortInplace( + { + {{HIGHER_DIMENSIONS}}, + {{SORT_DIMENSION_ELEMENTS}}, + {{LOWER_DIMENSIONS}} + }, + values, primitive_sizes, {{IS_STABLE}}, &less_than); })"; TF_ASSIGN_OR_RETURN( @@ -660,7 +590,7 @@ ThunkProtoExecutionDeserializer::GetSortThunkRunImpl( {"{{SORT_FUNCTION_NAME}}", sort_thunk.comparator_name()}, {"{{BUFFERS_TO_SORT}}", absl::StrJoin(buffers_to_sort, ", ")}, {"{{VALUES_PRIMITIVE_TYPE_SIZE_IN_BYTES}}", - absl::StrJoin(values_primitive_type_size_in_bytes, ", ")}, + absl::StrJoin(primitive_sizes, ", ")}, {"{{IS_STABLE}}", sort_thunk.is_stable() ? "true" : "false"}, }); } @@ -677,7 +607,7 @@ ThunkProtoExecutionDeserializer::GetTopKThunkRunImpl( absl::string_view topk_thunk_invocation_format = R"( // TopK Thunk { - __xla_cpu_runtime_TopKF32({{BATCH_SIZE}}, {{INPUT_SIZE}}, {{K}}, + ::xla::cpu::internal::TopK({{BATCH_SIZE}}, {{INPUT_SIZE}}, {{K}}, reinterpret_cast({{VALUES_PTR}}), reinterpret_cast({{OUTPUT_PTR}}), reinterpret_cast({{INDICES_PTR}})); diff --git a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h index 1e5e47f140020e..a5f7ddcd5fa13b 100644 --- a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h +++ b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "xla/backends/cpu/runtime/convolution_lib.h" +#include "xla/backends/cpu/runtime/convolution_dims.h" #include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/debug_options_flags.h" #include "xla/service/cpu/executable.pb.h" @@ -44,14 +44,13 @@ class ThunkProtoExecutionDeserializer { const xla::cpu::ThunkSequenceProto& thunk_sequence_proto); protected: - absl::StatusOr GetMatmulFunction(xla::PrimitiveType xla_type, - bool is_single_threaded); + absl::StatusOr GetMatmulFunction(xla::PrimitiveType xla_type); absl::StatusOr GetDotThunkRunImpl( const xla::cpu::ThunkProto& thunk); absl::StatusOr GetConvolutionFunction( - xla::PrimitiveType xla_type, bool is_single_threaded); + xla::PrimitiveType xla_type); absl::StatusOr GetConvolution2DRunImpl( const xla::cpu::ConvolutionThunkProto& convolution_thunk, diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c65bb6c44b1079..7c1772c084750c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -654,6 +654,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_xla//xla:future", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla:util", @@ -662,7 +663,6 @@ cc_library( "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt:pjrt_common", "@local_xla//xla/pjrt:pjrt_executable", - "@local_xla//xla/pjrt:pjrt_future", "@local_xla//xla/service:executable", "@local_xla//xla/service:maybe_owning_device_memory", "@local_xla//xla/service:shaped_buffer", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index bed899bfed2f3e..31f1aeedd9850e 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -132,7 +132,7 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node, if (merged_output.node() == nullptr) { Output new_output(new_node, oidx); if (debugging_opts.print_outputs) { - string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + std::string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0"; ops::Print print_op(s.WithOpName("print_", oidx) .WithDevice(cpu_device) .WithAssignedDevice(cpu_device), @@ -298,7 +298,8 @@ absl::StatusOr ReplaceFunctionCallWithPartitionedCall( const GraphOptimizationPassOptions& options, const FunctionLibraryDefinition& flib_def, Node* n, Graph* g, const NameAttrList& func, const Scope& root) { - string config_string = options.session_options->config.SerializeAsString(); + std::string config_string = + options.session_options->config.SerializeAsString(); int input_count = absl::c_count_if( n->in_edges(), [](const Edge* e) { return !e->IsControlEdge(); }); @@ -346,7 +347,8 @@ absl::StatusOr ReplaceFunctionCallWithPartitionedCall( absl::StatusOr InferDeviceForCluster( jit::DeviceInfoCache* device_info_cache, Node* n, - const string& function_name, const FunctionLibraryDefinition& flib_def) { + const std::string& function_name, + const FunctionLibraryDefinition& flib_def) { const FunctionDef* func_def = flib_def.Find(function_name); TF_RET_CHECK(func_def) << "Could not find " << function_name; @@ -485,7 +487,8 @@ absl::Status ReplaceNodeWithXlaCompileAndXlaRun( requires_compilation = true; } - string device_name_str = string(device_info_cache->GetNameFor(device)); + std::string device_name_str = + std::string(device_info_cache->GetNameFor(device)); absl::Status status; Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index c3b5ba5521ee65..6b90557df4b86f 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -85,8 +85,8 @@ absl::Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib, return absl::OkStatus(); } -absl::Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, - const string& node_name, +absl::Status MakeXlaCompiledKernel(Graph* graph, const std::string& callee_name, + const std::string& node_name, int num_constant_args, int num_resource_args, Node** result) { NodeDef call_node; @@ -99,14 +99,16 @@ absl::Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, return absl::OkStatus(); } -absl::Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, - const string& node_name, Node** result) { +absl::Status MakeXlaCompiledKernel(Graph* graph, const std::string& callee_name, + const std::string& node_name, + Node** result) { return MakeXlaCompiledKernel(graph, callee_name, node_name, /*num_constant_args=*/0, /*num_resource_args=*/0, result); } -Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) { +Node* MakeWrite(const Scope& scope, Output value_to_write, + const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var_" + id), DT_FLOAT, TensorShape({})); ops::AssignVariableOp assign_op(scope.WithOpName("Assignee_" + id), @@ -114,12 +116,13 @@ Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) { return assign_op.operation.node(); } -Node* MakeWrite(const Scope& scope, const string& id) { +Node* MakeWrite(const Scope& scope, const std::string& id) { return MakeWrite( scope, ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f), id); } -FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { +FunctionDefLibrary CreateFunctionDefLibWithConstFunction( + const std::string& name) { FunctionDefLibrary fdef_lib; FunctionDef func = FunctionDefHelper::Create( /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc index bb8dce848cfbc9..4164efc65a8f4c 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -36,19 +36,21 @@ class CloneConstantsForBetterClusteringPassImpl { private: absl::Status CloneSmallConstantInputs( - const absl::flat_hash_set& name_set, Node* n); - string GenerateUniqueName(const absl::flat_hash_set& name_set, - absl::string_view prefix); - absl::StatusOr CloneNode(const absl::flat_hash_set& name_set, - Node* n); + const absl::flat_hash_set& name_set, Node* n); + std::string GenerateUniqueName( + const absl::flat_hash_set& name_set, + absl::string_view prefix); + absl::StatusOr CloneNode( + const absl::flat_hash_set& name_set, Node* n); Graph* graph_; int unique_name_counter_; }; -string CloneConstantsForBetterClusteringPassImpl::GenerateUniqueName( - const absl::flat_hash_set& name_set, absl::string_view prefix) { - string candidate; +std::string CloneConstantsForBetterClusteringPassImpl::GenerateUniqueName( + const absl::flat_hash_set& name_set, + absl::string_view prefix) { + std::string candidate; do { candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++); } while (name_set.contains(candidate)); @@ -56,7 +58,7 @@ string CloneConstantsForBetterClusteringPassImpl::GenerateUniqueName( } absl::StatusOr CloneConstantsForBetterClusteringPassImpl::CloneNode( - const absl::flat_hash_set& name_set, Node* n) { + const absl::flat_hash_set& name_set, Node* n) { NodeDef new_in_def = n->def(); new_in_def.clear_input(); new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name())); @@ -112,7 +114,7 @@ bool IsInPlaceOp(absl::string_view op_name) { absl::Status CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( - const absl::flat_hash_set& name_set, Node* n) { + const absl::flat_hash_set& name_set, Node* n) { std::vector in_edges; // Get the edges and sort them so we clone in a deterministic order. absl::c_copy(n->in_edges(), std::back_inserter(in_edges)); @@ -142,7 +144,7 @@ CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( } absl::Status CloneConstantsForBetterClusteringPassImpl::Run() { - absl::flat_hash_set name_set; + absl::flat_hash_set name_set; absl::c_transform(graph_->nodes(), std::inserter(name_set, name_set.begin()), [](Node* n) { return n->name(); }); std::vector nodes; diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.cc b/tensorflow/compiler/jit/cluster_scoping_pass.cc index e70be48f0b7341..20a3d98be1d0f2 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass.cc @@ -51,8 +51,8 @@ class ClusterScopingPassImpl { size_t unique_scope_id_; }; -std::optional GetXlaInternalScope(Node* node) { - string scope; +std::optional GetXlaInternalScope(Node* node) { + std::string scope; if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) { return scope; } @@ -85,8 +85,8 @@ void SetXlaInternalScope(Node* node, absl::string_view scope) { // Node_X (scope "stage") -> Stage // void AddOrAppendXlaInternalScope(Node* node, absl::string_view suffix) { - string updated_scope; - std::optional cur_scope = GetXlaInternalScope(node); + std::string updated_scope; + std::optional cur_scope = GetXlaInternalScope(node); if (cur_scope == std::nullopt) { updated_scope = std::string(suffix); } else { @@ -96,7 +96,7 @@ void AddOrAppendXlaInternalScope(Node* node, absl::string_view suffix) { } void ClusterScopingPassImpl::AddScopeToAllTransitivePredecessors(Node* start) { - const string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); + const std::string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); std::vector starts; starts.push_back(start); @@ -106,7 +106,7 @@ void ClusterScopingPassImpl::AddScopeToAllTransitivePredecessors(Node* start) { } void ClusterScopingPassImpl::AddScopeToAllTransitiveSuccessors(Node* start) { - const string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); + const std::string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); std::vector starts; starts.push_back(start); diff --git a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc index b09cb2c12fa297..66cc10775992a3 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc @@ -45,10 +45,11 @@ absl::Status ClusterScoping(std::unique_ptr* graph) { return pass.Run(opt_options); } -absl::flat_hash_map GetXlaInternalScopes(const Graph& graph) { - absl::flat_hash_map scopes; +absl::flat_hash_map GetXlaInternalScopes( + const Graph& graph) { + absl::flat_hash_map scopes; for (Node* node : graph.nodes()) { - string scope; + std::string scope; if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) { scopes[node->name()] = scope; } @@ -63,7 +64,7 @@ absl::flat_hash_map GetXlaInternalScopes(const Graph& graph) { return scopes; } -Node* BuildStageNode(GraphDefBuilder& builder, string name, +Node* BuildStageNode(GraphDefBuilder& builder, std::string name, std::initializer_list dtypes, absl::Span values) { auto opts = builder.opts() diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 50b26371698877..6c77648817f808 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -172,7 +172,7 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( } bool RecursiveCompilabilityChecker::HasXLAKernel( - const Node& node, string* uncompilable_reason) const { + const Node& node, std::string* uncompilable_reason) const { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by // IsCompilableCall(). @@ -424,7 +424,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } - string uncompilable_reason; + std::string uncompilable_reason; if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) { if (!IsCompilableCall(node.def(), lib_runtime, stack_trace, encapsulating_function, uncompilable_nodes)) { diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 0d86c22de11a22..7d6741529ebd08 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -262,7 +262,7 @@ class RecursiveCompilabilityChecker { } bool HasXLAKernel(const Node& node, - string* uncompilable_reason = nullptr) const; + std::string* uncompilable_reason = nullptr) const; static void MaybeMarkUncompilableNode( const absl::string_view reason, diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 2b2db07642d1ab..fa546e3543e358 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -123,7 +123,7 @@ class Predicate { public: enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol, kIntSymbol }; - virtual string ToString() const = 0; + virtual std::string ToString() const = 0; // An ID assigned to the Predicate at construction time. Conceptually like a // pointer, except that it is stable across runs. @@ -156,12 +156,12 @@ class AndPredicate : public Predicate { explicit AndPredicate(int64_t id, std::vector operands) : Predicate(id), operands_(std::move(operands)) {} - string ToString() const override { + std::string ToString() const override { if (operands().empty()) { return "#true"; } - std::vector operands_str; + std::vector operands_str; std::transform(operands().begin(), operands().end(), std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); @@ -186,12 +186,12 @@ class OrPredicate : public Predicate { explicit OrPredicate(int64_t id, std::vector operands) : Predicate(id), operands_(std::move(operands)) {} - string ToString() const override { + std::string ToString() const override { if (operands().empty()) { return "#false"; } - std::vector operands_str; + std::vector operands_str; std::transform(operands().begin(), operands().end(), std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); @@ -215,7 +215,7 @@ class NotPredicate : public Predicate { explicit NotPredicate(int64_t id, Predicate* operand) : Predicate(id), operands_({operand}) {} - string ToString() const override { + std::string ToString() const override { return absl::StrCat("~", operand()->ToString()); } @@ -251,14 +251,14 @@ class NotPredicate : public Predicate { class AndRecurrencePredicate : public Predicate { public: explicit AndRecurrencePredicate(int64_t id, Predicate* start, Predicate* step, - std::vector frame) + std::vector frame) : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {} Predicate* start() const { return operands_[0]; } Predicate* step() const { return operands_[1]; } - absl::Span frame() const { return frame_; } + absl::Span frame() const { return frame_; } - string ToString() const override { + std::string ToString() const override { return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), "}<", absl::StrJoin(frame(), ";"), ">"); } @@ -271,7 +271,7 @@ class AndRecurrencePredicate : public Predicate { private: std::array operands_; - std::vector frame_; + std::vector frame_; }; // Represents an uninterpreted symbol in a logical predicate. @@ -286,7 +286,7 @@ class SymbolPredicate : public Predicate { tensor_id_(std::move(tensor_id)), must_be_true_(must_be_true) {} - string ToString() const override { + std::string ToString() const override { return must_be_true() ? absl::StrCat("*", tensor_id_.ToString()) : tensor_id_.ToString(); } @@ -320,7 +320,7 @@ class IntSymbolPredicate : public Predicate { tensor_id_(std::move(tensor_id)), must_have_value_(must_have_value) {} - string ToString() const override { + std::string ToString() const override { return must_have_value().has_value() ? absl::StrCat(tensor_id_.ToString(), "=", *must_have_value_) : tensor_id_.ToString(); @@ -396,7 +396,7 @@ class PredicateFactory { } Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step, - std::vector frame) { + std::vector frame) { SignatureForAndRec signature(start, step, std::move(frame)); auto it = interned_and_rec_instances_.find(signature); if (it != interned_and_rec_instances_.end()) { @@ -463,8 +463,8 @@ class PredicateFactory { Tensor tensor(proto->dtype()); TF_RET_CHECK(tensor.FromProto(*proto)); - *predicate = tensor.scalar()() == *must_have_value ? MakeTrue() - : MakeFalse(); + *predicate = tensor.scalar()() == *must_have_value ? MakeTrue() + : MakeFalse(); return absl::OkStatus(); } SignatureForIntSymbol signature = {tensor_id, must_have_value}; @@ -559,9 +559,9 @@ class PredicateFactory { std::pair>; using SignatureForNot = Predicate*; using SignatureForAndRec = - std::tuple>; + std::tuple>; using SignatureForSymbol = std::pair; - using SignatureForIntSymbol = std::pair>; + using SignatureForIntSymbol = std::pair>; struct HashSignatureForAndOr { size_t operator()(const SignatureForAndOr& signature) const { @@ -586,7 +586,7 @@ class PredicateFactory { SafeTensorId::Hasher()(signature.first), Hash64Combine( ::tensorflow::hash()(signature.second.has_value()), - ::tensorflow::hash()( + ::tensorflow::hash()( signature.second.has_value() ? *signature.second : 0))); } }; @@ -830,8 +830,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { absl::StatusOr GetPredicateFor( Node* n, int oidx) const override; void Print() const override; - absl::flat_hash_map PredicateMapAsString() - const; + absl::flat_hash_map + PredicateMapAsString() const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; @@ -958,7 +958,7 @@ absl::Status DeadnessAnalysisImpl::HandleSwitch( for (int i = 0; i < n->num_outputs() - 1; i++) { TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( pred_edge->src(), pred_edge->src_output(), - /*must_have_value=*/std::optional(i), &branch_pred)); + /*must_have_value=*/std::optional(i), &branch_pred)); input_preds.push_back(branch_pred); SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds), should_revisit); @@ -982,7 +982,7 @@ absl::Status DeadnessAnalysisImpl::HandleSwitch( namespace { absl::Status CreateMultipleNextIterationInputsError(Node* merge) { - std::vector backedges; + std::vector backedges; for (const Edge* backedge : merge->in_edges()) { if (backedge->src()->IsNextIteration()) { backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src()))); @@ -1058,7 +1058,7 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, absl::Status GetFullFrame(const Node* n, absl::Span cfi_infos, - std::vector* frame) { + std::vector* frame) { int depth = 0; for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource(); n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) { @@ -1174,7 +1174,7 @@ absl::Status DeadnessAnalysisImpl::HandleMerge( Predicate* start = predicate_factory_.MakeOrPredicate(non_recurrent_inputs); - std::vector frame; + std::vector frame; TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame)); Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate( start, step, std::move(frame)); @@ -1358,7 +1358,7 @@ absl::Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder( // nested while, as there is no clean cut for separating them in the topological // order. absl::Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) { - std::vector unreachable_nodes; + std::vector unreachable_nodes; // Compute the loop structure of the graph. TF_RETURN_IF_ERROR( BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes)); @@ -1582,9 +1582,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return absl::OkStatus(); } -absl::flat_hash_map +absl::flat_hash_map DeadnessAnalysisImpl::PredicateMapAsString() const { - absl::flat_hash_map result; + absl::flat_hash_map result; for (const auto& kv_pair : predicate_map_) { CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); } @@ -1603,7 +1603,7 @@ absl::Status ComputePredicates(const Graph& graph, } // namespace deadness_analysis_internal -string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const { +std::string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const { return static_cast(predicate.pred_)->ToString(); } diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h index 80fa9a20faef41..1cd394154faf36 100644 --- a/tensorflow/compiler/jit/deadness_analysis.h +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -81,7 +81,7 @@ class DeadnessAnalysis { virtual void Print() const = 0; virtual ~DeadnessAnalysis(); - string DebugString(DeadnessPredicate predicate) const; + std::string DebugString(DeadnessPredicate predicate) const; // Run the deadness analysis over `graph` and returns an error or a populated // instance of DeadnessAnalysis in `result`. diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 0dc18d3e129d79..569cdeadae735e 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -24,7 +24,8 @@ namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. -using PredicateMapTy = absl::flat_hash_map; +using PredicateMapTy = + absl::flat_hash_map; absl::Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, bool enable_optimistic = true); diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 894ee659121e25..fd7d93b3772f5f 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -61,7 +61,7 @@ absl::Status AnalyzeDeadness(Graph* graph, return DeadnessAnalysis::Run(*graph, result); } -ops::Switch CreateSwitch(const Scope& root, const string& prefix) { +ops::Switch CreateSwitch(const Scope& root, const std::string& prefix) { Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT); Output predicate = ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL); @@ -76,7 +76,7 @@ void VLogGraphIfAsked(const Graph& graph) { if (VLOG_IS_ON(3)) { GraphDef graph_def; graph.ToGraphDef(&graph_def); - string serialized; + std::string serialized; ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized); LOG(INFO) << serialized; } @@ -127,8 +127,8 @@ struct InductionVarInfo { // +-----> | Exit | // +---------------+ InductionVarInfo CreateInductionVariable(const Scope& root, - const string& prefix, - const string& frame_name, + const std::string& prefix, + const std::string& frame_name, const Output& initial_value) { Output enter_initial_value = ops::internal::Enter( root.WithOpName(prefix + "/enter"), initial_value, frame_name); @@ -158,8 +158,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root, } InductionVarInfo CreateInductionVariable(const Scope& root, - const string& prefix, - const string& frame_name, + const std::string& prefix, + const std::string& frame_name, int32_t init) { return CreateInductionVariable( root, prefix, frame_name, @@ -201,7 +201,7 @@ struct DependentInductionVar { }; DependentInductionVar CreateDependentLoopInvariantValue( - const Scope& root, const string& prefix, const string& frame_name, + const Scope& root, const std::string& prefix, const std::string& frame_name, const Output& loop_cond, const Output& value) { Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"), value, frame_name); @@ -218,7 +218,7 @@ DependentInductionVar CreateDependentLoopInvariantValue( } DependentInductionVar CreateDependentLoopInvariantValue( - const Scope& root, const string& prefix, const string& frame_name, + const Scope& root, const std::string& prefix, const std::string& frame_name, const Output& loop_cond, int32_t value) { return CreateDependentLoopInvariantValue( root, prefix, frame_name, loop_cond, diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc index 9ec02d92d37cd6..8288b44e7f1c1d 100644 --- a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc @@ -65,9 +65,9 @@ struct SignatureNotEqual { // Functor that incrementally computes a Signature's hash given its current hash // and one of its args. struct SignatureHashCombiner { - explicit SignatureHashCombiner(const uint64 h) : h(h) {} - uint64 h; - uint64 operator()(const Tensor& arg) { + explicit SignatureHashCombiner(const uint64_t h) : h(h) {} + uint64_t h; + uint64_t operator()(const Tensor& arg) { h = Hash64Combine(h, std::hash()(static_cast(arg.dtype()))); h = Hash64Combine( h, Hash64(arg.tensor_data().data(), arg.tensor_data().size())); @@ -76,7 +76,7 @@ struct SignatureHashCombiner { } return h; } - uint64 operator()(const TensorTypeAndShape& arg) { + uint64_t operator()(const TensorTypeAndShape& arg) { h = Hash64Combine(h, std::hash()(static_cast(arg.first))); h = Hash64Combine(h, std::hash()(arg.second.size())); for (int dim : arg.second) { @@ -108,8 +108,8 @@ bool Signature::operator==(const Signature& other) const { return true; } -uint64 Signature::Hash::operator()(const Signature& signature) const { - uint64 h = std::hash()(signature.name); +uint64_t Signature::Hash::operator()(const Signature& signature) const { + uint64_t h = std::hash()(signature.name); for (const auto& arg : signature.args) { h = std::visit(SignatureHashCombiner(h), arg); } diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.h b/tensorflow/compiler/jit/device_compilation_cluster_signature.h index b4c2840eedee59..721c1d3b78c50e 100644 --- a/tensorflow/compiler/jit/device_compilation_cluster_signature.h +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.h @@ -58,7 +58,8 @@ struct DeviceCompilationClusterSignature { bool operator==(const DeviceCompilationClusterSignature& other) const; struct Hash { - uint64 operator()(const DeviceCompilationClusterSignature& signature) const; + uint64_t operator()( + const DeviceCompilationClusterSignature& signature) const; }; // Returns a human-readable description of the signature. diff --git a/tensorflow/compiler/jit/device_compilation_profiler.cc b/tensorflow/compiler/jit/device_compilation_profiler.cc index 5e1b3b26e8ecb5..ec161293b7643d 100644 --- a/tensorflow/compiler/jit/device_compilation_profiler.cc +++ b/tensorflow/compiler/jit/device_compilation_profiler.cc @@ -107,7 +107,7 @@ absl::Status DeviceCompilationProfiler::RegisterCompilation( cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) .first; - const uint64 compile_time_s = compile_time_us / 1.0e6; + const uint64_t compile_time_s = compile_time_us / 1.0e6; it->second.compile_count++; it->second.cumulative_compile_time_us += compile_time_us; VLOG(1) << "Compiled " << function_name << " " << it->second.compile_count diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index 0fae07abd22897..a9f2418282c414 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -137,7 +137,7 @@ class DeviceCompiler : public ResourceBase { return compiler_client_.get(); } - string DebugString() const override; + std::string DebugString() const override; private: // Common implementation of Compile and CompileSingleOp. The `OpKernelContext` @@ -259,7 +259,7 @@ DeviceCompiler::~DeviceCompiler() { } template -string DeviceCompiler::DebugString() const { +std::string DeviceCompiler::DebugString() const { return "DeviceCompiler"; } @@ -331,7 +331,7 @@ DeviceCompiler::CompileStrict( CompileScope scope, OpKernelContext* ctx, DeviceCompilationProfiler* profiler, mutex* mu) { tensorflow::Env* env = tensorflow::Env::Default(); - const uint64 compile_start_us = env->NowMicros(); + const uint64_t compile_start_us = env->NowMicros(); TfGraphToHloCompiler compiler(options); cache_value.compile_state = DeviceCompileState::kCompiled; @@ -385,8 +385,8 @@ DeviceCompiler::CompileStrict( // Finalize the cache to release the XlaComputation after it was compiled. cache_->Finalize(); - const uint64 compile_end_us = env->NowMicros(); - const uint64 compile_time_us = compile_end_us - compile_start_us; + const uint64_t compile_end_us = env->NowMicros(); + const uint64_t compile_time_us = compile_end_us - compile_start_us; device_compiler_internal::LogOnceXlaCompiledFirstCluster(); TF_RETURN_IF_ERROR(profiler->RegisterCompilation( @@ -496,7 +496,7 @@ absl::Status DeviceCompiler::CompileImpl( profiler->RegisterExecution(function); - string human_signature; + std::string human_signature; if (VLOG_IS_ON(2)) { human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name(); VLOG(2) << "DeviceCompilationClusterSignature: " << human_signature; diff --git a/tensorflow/compiler/jit/device_compiler_test.cc b/tensorflow/compiler/jit/device_compiler_test.cc index 64e286bff55b07..749110be186311 100644 --- a/tensorflow/compiler/jit/device_compiler_test.cc +++ b/tensorflow/compiler/jit/device_compiler_test.cc @@ -139,7 +139,7 @@ class MockXlaDeviceExecutablePersistor Config{testing::TmpDir(), false, "xla"}, DeviceType(DEVICE_CPU_XLA_JIT)) {} MOCK_METHOD(absl::Status, TryToPersistExecutable, - (uint64, const std::string&, const XlaCompiler::Options&, + (uint64_t, const std::string&, const XlaCompiler::Options&, const XlaCompiler::CompilationResult&, const xla::LocalExecutable&, (DeviceCompilerClient*)), @@ -425,7 +425,7 @@ TEST_F(DeviceCompilerTest, CompileFailedToLoadFromPersistentCache) { &xla_executable)); // Corrupt the file which contains the serialized executable. - std::vector files; + std::vector files; TF_ASSERT_OK(Env::Default()->GetChildren(testing::TmpDir(), &files)); std::string const* serialized_executable_filename = nullptr; for (const auto& file : files) { diff --git a/tensorflow/compiler/jit/device_context_test.cc b/tensorflow/compiler/jit/device_context_test.cc index 34a0c3d5ea067b..33bba30f3db3e1 100644 --- a/tensorflow/compiler/jit/device_context_test.cc +++ b/tensorflow/compiler/jit/device_context_test.cc @@ -38,7 +38,7 @@ static bool Initialized = [] { class DeviceContextTest : public ::testing::Test { public: - void SetDevice(const string& device_type) { + void SetDevice(const std::string& device_type) { auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; rollout_config.AllowForDeviceInXlaLaunch(DeviceType(device_type)); rollout_config.AllowForDeviceInXlaCompileOnDemand(DeviceType(device_type)); diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h index 458441c86b5c43..5a64b078e1a93c 100644 --- a/tensorflow/compiler/jit/device_executable_persistor.h +++ b/tensorflow/compiler/jit/device_executable_persistor.h @@ -96,7 +96,7 @@ class DeviceExecutablePersistor { // TODO(b/255826209): Take in Signature instead of hash and string once cache // is refactored. std::optional>> TryToLoadExecutable( - uint64 signature_hash, const std::string& signature_str, + uint64_t signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, DeviceCompilerClient* client) const; @@ -107,7 +107,7 @@ class DeviceExecutablePersistor { // TODO(b/255826209): Take in Signature instead hash and string once cache // is refactored. virtual absl::Status TryToPersistExecutable( - uint64 signature_hash, const std::string& signature_str, + uint64_t signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, const ExecutableType& executable, @@ -123,15 +123,15 @@ class DeviceExecutablePersistor { // Returns a cache key proto that identifies an entry in the compilation // cache. XlaSerializedCacheKey BuildSerializedCacheKey( - uint64 signature_hash, const xla::HloModuleProto& hlo_module) const; + uint64_t signature_hash, const xla::HloModuleProto& hlo_module) const; XlaSerializedCacheKey BuildSerializedCacheKey( - uint64 signature_hash, const xla::HloModuleProto& hlo_module, + uint64_t signature_hash, const xla::HloModuleProto& hlo_module, bool compiled_using_pjrt) const; // Serializes the signature and its corresponding entry to a proto message. absl::StatusOr SerializeEntry( - uint64 signature_hash, const XlaCompiler::Options& options, + uint64_t signature_hash, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, const ExecutableType& executable, DeviceCompilerClient* compiler_client) const; @@ -189,7 +189,7 @@ std::string DeviceExecutablePersistor::GetFilePath( template XlaSerializedCacheKey DeviceExecutablePersistor::BuildSerializedCacheKey( - uint64 signature_hash, const xla::HloModuleProto& hlo_module, + uint64_t signature_hash, const xla::HloModuleProto& hlo_module, bool compiled_using_pjrt) const { XlaSerializedCacheKey key; key.set_signature_fingerprint(signature_hash); @@ -203,7 +203,7 @@ DeviceExecutablePersistor::BuildSerializedCacheKey( template XlaSerializedCacheKey DeviceExecutablePersistor::BuildSerializedCacheKey( - uint64 signature_hash, const xla::HloModuleProto& hlo_module) const { + uint64_t signature_hash, const xla::HloModuleProto& hlo_module) const { return BuildSerializedCacheKey(signature_hash, hlo_module, false); } @@ -212,7 +212,7 @@ DeviceExecutablePersistor::BuildSerializedCacheKey( template <> inline XlaSerializedCacheKey DeviceExecutablePersistor:: - BuildSerializedCacheKey(uint64 signature_hash, + BuildSerializedCacheKey(uint64_t signature_hash, const xla::HloModuleProto& hlo_module) const { return BuildSerializedCacheKey(signature_hash, hlo_module, true); } @@ -305,7 +305,7 @@ DeviceExecutablePersistor::SaveSerializedEntry( template absl::StatusOr DeviceExecutablePersistor::SerializeEntry( - uint64 signature_hash, const XlaCompiler::Options& options, + uint64_t signature_hash, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, const ExecutableType& executable, DeviceCompilerClient* compiler_client) const { @@ -340,7 +340,7 @@ DeviceExecutablePersistor::SerializeEntry( template std::optional>> DeviceExecutablePersistor::TryToLoadExecutable( - uint64 signature_hash, const std::string& signature_str, + uint64_t signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, DeviceCompilerClient* compiler_client) const { @@ -376,7 +376,7 @@ DeviceExecutablePersistor::TryToLoadExecutable( template absl::Status DeviceExecutablePersistor::TryToPersistExecutable( - uint64 signature_hash, const std::string& signature_str, + uint64_t signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, const ExecutableType& executable, diff --git a/tensorflow/compiler/jit/device_executable_persistor_test.cc b/tensorflow/compiler/jit/device_executable_persistor_test.cc index 7779f1112e7b9e..62cfd4c1b8e0b7 100644 --- a/tensorflow/compiler/jit/device_executable_persistor_test.cc +++ b/tensorflow/compiler/jit/device_executable_persistor_test.cc @@ -222,7 +222,7 @@ absl::StatusOr ReadCacheEntryFromFile( } XlaSerializedCacheKey CreateCacheKey( - uint64 signature_hash, + uint64_t signature_hash, const XlaCompiler::CompilationResult& compilation_result, const DeviceType& device_type, const std::string& persistence_prefix, bool compiled_using_pjrt = false) { diff --git a/tensorflow/compiler/jit/device_util.cc b/tensorflow/compiler/jit/device_util.cc index 828da0b08c2590..1979aec5bcf0c3 100644 --- a/tensorflow/compiler/jit/device_util.cc +++ b/tensorflow/compiler/jit/device_util.cc @@ -44,7 +44,7 @@ void DeviceSet::UnionWith(const DeviceSet& other) { } bool DeviceSet::IsEmpty() const { - return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; }); + return absl::c_all_of(storage_, [&](uint64_t val) { return val == 0; }); } absl::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { @@ -56,7 +56,7 @@ absl::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { } int new_id = names_.size(); - names_.push_back(string(name)); + names_.push_back(std::string(name)); id_to_device_type_.push_back(std::make_unique("")); DeviceType* device_type = id_to_device_type_.back().get(); TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type)); @@ -64,7 +64,7 @@ absl::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { is_cpu_.push_back(device_type->type_string() == DEVICE_CPU); is_gpu_.push_back(device_type->type_string() == DEVICE_GPU); - name_to_id_.emplace(string(name), DeviceId(new_id)); + name_to_id_.emplace(std::string(name), DeviceId(new_id)); const XlaOpRegistry::DeviceRegistration* compilation_device; if (!XlaOpRegistry::GetCompilationDevice(device_type->type(), @@ -76,10 +76,10 @@ absl::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { return DeviceId(new_id); } -string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { - std::vector names; +std::string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { + std::vector names; device_set.ForEach([&](DeviceId device_id) { - names.push_back(string(GetNameFor(device_id))); + names.push_back(std::string(GetNameFor(device_id))); return true; }); @@ -87,7 +87,7 @@ string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { } } // namespace jit -absl::Status DeviceNameToDeviceType(const string& device, +absl::Status DeviceNameToDeviceType(const std::string& device, DeviceType* device_type) { DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(device, &parsed)) { diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h index 745f87309501d8..fa862aac88c394 100644 --- a/tensorflow/compiler/jit/device_util.h +++ b/tensorflow/compiler/jit/device_util.h @@ -75,9 +75,9 @@ class DeviceSet { // iterator if this ends up being used widely. for (int word_index = 0, end = storage_.size(); word_index < end; word_index++) { - uint64 word = storage_[word_index]; + uint64_t word = storage_[word_index]; while (word != 0) { - uint64 only_lowest_bit_set = word & -word; + uint64_t only_lowest_bit_set = word & -word; // The number of trailing zeros in a non-zero word is the index of the // least significant 1. int bit_index = absl::countr_zero(word); @@ -90,7 +90,7 @@ class DeviceSet { } private: - absl::InlinedVector storage_; + absl::InlinedVector storage_; const int kWordSize = 64; }; @@ -131,17 +131,17 @@ class DeviceInfoCache { return std::cref(*id_to_device_type_[device_id.id()]); } - string DebugString(const DeviceSet& device_set) const; + std::string DebugString(const DeviceSet& device_set) const; private: - absl::flat_hash_map name_to_id_; + absl::flat_hash_map name_to_id_; // These fields are populated for a device in GetIdFor, *before* we give out a // DeviceId. std::vector id_to_compilation_device_; std::vector> id_to_device_type_; - std::vector names_; + std::vector names_; std::vector is_cpu_; std::vector is_gpu_; }; @@ -149,7 +149,7 @@ class DeviceInfoCache { } // namespace jit // Returns the DeviceType corresponding to 'device'. -absl::Status DeviceNameToDeviceType(const string& device, +absl::Status DeviceNameToDeviceType(const std::string& device, DeviceType* device_type); // Picks the device for which XLA should compile a cluster that contains diff --git a/tensorflow/compiler/jit/device_util_test.cc b/tensorflow/compiler/jit/device_util_test.cc index cef39df6283f2b..be58292f931686 100644 --- a/tensorflow/compiler/jit/device_util_test.cc +++ b/tensorflow/compiler/jit/device_util_test.cc @@ -23,7 +23,7 @@ namespace { absl::Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, absl::Span device_names, - string* result) { + std::string* result) { jit::DeviceInfoCache cache; jit::DeviceSet device_set; for (absl::string_view name : device_names) { @@ -34,14 +34,14 @@ absl::Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, TF_ASSIGN_OR_RETURN( jit::DeviceId result_id, PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu)); - *result = string(cache.GetNameFor(result_id)); + *result = std::string(cache.GetNameFor(result_id)); return absl::OkStatus(); } void CheckPickDeviceResult(absl::string_view expected_result, bool allow_mixing_unknown_and_cpu, absl::Span inputs) { - string result; + std::string result; TF_ASSERT_OK(PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result)) << "inputs = [" << absl::StrJoin(inputs, ", ") << "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu @@ -51,7 +51,7 @@ void CheckPickDeviceResult(absl::string_view expected_result, void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu, absl::Span inputs) { - string result; + std::string result; EXPECT_FALSE( PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result).ok()); } @@ -110,10 +110,10 @@ void SimpleRoundTripTestForDeviceSet(int num_devices) { jit::DeviceSet device_set; jit::DeviceInfoCache device_info_cache; - std::vector expected_devices, actual_devices; + std::vector expected_devices, actual_devices; for (int i = 0; i < num_devices; i++) { - string device_name = + std::string device_name = absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i); TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id, device_info_cache.GetIdFor(device_name)); @@ -122,7 +122,8 @@ void SimpleRoundTripTestForDeviceSet(int num_devices) { } device_set.ForEach([&](jit::DeviceId device_id) { - actual_devices.push_back(string(device_info_cache.GetNameFor(device_id))); + actual_devices.push_back( + std::string(device_info_cache.GetNameFor(device_id))); return true; }); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 3e8a43ce08ed58..6e7d16de16a4f6 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -115,7 +115,7 @@ void MarkGuaranteedConstants( } struct OutputInputTensorPairHasher { - uint64 operator()(std::pair const& s) const { + uint64_t operator()(std::pair const& s) const { return Hash64Combine(OutputTensor::Hash()(s.first), InputTensor::Hash()(s.second)); } @@ -128,7 +128,7 @@ static const char* const kRetValOp = "_Retval"; class Encapsulator { public: - Encapsulator(string group_attribute, Graph const* graph_in) + Encapsulator(std::string group_attribute, Graph const* graph_in) : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {} // Find subgraphs marked with 'group_attribute', and build a new @@ -182,7 +182,7 @@ class Encapsulator { // 'reuse_existing_functions' is set, use an existing function with the same // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the // subgraph before function conversion. - absl::Status BuildFunctionDef(const string& name_in, + absl::Status BuildFunctionDef(const std::string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library); @@ -226,7 +226,7 @@ class Encapsulator { const absl::flat_hash_map& node_images); // Creates the sequencer node if it doesn't exist, adding it to graph_out. - absl::Status MakeSequencingNode(const string& subgraph_name, + absl::Status MakeSequencingNode(const std::string& subgraph_name, Graph* graph_out); // If there is a sequencer node, adds a control edge from the sequencer to @@ -243,14 +243,14 @@ class Encapsulator { // Which device are these nodes on? Used to assign a device to the call // node. - string device_; + std::string device_; // NodeDef for the function call node. NodeDef call_node_def_; // Name that is used for the call node. This may not be // call_node_def_.name() if the client supplies a rewrite lambda. - string function_def_name_; + std::string function_def_name_; // Placeholder node simulating the host compute key in the output graph. // Not owned. @@ -275,7 +275,7 @@ class Encapsulator { // Set of node names that are the source of a control output of the // subgraph. We store strings here so that we can tolerate nodes being // removed from the graph. - absl::flat_hash_set control_output_nodes_; + absl::flat_hash_set control_output_nodes_; // NoOp node in the output graph that is sequenced after the call node. Node* sequencer_ = nullptr; @@ -283,7 +283,7 @@ class Encapsulator { // Returns the key attribute associated with a node in attr. Sets either // result to the empty string if the respective attribute is not found. - absl::Status GetFunctionNameAttr(Node const* node, string* attr) const; + absl::Status GetFunctionNameAttr(Node const* node, std::string* attr) const; // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to // subgraphs for data edges that cross subgraph boundaries. @@ -308,36 +308,35 @@ class Encapsulator { // a subgraph boundary it is the output of a call node, otherwise it is a node // in the output graph. absl::Status FindOutputImageOfEdgeSrc( - const string& src_func_id, const string& dst_func_id, + const std::string& src_func_id, const std::string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_src_node, Node** src_image); // Finds an edge source slot in the output graph. If the edge crosses a // subgraph boundary it is a slot on the output of a call node, otherwise it // is a slot on a node in the output graph. - int FindOutputSlotOfEdgeSrc(const string& src_func_id, - const string& dst_func_id, - const Edge* edge); + int FindOutputSlotOfEdgeSrc(const std::string& src_func_id, + const std::string& dst_func_id, const Edge* edge); // Finds the image of an edge destination in the output graph. If the edge // crosses a subgraph boundary it is the input of a call node, otherwise it is // a node in the output graph. absl::Status FindOutputImageOfEdgeDst( - const string& src_func_id, const string& dst_func_id, + const std::string& src_func_id, const std::string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_dst_node, Node** dst_image); // Finds an edge destination slot in the output graph. If the edge crosses a // subgraph boundary it is a slot on the input of a call node, otherwise it is // a slot on a node in the output graph. - int FindOutputSlotOfEdgeDst(const string& src_func_id, - const string& dst_func_id, - const Edge* edge); + int FindOutputSlotOfEdgeDst(const std::string& src_func_id, + const std::string& dst_func_id, const Edge* edge); // Copies a single edge to the output graph. The edge is either entirely // within the output graph, or crosses into or out of a compiled subgraph. absl::Status CopyEdgeToOutputGraph( - const Edge* edge, const string& src_func_id, const string& dst_func_id, + const Edge* edge, const std::string& src_func_id, + const std::string& dst_func_id, const absl::flat_hash_map& node_images, Graph* graph_out, absl::flat_hash_set, @@ -358,10 +357,10 @@ class Encapsulator { absl::flat_hash_map* node_images, FunctionLibraryDefinition* library); - const string group_attribute_; + const std::string group_attribute_; const Graph* graph_in_; - absl::flat_hash_map subgraphs_; + absl::flat_hash_map subgraphs_; Encapsulator(const Encapsulator&) = delete; void operator=(const Encapsulator&) = delete; @@ -374,19 +373,20 @@ namespace { // including clusters that are not present in the ancestors map. has_successors // is the set of clusters that are ancestors of some other cluster. void TopologicalClusterSort( - const absl::flat_hash_set& clusters, - const absl::flat_hash_set& has_successors, - const absl::flat_hash_map>& ancestors, - std::vector* sorted) { + const absl::flat_hash_set& clusters, + const absl::flat_hash_set& has_successors, + const absl::flat_hash_map>& + ancestors, + std::vector* sorted) { // The nodes are placed in 'sorted' in topological order. sorted->clear(); // We don't use the standard DFS because we are not operating on Node* // objects. struct Work { - string cluster; + std::string cluster; bool leave; }; - std::set visited; + std::set visited; std::vector stack; // Seed the processing list with clusters that have no successors. for (const auto& cluster : clusters) { @@ -523,7 +523,7 @@ absl::Status Encapsulator::Subgraph::RecordResult( } absl::Status Encapsulator::Subgraph::MakeSequencingNode( - const string& subgraph_name, Graph* graph_out) { + const std::string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; // TODO(shikharagarwal): What source node should we use for errors? @@ -547,11 +547,11 @@ void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { } absl::Status Encapsulator::Subgraph::BuildFunctionDef( - const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + const std::string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { // name_in is copied here because name may be modified below if // rewrite_subgraph_fn is true. - string name = name_in; + std::string name = name_in; call_node_def_.set_op(name); call_node_def_.set_name(name); call_node_def_.set_device(device_); @@ -596,7 +596,7 @@ absl::Status Encapsulator::Subgraph::BuildFunctionDef( function_def_name_ = name; FunctionDef fdef; - auto lookup = [this](const Node* node) -> std::optional { + auto lookup = [this](const Node* node) -> std::optional { if (control_output_nodes_.contains(node->name())) { return std::make_optional(node->name()); } @@ -625,7 +625,7 @@ absl::Status Encapsulator::Subgraph::BuildFunctionDef( absl::Status Encapsulator::Subgraph::ReplaceFunctionDef( FunctionLibraryDefinition* library) { - const string& name = function_def_name_; + const std::string& name = function_def_name_; FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -654,7 +654,7 @@ absl::Status Encapsulator::Subgraph::AddFunctionCallNode( } absl::Status Encapsulator::GetFunctionNameAttr(Node const* node, - string* attr) const { + std::string* attr) const { AttrSlice attrs = node->attrs(); attr->clear(); for (const auto& node_attr : attrs) { @@ -667,12 +667,12 @@ absl::Status Encapsulator::GetFunctionNameAttr(Node const* node, return absl::OkStatus(); } -bool IsInSubgraph(const string& func_id) { return !func_id.empty(); } +bool IsInSubgraph(const std::string& func_id) { return !func_id.empty(); } absl::Status Encapsulator::CopySubgraphNodes( absl::flat_hash_map* node_images) { for (Node* node : graph_in_->op_nodes()) { - string func_id; + std::string func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); if (!IsInSubgraph(func_id)) continue; @@ -688,9 +688,9 @@ absl::Status Encapsulator::CopySubgraphEdges( const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs) { for (const Edge* edge : graph_in_->edges()) { - string src_func_id; + std::string src_func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); - string dst_func_id; + std::string dst_func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); @@ -793,7 +793,7 @@ absl::Status Encapsulator::BuildFunctionDefs( const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { for (auto& subgraph_entry : subgraphs_) { - string name = subgraph_entry.first; + std::string name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef( name, rewrite_subgraph_fn, reuse_existing_functions, library)); @@ -804,7 +804,7 @@ absl::Status Encapsulator::BuildFunctionDefs( absl::Status Encapsulator::CopyNodesToOutputGraph( Graph* graph_out, absl::flat_hash_map* node_images) { for (Node* node : graph_in_->op_nodes()) { - string func_id; + std::string func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); // Don't copy nodes that are going to be encapsulated. @@ -829,7 +829,7 @@ absl::Status Encapsulator::AddFunctionCallNodes( } absl::Status Encapsulator::FindOutputImageOfEdgeSrc( - const string& src_func_id, const string& dst_func_id, + const std::string& src_func_id, const std::string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_src_node, Node** src_image) { if (IsInSubgraph(src_func_id)) { @@ -844,8 +844,8 @@ absl::Status Encapsulator::FindOutputImageOfEdgeSrc( return absl::OkStatus(); } -int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, - const string& dst_func_id, +int Encapsulator::FindOutputSlotOfEdgeSrc(const std::string& src_func_id, + const std::string& dst_func_id, const Edge* edge) { if (IsInSubgraph(src_func_id)) { const Subgraph& src_subgraph = subgraphs_.at(src_func_id); @@ -860,7 +860,7 @@ int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, } absl::Status Encapsulator::FindOutputImageOfEdgeDst( - const string& src_func_id, const string& dst_func_id, + const std::string& src_func_id, const std::string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_dst_node, Node** dst_image) { if (IsInSubgraph(dst_func_id)) { @@ -875,8 +875,8 @@ absl::Status Encapsulator::FindOutputImageOfEdgeDst( return absl::OkStatus(); } -int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, - const string& dst_func_id, +int Encapsulator::FindOutputSlotOfEdgeDst(const std::string& src_func_id, + const std::string& dst_func_id, const Edge* edge) { if (IsInSubgraph(dst_func_id)) { const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); @@ -891,7 +891,8 @@ int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, } absl::Status Encapsulator::CopyEdgeToOutputGraph( - const Edge* edge, const string& src_func_id, const string& dst_func_id, + const Edge* edge, const std::string& src_func_id, + const std::string& dst_func_id, const absl::flat_hash_map& node_images, Graph* graph_out, absl::flat_hash_set, @@ -943,9 +944,9 @@ absl::Status Encapsulator::AddEdgesToOutputGraph( edges_added; for (const Edge* edge : graph_in_->edges()) { - string src_func_id; + std::string src_func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); - string dst_func_id; + std::string dst_func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); // Ignore edges that are strictly contained within one subgraph, unless @@ -1091,7 +1092,7 @@ absl::Status Encapsulator::BuildOutputGraph( } // anonymous namespace absl::Status EncapsulateSubgraphsInFunctions( - string group_attribute, const Graph& graph_in, + std::string group_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { Encapsulator encapsulator(std::move(group_attribute), diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 0c7729f67349b5..ed2c9ef45a2c16 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -73,7 +73,7 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 1e05ad067def7f..94b136a02b99cf 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -46,7 +46,7 @@ const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; absl::Status AddGraphDefToFunctionLibrary( - const GraphDefBuilder& graphdef_builder, const string& name_suffix, + const GraphDefBuilder& graphdef_builder, const std::string& name_suffix, FunctionDefLibrary* library) { GraphDef graphdef; TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef)); @@ -64,13 +64,14 @@ absl::Status AddGraphDefToFunctionLibrary( } template -bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, - const ::tensorflow::protobuf::Map& b, - const std::function& key_to_string, - const std::function& value_to_string, - const std::function& compare, - const string& map_name, string* diff) { +bool EqualProtoMap( + const ::tensorflow::protobuf::Map& a, + const ::tensorflow::protobuf::Map& b, + const std::function& key_to_string, + const std::function& value_to_string, + const std::function& + compare, + const std::string& map_name, std::string* diff) { for (const auto& elt_a : a) { const auto iter = b.find(elt_a.first); if (iter == b.end()) { @@ -106,7 +107,7 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, } bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, - const string& diff_preamble, string* diff) { + const std::string& diff_preamble, std::string* diff) { if (a.op() != b.op()) { if (diff) { *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), @@ -131,8 +132,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, } return false; } - std::unordered_set control_input_a; - std::unordered_set control_input_b; + std::unordered_set control_input_a; + std::unordered_set control_input_b; for (int i = 0; i < a.input_size(); ++i) { if (absl::StartsWith(a.input(i), "^")) { if (!absl::StartsWith(b.input(i), "^")) { @@ -164,17 +165,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, } return false; } - return EqualProtoMap( - a.attr(), b.attr(), [](const string& s) { return s; }, + return EqualProtoMap( + a.attr(), b.attr(), [](const std::string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, - [](const string& key, const AttrValue& av, const AttrValue& bv) { + [](const std::string& key, const AttrValue& av, const AttrValue& bv) { if (key == "ancestors") { // The ancestors are added from a set so the order is unpredictable; // just compare set equality not list equality. - std::unordered_set a_set(av.list().s().begin(), - av.list().s().end()); - std::unordered_set b_set(bv.list().s().begin(), - bv.list().s().end()); + std::unordered_set a_set(av.list().s().begin(), + av.list().s().end()); + std::unordered_set b_set(bv.list().s().begin(), + bv.list().s().end()); return a_set == b_set; } else { return av.DebugString() == bv.DebugString(); @@ -184,7 +185,7 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, } bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, - string* diff) { + std::string* diff) { if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { *diff = @@ -194,22 +195,21 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } return false; } - if (!EqualProtoMap( - a.attr(), b.attr(), [](const string& s) { return s; }, + if (!EqualProtoMap( + a.attr(), b.attr(), [](const std::string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, - [](const string& key, const AttrValue& av, const AttrValue& bv) { + [](const std::string& key, const AttrValue& av, const AttrValue& bv) { return av.DebugString() == bv.DebugString(); }, absl::StrCat("attr mismatch for function ", a.signature().name()), diff)) { return false; } - if (!EqualProtoMap( - a.ret(), b.ret(), [](const string& s) { return s; }, - [](const string& s) { return s; }, - [](const string& key, const string& av, const string& bv) { - return av == bv; - }, + if (!EqualProtoMap( + a.ret(), b.ret(), [](const std::string& s) { return s; }, + [](const std::string& s) { return s; }, + [](const std::string& key, const std::string& av, + const std::string& bv) { return av == bv; }, absl::StrCat("ret mismatch for function ", a.signature().name()), diff)) { return false; @@ -257,8 +257,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, - const FunctionDefLibrary& actual, string* diff) { - std::unordered_map actual_index; + const FunctionDefLibrary& actual, + std::string* diff) { + std::unordered_map actual_index; for (const FunctionDef& function : actual.function()) { actual_index[function.signature().name()] = &function; } @@ -343,7 +344,7 @@ REGISTER_OP("AddNLikeTest") .SetIsAggregate(); Node* Sequencer(const GraphDefBuilder::Options& opts, - const string& call_node_name) { + const std::string& call_node_name) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp", opts.op_registry()); @@ -383,7 +384,7 @@ Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) { return KnownShapeBase(DT_STRING, {2}, opts); } -Node* KeyPlaceholder(const string& call_node, +Node* KeyPlaceholder(const std::string& call_node, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(absl::StrCat(call_node, "_key_placeholder"), @@ -396,15 +397,16 @@ Node* KeyPlaceholder(const string& call_node, .FinalizeBuilder(&node_builder); } -Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, - const string& new_func_name, const string& oc_cluster, +Node* RecvAtHost(ops::NodeOut key_input, const std::string& cluster, + const std::string& new_func_name, + const std::string& oc_cluster, absl::Span dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = absl::StrCat("host_compute_channel_", cluster, "_", - new_func_name, "_", oc_cluster); - string name = absl::StrCat("outside_compilation_", cluster, "_", - new_func_name, "_", oc_cluster, "_recv"); + std::string key = absl::StrCat("host_compute_channel_", cluster, "_", + new_func_name, "_", oc_cluster); + std::string name = absl::StrCat("outside_compilation_", cluster, "_", + new_func_name, "_", oc_cluster, "_recv"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); node_builder.Input(std::move(key_input)); @@ -416,15 +418,16 @@ Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, .FinalizeBuilder(&node_builder); } -Node* SendFromHost(ops::NodeOut key_input, const string& cluster, - const string& new_func_name, const string& oc_cluster, +Node* SendFromHost(ops::NodeOut key_input, const std::string& cluster, + const std::string& new_func_name, + const std::string& oc_cluster, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = absl::StrCat("host_compute_channel_", cluster, "_", - new_func_name, "_", oc_cluster); - string name = absl::StrCat("outside_compilation_", cluster, "_", - new_func_name, "_", oc_cluster, "_send"); + std::string key = absl::StrCat("host_compute_channel_", cluster, "_", + new_func_name, "_", oc_cluster); + std::string name = absl::StrCat("outside_compilation_", cluster, "_", + new_func_name, "_", oc_cluster, "_send"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); @@ -477,8 +480,9 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { return opts.FinalizeBuilder(&node_builder); } -absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, - const std::vector& encapsulated_functions) { +absl::Status Encapsulate( + GraphDef* graphdef, FunctionDefLibrary* library, + const std::vector& encapsulated_functions) { absl::Status s; // Convert the GraphDef to a Graph std::unique_ptr lib_def( @@ -512,7 +516,7 @@ absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, &graph_out, lib_def.get()); if (!s.ok()) return s; - std::unordered_map clusters; + std::unordered_map clusters; for (const auto& func : encapsulated_functions) { Node* xla_computation_node; for (Node* n : graph_out->nodes()) { @@ -527,7 +531,7 @@ absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, func_name_attrs.set_name(func); clusters.emplace(func, XlaClusterInfo{func, func_name_attrs, xla_computation_node, - std::map{}}); + std::map{}}); } bool modified; s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, @@ -551,7 +555,7 @@ absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, } absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { - std::vector encapsulated_functions; + std::vector encapsulated_functions; return Encapsulate(graphdef, library, encapsulated_functions); } @@ -698,8 +702,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { } // Returns a vector of node names in 'graph', sorted by name. -std::vector GraphNodes(const Graph& graph) { - std::vector nodes; +std::vector GraphNodes(const Graph& graph) { + std::vector nodes; for (const auto& node : graph.nodes()) { if (!node->IsSource() && !node->IsSink()) { nodes.push_back(node->name()); @@ -710,8 +714,9 @@ std::vector GraphNodes(const Graph& graph) { } // Returns a sorted vector of (src, dst) edges in 'graph'. -std::vector> GraphEdges(const Graph& graph) { - std::vector> edges; +std::vector> GraphEdges( + const Graph& graph) { + std::vector> edges; for (const Edge* edge : graph.edges()) { if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; edges.emplace_back( @@ -742,10 +747,11 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); - std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; + std::vector expected_nodes = {"cluster1", "cluster2", "mul", + "x"}; EXPECT_EQ(expected_nodes, GraphNodes(*graph)); - std::vector> expected_edges = { + std::vector> expected_edges = { {"cluster1:0", "cluster2:0"}, {"cluster1:0", "mul:0"}, {"cluster2:0", "mul:1"}, @@ -753,7 +759,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } -const Node* FindNodeByName(const Graph& graph, const string& name) { +const Node* FindNodeByName(const Graph& graph, const std::string& name) { for (const Node* node : graph.nodes()) { if (node->name() == name) return node; } @@ -889,7 +895,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -931,7 +937,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"C:o:0", "c:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -941,7 +947,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"c"}}, @@ -1025,7 +1031,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1102,7 +1108,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"F:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"send_key", ""}, {"recv_key", ""}, @@ -1112,8 +1118,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node", - "outside_compilation_O1_host_compute"})}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}}, {"F", "outside_compilation_O1_host_compute"}}, @@ -1122,7 +1129,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"C:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1132,7 +1139,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"D"}}, @@ -1235,7 +1242,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1", "F2"}; + std::vector encapsulated_functions{"F1", "F2"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1262,7 +1269,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"C:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1273,7 +1280,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"D"}}, @@ -1295,7 +1302,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"d_0_arg", "G:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_F2_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1306,7 +1313,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1409,7 +1416,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1", "F2"}; + std::vector encapsulated_functions{"F1", "F2"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1432,7 +1439,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"C:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1443,7 +1450,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"D"}}, @@ -1462,7 +1469,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"G:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_F2_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1473,7 +1480,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1556,7 +1563,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1578,7 +1585,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1589,7 +1596,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1652,7 +1659,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1674,7 +1681,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1685,7 +1692,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"D"}}, @@ -1748,7 +1755,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1785,7 +1792,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1795,7 +1802,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1858,7 +1865,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1899,7 +1906,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1909,7 +1916,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1978,7 +1985,7 @@ TEST(EncapsulateSubgraphsTest, TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2037,7 +2044,7 @@ TEST(EncapsulateSubgraphsTest, {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2047,7 +2054,7 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, {{"outside_compilation_O2_host_compute"}, @@ -2055,7 +2062,7 @@ TEST(EncapsulateSubgraphsTest, {"F:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"send_key", ""}, {"recv_key", ""}, @@ -2065,8 +2072,9 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node", - "outside_compilation_O1_host_compute"})}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}}, {"outside_compilation_O1_host_compute"}}, @@ -2149,7 +2157,7 @@ TEST(EncapsulateSubgraphsTest, TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2189,7 +2197,7 @@ TEST(EncapsulateSubgraphsTest, {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"send_key", ""}, {"recv_key", ""}, @@ -2199,8 +2207,9 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node", - "outside_compilation_O1_host_compute"})}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}}, {"outside_compilation_O1_host_compute"}}, @@ -2209,7 +2218,7 @@ TEST(EncapsulateSubgraphsTest, {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2219,7 +2228,7 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -2303,7 +2312,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2340,7 +2349,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2350,7 +2359,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, {{"outside_compilation_O2_host_compute"}, @@ -2358,7 +2367,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"send_key", ""}, {"recv_key", ""}, @@ -2368,7 +2377,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span( + absl::Span( {"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}}, {"outside_compilation_O1_host_compute"}}, @@ -2377,7 +2386,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O3"}, {"send_key", ""}, {"recv_key", ""}, @@ -2387,9 +2396,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node", - "outside_compilation_O1_host_compute", - "outside_compilation_O2_host_compute"})}, + absl::Span( + {"_xla_token_arg_node", "outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O3_host_compute"}}, {"outside_compilation_O1_host_compute", "outside_compilation_O2_host_compute"}}}, @@ -2470,7 +2479,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2507,7 +2516,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2517,7 +2526,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -2586,7 +2595,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2627,7 +2636,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"c_0_arg", "c:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2637,7 +2646,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"c"}}, diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index fa94a341bbabc6..445ca63c05ad66 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -36,7 +36,8 @@ namespace { // Returns string attribute value for the node if the attribute is present, // otherwise returns empty optional value. -std::optional GetStringAttr(const Node& n, const string& attr_name) { +std::optional GetStringAttr(const Node& n, + const std::string& attr_name) { auto attr = n.attrs().Find(attr_name); if (!attr) { return std::nullopt; @@ -47,8 +48,8 @@ std::optional GetStringAttr(const Node& n, const string& attr_name) { // Adds a value to the node's list attribute. template -absl::Status AppendToListAttr(Node* n, const string& attr_name, - const string& value) { +absl::Status AppendToListAttr(Node* n, const std::string& attr_name, + const std::string& value) { std::vector attr_value; absl::Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value); if (!s.ok() && s.code() != error::NOT_FOUND) { @@ -63,7 +64,7 @@ absl::Status AppendToListAttr(Node* n, const string& attr_name, // Replaces attribute value. template -void ReplaceAttr(Node* n, const string& attr_name, const T& value) { +void ReplaceAttr(Node* n, const std::string& attr_name, const T& value) { n->ClearAttr(attr_name); n->AddAttr(attr_name, value); } @@ -71,7 +72,7 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) { // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. absl::Status PreprocessControlEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { // Gather edges to remove. We should not remove the edge while iterating. std::vector edges_to_remove; for (const Edge* e : g->edges()) { @@ -89,7 +90,7 @@ absl::Status PreprocessControlEdgesBetweenOutsideCompilations( // Case 1a: outside compilation to outside compilation control edge. edges_to_remove.push_back(e); - TF_RETURN_IF_ERROR(AppendToListAttr( + TF_RETURN_IF_ERROR(AppendToListAttr( e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName, e->src()->name())); } @@ -111,7 +112,7 @@ absl::Status PreprocessControlEdgesBetweenOutsideCompilations( // Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. absl::Status PreprocessDataEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { // Gather edges between outside compilation and host computation. Notice that // we do not store `Edge*` directly because we remove some nodes while adding // Identity nodes, and those Edge pointers might be invalidated. @@ -138,7 +139,7 @@ absl::Status PreprocessDataEdgesBetweenOutsideCompilations( // Remove the edge from host to outside compilation. Add a placeholder as // outside compilation node input. - std::map, Node*> placeholders; + std::map, Node*> placeholders; for (int i = 0, end = edges.size(); i < end; i++) { Node* dst = g->FindNodeId(edges[i].dst_node_id); const Edge* e; @@ -148,7 +149,7 @@ absl::Status PreprocessDataEdgesBetweenOutsideCompilations( g->RemoveEdge(e); // Find or create placeholder node. - string new_name = + std::string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output); auto placeholder_index = std::make_pair(src->name(), src_output); auto iter = placeholders.find(placeholder_index); @@ -156,7 +157,7 @@ absl::Status PreprocessDataEdgesBetweenOutsideCompilations( if (iter == placeholders.end()) { NodeDefBuilder placeholder_builder(new_name, "Placeholder"); placeholder_builder.Attr("dtype", src->output_type(src_output)); - string outside_compilation_attr; + std::string outside_compilation_attr; TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), outside_compilation_attr_name, &outside_compilation_attr)); @@ -195,7 +196,7 @@ absl::Status PreprocessDataEdgesBetweenOutsideCompilations( // Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of // `PostprocessEdgesBetweenOutsideCompilations` for details. absl::Status PostprocessDataEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { // Gather all outside compilation to outside compilation nodes. std::vector placeholder_nodes; for (Node* n : g->nodes()) { @@ -208,7 +209,7 @@ absl::Status PostprocessDataEdgesBetweenOutsideCompilations( // Remove the placeholder nodes, and reconnect original edge. auto node_name_index = g->BuildNodeNameIndex(); for (auto n : placeholder_nodes) { - string node_name; + std::string node_name; int node_src_output; TF_RETURN_IF_ERROR(GetNodeAttr( n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name)); @@ -271,12 +272,12 @@ absl::Status PostprocessDataEdgesBetweenOutsideCompilations( // Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of // `PostprocessEdgesBetweenOutsideCompilations` for details. absl::Status PostprocessControlEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { auto node_name_index = g->BuildNodeNameIndex(); // Reconnect outside compilation to outside compilation control edge. for (Node* n : g->nodes()) { - std::vector control_deps; + std::vector control_deps; absl::Status s = GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName, &control_deps); @@ -288,7 +289,7 @@ absl::Status PostprocessControlEdgesBetweenOutsideCompilations( } } else { n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName); - for (const string& control_input : control_deps) { + for (const std::string& control_input : control_deps) { auto iter = node_name_index.find(control_input); if (iter == node_name_index.end()) { return errors::Internal("Cannot find original node for ", @@ -342,11 +343,11 @@ absl::Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { } absl::StatusOr< - std::unique_ptr>>> + std::unique_ptr>>> OutsideCompilationClusterDependencies( - const Graph* g, const string& outside_compilation_attr_name) { + const Graph* g, const std::string& outside_compilation_attr_name) { auto cluster_deps = std::make_unique< - absl::flat_hash_map>>(); + absl::flat_hash_map>>(); for (const Edge* e : g->edges()) { auto src_outside_compilation = @@ -360,18 +361,18 @@ OutsideCompilationClusterDependencies( if (dst_deps_it == cluster_deps->end()) { cluster_deps->insert(std::make_pair( *dst_outside_compilation, - absl::flat_hash_set({*src_outside_compilation}))); + absl::flat_hash_set({*src_outside_compilation}))); } else { dst_deps_it->second.insert(*src_outside_compilation); } } } - auto cluster_deps_ordered = - std::make_unique>>(); + auto cluster_deps_ordered = std::make_unique< + absl::flat_hash_map>>(); for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) { - std::vector ordered_deps(it->second.begin(), it->second.end()); + std::vector ordered_deps(it->second.begin(), it->second.end()); std::sort(ordered_deps.begin(), ordered_deps.end()); cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps)); } @@ -380,7 +381,7 @@ OutsideCompilationClusterDependencies( } absl::Status PreprocessEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges // from outside compilation nodes to sink node. std::vector edges_to_remove; @@ -406,7 +407,7 @@ absl::Status PreprocessEdgesBetweenOutsideCompilations( } absl::Status PostprocessEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations( g, outside_compilation_attr_name)); TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations( diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index 7c99763c770728..81ab31c79dcda2 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -95,21 +95,21 @@ struct XlaClusterInfo { // without losing aggregate initialization, which allows us to get rid of // the constructor definitions again. XlaClusterInfo() {} - XlaClusterInfo(const string& cluster_name, + XlaClusterInfo(const std::string& cluster_name, const NameAttrList& func_name_attrs, Node* node, - const std::map& host_compute_core) + const std::map& host_compute_core) : cluster_name(cluster_name), func_name_attrs(func_name_attrs), node(node), host_compute_core(host_compute_core) {} // XLA cluster name. It might be different from `func_name`. - const string cluster_name; + const std::string cluster_name; // Name and attributes of XLA computation function. const NameAttrList func_name_attrs; // The XLA computation node in the graph. Node* node; // A mapping from outside compilation cluster name to its device assignment. - const std::map host_compute_core; + const std::map host_compute_core; }; // Finds dependencies between outside compilation clusters, including both data @@ -117,9 +117,9 @@ struct XlaClusterInfo { // outside compilation cluster to a set of names of outside compilation clusters // that it depends on. absl::StatusOr< - std::unique_ptr>>> + std::unique_ptr>>> OutsideCompilationClusterDependencies( - const Graph* g, const string& outside_compilation_attr_name); + const Graph* g, const std::string& outside_compilation_attr_name); // Preprocesses edges within the same XLA cluster. It will perform the following // operations in order: @@ -135,7 +135,7 @@ OutsideCompilationClusterDependencies( // 2. For data edges between different outside compilations, remove the edge // and create a Placeholder node as dst node's input. absl::Status PreprocessEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name); + Graph* g, const std::string& outside_compilation_attr_name); // Postprocesses edges within the same XLA cluster. This function reverts what // `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the @@ -149,7 +149,7 @@ absl::Status PreprocessEdgesBetweenOutsideCompilations( // `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. // They are handled in `RewriteOutsideCompilationSubgraphFn`. absl::Status PostprocessEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name); + Graph* g, const std::string& outside_compilation_attr_name); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 0e59bf0c19d93e..8ba11404010363 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -46,7 +46,7 @@ const char* const kXlaClusterOutput = "XlaClusterOutput"; bool IsCpuGpuCompile(const Graph* graph) { for (Node* n : graph->nodes()) { - string name; + std::string name; // Only consider nodes being compiled. if (!TryGetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name)) continue; // Early return for any node with a device that is not a CPU or GPU. @@ -185,7 +185,7 @@ absl::Status RewriteSubgraph( // Uniquify the function name by computing a fingerprint of the function. // Nondeterminism in serialization would not lead to incorrect results, but // may cause spurious cache misses. - TF_ASSIGN_OR_RETURN(uint64 fingerprint, FingerprintGraph(*graph)); + TF_ASSIGN_OR_RETURN(uint64_t fingerprint, FingerprintGraph(*graph)); VLOG(1) << "Subgraph fingerprint:" << fingerprint; call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); return absl::OkStatus(); @@ -360,7 +360,8 @@ absl::Status RewriteSubgraph( /*static*/ absl::Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( Graph* graph) { const auto is_xla_launch_node = [](const Node& node) -> absl::StatusOr { - const string& name = GetNodeAttrString(node.attrs(), kXlaClusterIdAttr); + const std::string& name = + GetNodeAttrString(node.attrs(), kXlaClusterIdAttr); return !name.empty(); }; diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index 16a17c3c2a03a6..acd5319cf8ed16 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -34,7 +34,7 @@ limitations under the License. namespace tensorflow { static std::unique_ptr MakeOuterGraph( - const FunctionLibraryDefinition& flib_def, const string& function) { + const FunctionLibraryDefinition& flib_def, const std::string& function) { Scope scope = Scope::NewRootScope().ExitOnError(); TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); @@ -143,7 +143,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { // Test that control edge insertion order doesn't affect the cache key // (cluster name) generated by TPU encapsulate pass. auto get_serialized_graph = [](bool control_input_reversed, - bool operand_reversed) -> string { + bool operand_reversed) -> std::string { FunctionLibraryDefinition flib_def(OpRegistry::Global(), FunctionDefLibrary()); std::unique_ptr graph(new Graph(&flib_def)); @@ -250,8 +250,8 @@ TEST(EncapsulateXlaComputations, Encapsulate) { TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); - std::unordered_map index = graph->BuildNodeNameIndex(); - string function = index.at("launch0")->type_string(); + std::unordered_map index = graph->BuildNodeNameIndex(); + std::string function = index.at("launch0")->type_string(); // Tests the outer graph is as expected. { @@ -285,9 +285,9 @@ TEST(EncapsulateXlaComputations, Encapsulate) { // function. Encapsulation should be deterministic to avoid recompilation. TF_ASSERT_OK( EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); - std::unordered_map index_copy = + std::unordered_map index_copy = graph_copy->BuildNodeNameIndex(); - string function_copy = index_copy.at("launch0")->type_string(); + std::string function_copy = index_copy.at("launch0")->type_string(); EXPECT_EQ(function, function_copy); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 140c47dbcac804..05514f00bd29d5 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -42,7 +42,7 @@ namespace { // Control return mapping function for outside compilation host graphs. // All nodes with kXlaHasHostTransfer attribute are control outputs. -std::optional HostGraphControlRetMapping(const Node* n) { +std::optional HostGraphControlRetMapping(const Node* n) { if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { return n->name(); } @@ -52,7 +52,7 @@ std::optional HostGraphControlRetMapping(const Node* n) { // Add a key placeholder node to the graph. The key placeholder node will be // used as input for XlaRecvAtHost/XlaSendFromHost nodes. absl::StatusOr AddHostComputeKeyPlaceholder( - const string& xla_cluster_name, Graph* g) { + const std::string& xla_cluster_name, Graph* g) { NodeDef key_def; NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"), "Placeholder"); @@ -74,7 +74,8 @@ bool IsKeyPlaceholderNode(const Node& n) { } // Returns nodes with given type. -std::vector GatherNodesWithType(const Graph& g, const string& type) { +std::vector GatherNodesWithType(const Graph& g, + const std::string& type) { std::vector result; for (Node* n : g.nodes()) { if (n->type_string() == type) { @@ -105,7 +106,7 @@ absl::Status GetArgDataTypes(const std::vector& arg_nodes, // Builds XlaRecvAtHost node. absl::StatusOr BuildRecvAtHostNode( - Graph* g, const string& oc_cluster_name, + Graph* g, const std::string& oc_cluster_name, const std::vector& recv_at_host_dtypes, Node* key_placeholder) { NodeDefBuilder recv_at_host_builder( absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"), @@ -128,7 +129,7 @@ absl::StatusOr BuildRecvAtHostNode( // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it. absl::StatusOr ReplaceArgNodesWithRecvAtHostNode( - Graph* g, const string& oc_cluster_name, + Graph* g, const std::string& oc_cluster_name, std::vector* recv_at_host_dtypes, Node* key_placeholder) { // TODO(b/77601805): use out nodes for source node, instead of traversing all // nodes. @@ -205,7 +206,7 @@ absl::Status GetRetDataTypes(const std::vector& ret_nodes, // Builds XlaSendFromHost node. absl::StatusOr BuildSendFromHostNode( - Graph* g, const string& oc_cluster_name, + Graph* g, const std::string& oc_cluster_name, const std::vector& ret_nodes, const std::vector& send_from_host_dtypes, Node* key_placeholder) { NodeDefBuilder send_from_host_builder( @@ -245,7 +246,7 @@ absl::StatusOr BuildSendFromHostNode( // Builds XlaSendFromHost node, and replaces all _Retval nodes with it. absl::StatusOr ReplaceRetNodesWithSendFromHostNode( - Graph* g, const string& oc_cluster_name, + Graph* g, const std::string& oc_cluster_name, std::vector* send_from_host_dtypes, Node* key_placeholder) { // TODO(b/77601805): use in nodes for sink node, instead of traversing all // nodes. @@ -299,16 +300,17 @@ std::optional> GetInferredInputShapes( return results; } -string host_compute_node_name(const string& original_oc_name) { +std::string host_compute_node_name(const std::string& original_oc_name) { return absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"); } // Builds XlaHostCompute NodeDef from the outside compilation call node. absl::StatusOr BuildXlaHostComputeNodeDef( - const Node* call_node, const std::map& host_compute_core, - const absl::flat_hash_map>& cluster_deps) { - string original_oc_name; + const Node* call_node, const std::map& host_compute_core, + const absl::flat_hash_map>& + cluster_deps) { + std::string original_oc_name; TF_RETURN_IF_ERROR(GetNodeAttr( call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name)); NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name), @@ -341,7 +343,7 @@ absl::StatusOr BuildXlaHostComputeNodeDef( // according to their host-side graph dependency. This can cause deadlock. // Therefore, we hint XLA what the correct ordering of these clusters should // be to avoid deadlocks. - std::vector xla_token_input_nodes; + std::vector xla_token_input_nodes; xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName); auto cluster_deps_it = cluster_deps.find(original_oc_name); if (cluster_deps_it != cluster_deps.end()) { @@ -376,8 +378,10 @@ absl::StatusOr BuildXlaHostComputeNodeDef( // Replace outside compilation function call node with XlaHostCompute node. TF_ATTRIBUTE_NOINLINE absl::StatusOr ReplaceOutsideCompilationCallNode( - Graph* g, Node* call_node, const std::map& host_compute_core, - const absl::flat_hash_map>& cluster_deps) { + Graph* g, Node* call_node, + const std::map& host_compute_core, + const absl::flat_hash_map>& + cluster_deps) { // Build XlaHostCompute NodeDef. TF_ASSIGN_OR_RETURN( NodeDef node_def, @@ -405,8 +409,8 @@ absl::Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->ClearAttr("device_ordinal"); n->AddAttr("device_ordinal", device_ordinal_value); } else if (n->IsIfNode()) { - for (const string& attr_name : - std::vector{"then_branch", "else_branch"}) { + for (const std::string& attr_name : + std::vector{"then_branch", "else_branch"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; @@ -414,7 +418,8 @@ absl::Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->AddAttr(attr_name, branch_func); } } else if (n->IsWhileNode()) { - for (const string& attr_name : std::vector{"cond", "body"}) { + for (const std::string& attr_name : + std::vector{"cond", "body"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; @@ -448,11 +453,12 @@ bool HasLiftedArgs(const FunctionDef& function_def) { absl::StatusOr>> LiftedArgsAndOutsideCompilationNodesInFunctionBody( const FunctionBody& function_body, - const std::unordered_map& outside_compilation_attr_to_node) { + const std::unordered_map& + outside_compilation_attr_to_node) { std::vector> lifted_arg_nodes_and_outside_compilation_nodes; for (Node* n : function_body.graph->op_nodes()) { - string oc_cluster; + std::string oc_cluster; if (n->type_string() == "Placeholder" && GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName, &oc_cluster) @@ -471,7 +477,7 @@ LiftedArgsAndOutsideCompilationNodesInFunctionBody( absl::StatusOr> UpdateTypesAttribute( const std::vector>& lifted_arg_nodes_and_outside_compilation_nodes, - const string& type_attr_name, Node* n) { + const std::string& type_attr_name, Node* n) { std::vector data_types; data_types.reserve(lifted_arg_nodes_and_outside_compilation_nodes.size()); TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types)); @@ -578,7 +584,8 @@ absl::Status AddFunctionWithNewName(const std::string& new_name, // Reconnect outside compilation lifted arguments in a functional While node to // its outside compilation tensor sources. absl::Status PostprocessLiftedArgsForWhile( - const std::unordered_map& outside_compilation_attr_to_node, + const std::unordered_map& + outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { TF_RET_CHECK(n->IsWhileNode()); @@ -687,7 +694,8 @@ absl::Status PostprocessLiftedArgsForWhile( } absl::Status PostprocessLiftedArgsForIf( - const std::unordered_map& outside_compilation_attr_to_node, + const std::unordered_map& + outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { TF_RET_CHECK(n->IsIfNode()); @@ -826,7 +834,8 @@ absl::Status PostprocessLiftedArgsForIf( } absl::Status PostprocessLiftedArgsForCall( - const std::unordered_map& outside_compilation_attr_to_node, + const std::unordered_map& + outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { const FunctionDef* fdef = fld->Find(n->type_string()); TF_RET_CHECK(fdef); @@ -924,12 +933,12 @@ absl::Status PostprocessLiftedArgsForCall( // Creates a mapping from outside compilation cluster name to lifted argument // placeholder. -absl::StatusOr> OutsideCompilationAttrToNode( - const Graph& g) { - std::unordered_map outside_compilation_attr_to_node; +absl::StatusOr> +OutsideCompilationAttrToNode(const Graph& g) { + std::unordered_map outside_compilation_attr_to_node; for (Node* n : g.op_nodes()) { bool is_lifted_arg; - string outside_compilation_attr; + std::string outside_compilation_attr; if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) && TryGetNodeAttr(n->def(), "_xla_outside_compilation", &outside_compilation_attr)) { @@ -988,8 +997,9 @@ absl::Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { // replace this node with compilation result node. // 3) all outside compilation graphs. absl::Status ConstructHostGraph( - const string& xla_cluster_name, const string& outside_compilation_attr_name, - const std::vector& outside_compilation_host_graphs, + const std::string& xla_cluster_name, + const std::string& outside_compilation_attr_name, + const std::vector& outside_compilation_host_graphs, FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { host_graph->reset(new Graph(fld)); @@ -1013,7 +1023,7 @@ absl::Status ConstructHostGraph( // XlaSendFromHost, If/While nodes containing // XlaRecvAtHost/XlaSendFromHost) to sequencer node. // c) Clear node_def.device(), so device placer won't get confused. - for (const string& host_func : outside_compilation_host_graphs) { + for (const std::string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder // value after we expanded all host graphs. We cannot just use placeholder @@ -1021,7 +1031,7 @@ absl::Status ConstructHostGraph( // value for attributes. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_attr; std::unique_ptr host_fbody; const FunctionDef* host_fdef = fld->Find(host_func); @@ -1123,18 +1133,17 @@ absl::Status ConstructHostGraph( // Expand XLA computation's outside compilation host side graph into main graph. // Add a control edge between sequencer node and the XLA computation node. -absl::Status ExpandHostGraphIntoMainGraph(Graph* main_graph, - FunctionLibraryDefinition* fld, - const string& host_graph_func_name, - Node* xla_computation_node, - Node* pivot_node) { +absl::Status ExpandHostGraphIntoMainGraph( + Graph* main_graph, FunctionLibraryDefinition* fld, + const std::string& host_graph_func_name, Node* xla_computation_node, + Node* pivot_node) { // Temporarily use "0" as "_device_ordinal". It will be rewritten with the // correct value in a later pass. We cannot just use placeholder value here // because FunctionDef instantiation does not allow placeholder value for // attributes. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_attr; std::unique_ptr fbody; const FunctionDef* host_graph_func = fld->Find(host_graph_func_name); @@ -1207,12 +1216,12 @@ absl::Status ExpandHostGraphIntoMainGraph(Graph* main_graph, // 2) Remove control edges. // 3) Prune nodes that are not useful for shape inference. absl::Status RewriteShapeInferenceGraph( - const string& shape_inference_graph_name, Graph* host_graph, + const std::string& shape_inference_graph_name, Graph* host_graph, Node* pivot_node, FunctionLibraryDefinition* fld) { // Use "0" as "_device_ordinal". It does not matter for shape inference. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_attr; std::unique_ptr fbody; const FunctionDef* shape_inference_graph = @@ -1338,13 +1347,13 @@ void SetMaximalSharding(NodeDefBuilder& node_builder) { // Builds XlaSendToHost node which sends cond predicate to host. TF_ATTRIBUTE_NOINLINE absl::StatusOr BuildSendIfPredNode( - const string& name, const string& host_transfer_key, Node* pred_node, - Graph* g) { + const std::string& name, const std::string& host_transfer_key, + Node* pred_node, Graph* g) { NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); send_pred_builder.Attr("Tinput", DT_BOOL); send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); send_pred_builder.Attr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + std::vector{kXlaTokenArgNodeName}); send_pred_builder.Attr(kXlaOriginalOutsideCompilationNodeName, name); SetMaximalSharding(send_pred_builder); send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); @@ -1356,14 +1365,14 @@ TF_ATTRIBUTE_NOINLINE absl::StatusOr BuildSendIfPredNode( } // Replaces key placeholder node with an _Arg node. -absl::Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, - const string& func_name, - FunctionLibraryDefinition* fld) { +absl::Status ReplaceKeyPlaceholderWithArgNode( + const std::string& xla_cluster_name, const std::string& func_name, + FunctionLibraryDefinition* fld) { // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder // value after rewriting. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_attr; std::unique_ptr fbody; const FunctionDef* func = fld->Find(func_name); @@ -1404,14 +1413,15 @@ absl::Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, // Builds host side graph for If node. TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForIfNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const string& if_node_name, const string& host_transfer_key, - const string& host_graph_func_name, FunctionLibraryDefinition* fld, - const string& then_branch_host_func_name, - const string& else_branch_host_func_name) { + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const std::string& if_node_name, + const std::string& host_transfer_key, + const std::string& host_graph_func_name, FunctionLibraryDefinition* fld, + const std::string& then_branch_host_func_name, + const std::string& else_branch_host_func_name) { Graph host_graph(fld); - string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); + std::string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); AttrValue device_ordinal_value; device_ordinal_value.set_placeholder("_device_ordinal"); @@ -1484,7 +1494,7 @@ TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForIfNode( // Rewrites loop cond to add a node which sends loop cond to host. TF_ATTRIBUTE_NOINLINE absl::Status AddSendLoopPredToLoopCond( - const string& cond_xla_func_name, const string& host_transfer_key, + const std::string& cond_xla_func_name, const std::string& host_transfer_key, NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld, Node* while_node) { // Instantiate the loop cond function. @@ -1523,7 +1533,7 @@ TF_ATTRIBUTE_NOINLINE absl::Status AddSendLoopPredToLoopCond( send_loop_cond_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + std::vector{kXlaTokenArgNodeName}); send_loop_cond_builder.Attr(kXlaOriginalOutsideCompilationNodeName, send_loop_cond_builder.node_name()); SetMaximalSharding(send_loop_cond_builder); @@ -1560,10 +1570,13 @@ TF_ATTRIBUTE_NOINLINE absl::Status AddSendLoopPredToLoopCond( // Rewrites while loop cond function for host. absl::Status RewriteHostWhileLoopCond( - const string& cond_host_func_name, const string& while_node_name, - const string& host_transfer_key, const string& xla_cluster_attr_name, - const string& xla_cluster_name, const string& outside_compilation_attr_name, - const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + const std::string& cond_host_func_name, const std::string& while_node_name, + const std::string& host_transfer_key, + const std::string& xla_cluster_attr_name, + const std::string& xla_cluster_name, + const std::string& outside_compilation_attr_name, + const std::string& outside_compilation_name, + FunctionLibraryDefinition* fld) { // Replace key placeholder node with _Arg node. TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( xla_cluster_name, cond_host_func_name, fld)); @@ -1571,7 +1584,7 @@ absl::Status RewriteHostWhileLoopCond( // Instantiate cond function. AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_temp_value; std::unique_ptr cond_fbody; const FunctionDef* cond_host_func = fld->Find(cond_host_func_name); @@ -1634,10 +1647,13 @@ absl::Status RewriteHostWhileLoopCond( // Rewrites while loop body function for host. absl::Status RewriteHostWhileLoopBody( - const string& body_host_func_name, const string& while_node_name, - const string& host_transfer_key, const string& xla_cluster_attr_name, - const string& xla_cluster_name, const string& outside_compilation_attr_name, - const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + const std::string& body_host_func_name, const std::string& while_node_name, + const std::string& host_transfer_key, + const std::string& xla_cluster_attr_name, + const std::string& xla_cluster_name, + const std::string& outside_compilation_attr_name, + const std::string& outside_compilation_name, + FunctionLibraryDefinition* fld) { // Replace key placeholder node with _Arg node. TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( xla_cluster_name, body_host_func_name, fld)); @@ -1645,7 +1661,7 @@ absl::Status RewriteHostWhileLoopBody( // Instantiate body function. AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_temp_value; std::unique_ptr body_fbody; const FunctionDef* body_host_func = fld->Find(body_host_func_name); @@ -1692,13 +1708,16 @@ absl::Status RewriteHostWhileLoopBody( // Builds host side graph for while node. TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForWhileNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const string& while_node_name, const string& host_transfer_key, - const string& host_graph_func_name, FunctionLibraryDefinition* fld, - const string& cond_host_func_name, const string& body_host_func_name) { + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const std::string& while_node_name, + const std::string& host_transfer_key, + const std::string& host_graph_func_name, FunctionLibraryDefinition* fld, + const std::string& cond_host_func_name, + const std::string& body_host_func_name) { Graph host_graph(fld); - string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); + std::string outside_compilation_name = + absl::StrCat("oc_while_", while_node_name); // Step 1: add key placeholder node. TF_ASSIGN_OR_RETURN( @@ -1759,10 +1778,12 @@ TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForWhileNode( // Builds host graph for func call nodes. absl::Status BuildHostGraphForFuncCallNode( - const string& xla_cluster_attr_name, const string& xla_cluster_name, - const string& outside_compilation_attr_name, - const string& func_call_node_name, const string& func_call_host_func_name, - const string& host_graph_func_name, FunctionLibraryDefinition* fld) { + const std::string& xla_cluster_attr_name, + const std::string& xla_cluster_name, + const std::string& outside_compilation_attr_name, + const std::string& func_call_node_name, + const std::string& func_call_host_func_name, + const std::string& host_graph_func_name, FunctionLibraryDefinition* fld) { Graph host_graph(fld); AttrValue device_ordinal_value; device_ordinal_value.set_placeholder("_device_ordinal"); @@ -1807,18 +1828,19 @@ absl::Status BuildHostGraphForFuncCallNode( } TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, Graph* g, Node* n, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, - std::vector* host_graphs, - std::vector* shape_inference_graphs, + std::vector* host_graphs, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { bool func_has_outside_compilation = false; NameAttrList func; if (fld->Contains(n->type_string())) { func.set_name(n->type_string()); - typedef protobuf::Map AttrMap; + typedef protobuf::Map AttrMap; *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); } else if (n->IsPartitionedCall()) { TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func)); @@ -1827,7 +1849,7 @@ TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( func.set_name(FunctionLibraryDefinition::kGradientOp); *func.mutable_attr() = n->def().attr(); } - string canonical_func_name; + std::string canonical_func_name; if (func.name() == FunctionLibraryDefinition::kGradientOp) { NameAttrList forward_func; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &forward_func)); @@ -1835,8 +1857,8 @@ TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( } else { canonical_func_name = func.name(); } - string new_func_name = absl::StrCat(canonical_func_name, "_oc"); - string host_func_name = + std::string new_func_name = absl::StrCat(canonical_func_name, "_oc"); + std::string host_func_name = absl::StrCat("oc_func_call_host_", canonical_func_name); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, @@ -1876,11 +1898,11 @@ TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get())); TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def)); replace->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + std::vector{kXlaTokenArgNodeName}); replace->AddAttr(kXlaOriginalOutsideCompilationNodeName, replace->name()); // Build host side graph for the function call. - string oc_host_graph_name = + std::string oc_host_graph_name = absl::StrCat("oc_func_host_graph_", replace->name()); TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode( xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, @@ -1893,12 +1915,13 @@ TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( } absl::Status ExtractOutsideCompilationForIfNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, Graph* g, Node* n, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, - std::vector* host_graphs, - std::vector* shape_inference_graphs, + std::vector* host_graphs, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { // Instantiate "then_branch" and "else_branch". NameAttrList then_branch, else_branch; @@ -1908,12 +1931,14 @@ absl::Status ExtractOutsideCompilationForIfNode( // Extract outside compilation for then_branch and else_branch. bool then_branch_has_outside_compilation = false; bool else_branch_has_outside_compilation = false; - string then_branch_host_func_name = - absl::StrCat("oc_then_branch_host_if_", then_branch.name()), - else_branch_host_func_name = - absl::StrCat("oc_else_branch_host_if_", else_branch.name()); - string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), - else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); + std::string then_branch_host_func_name = + absl::StrCat("oc_then_branch_host_if_", then_branch.name()), + else_branch_host_func_name = + absl::StrCat("oc_else_branch_host_if_", else_branch.name()); + std::string then_branch_xla_func_name = + absl::StrCat(then_branch.name(), "_oc"), + else_branch_xla_func_name = + absl::StrCat(else_branch.name(), "_oc"); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, then_branch, then_branch_xla_func_name, then_branch_host_func_name, @@ -1946,7 +1971,7 @@ absl::Status ExtractOutsideCompilationForIfNode( } n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name()); - string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); + std::string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); // XLA computation: add a SendToHost node to send cond predicate. Node* pred_node; @@ -1956,7 +1981,7 @@ absl::Status ExtractOutsideCompilationForIfNode( BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), host_transfer_key, pred_node, g)); n->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{send_pred_node->name()}); + std::vector{send_pred_node->name()}); // Add a control edge from `send_pred_node` to If node, so XlaCompiler will // visit If node after `send_pred_node`, thus the token output for @@ -1969,7 +1994,7 @@ absl::Status ExtractOutsideCompilationForIfNode( // we need to create a no-op host graph. if (!then_branch_has_outside_compilation) { std::unique_ptr then_branch_host_graph(new Graph(fld)); - std::vector then_branch_host_graphs; + std::vector then_branch_host_graphs; TF_RETURN_IF_ERROR(ConstructHostGraph( xla_cluster_name, outside_compilation_attr_name, then_branch_host_graphs, fld, &then_branch_host_graph)); @@ -1986,7 +2011,7 @@ absl::Status ExtractOutsideCompilationForIfNode( } if (!else_branch_has_outside_compilation) { std::unique_ptr else_branch_host_graph(new Graph(fld)); - std::vector else_branch_host_graphs; + std::vector else_branch_host_graphs; TF_RETURN_IF_ERROR(ConstructHostGraph( xla_cluster_name, outside_compilation_attr_name, else_branch_host_graphs, fld, &else_branch_host_graph)); @@ -2001,7 +2026,7 @@ absl::Status ExtractOutsideCompilationForIfNode( TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef)); } } - string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); + std::string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, n->name(), host_transfer_key, oc_host_graph_name, fld, @@ -2012,12 +2037,13 @@ absl::Status ExtractOutsideCompilationForIfNode( } absl::Status ExtractOutsideCompilationForWhileNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, Graph* g, Node* n, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, - std::vector* host_graphs, - std::vector* shape_inference_graphs, + std::vector* host_graphs, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { // Instantiate "cond" and "body". NameAttrList cond, body; @@ -2027,10 +2053,12 @@ absl::Status ExtractOutsideCompilationForWhileNode( // Extract outside compilation for cond and body. bool cond_has_outside_compilation = false; bool body_has_outside_compilation = false; - string cond_host_func_name = absl::StrCat("oc_cond_host_while_", cond.name()), - body_host_func_name = absl::StrCat("oc_body_host_while_", body.name()); - string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), - body_xla_func_name = absl::StrCat(body.name(), "_oc"); + std::string cond_host_func_name = + absl::StrCat("oc_cond_host_while_", cond.name()), + body_host_func_name = + absl::StrCat("oc_body_host_while_", body.name()); + std::string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), + body_xla_func_name = absl::StrCat(body.name(), "_oc"); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, @@ -2060,19 +2088,19 @@ absl::Status ExtractOutsideCompilationForWhileNode( } n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name()); - string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); + std::string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); // XLA computation: rewrite cond function to add a SendToHost node to send // loop predicate. TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond( cond_xla_func_name, host_transfer_key, &cond, fld, n)); n->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + std::vector{kXlaTokenArgNodeName}); // Build host side graph for the "While" node. if (!cond_has_outside_compilation) { std::unique_ptr cond_host_graph(new Graph(fld)); - std::vector host_graphs; + std::vector host_graphs; TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, host_graphs, fld, &cond_host_graph)); @@ -2088,7 +2116,7 @@ absl::Status ExtractOutsideCompilationForWhileNode( } if (!body_has_outside_compilation) { std::unique_ptr body_host_graph(new Graph(fld)); - std::vector host_graphs; + std::vector host_graphs; TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, host_graphs, fld, &body_host_graph)); @@ -2102,7 +2130,8 @@ absl::Status ExtractOutsideCompilationForWhileNode( TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef)); } } - string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); + std::string oc_host_graph_name = + absl::StrCat("oc_while_host_graph_", n->name()); TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, n->name(), host_transfer_key, oc_host_graph_name, fld, @@ -2113,11 +2142,13 @@ absl::Status ExtractOutsideCompilationForWhileNode( } absl::Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( - Graph* g, const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, FunctionLibraryRuntime* flr, - FunctionLibraryDefinition* fld, std::vector* host_graphs, - std::vector* shape_inference_graphs, + Graph* g, const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, + const std::map& host_compute_core, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* host_graphs, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { std::vector if_nodes, while_nodes, func_call_nodes; for (Node* n : g->nodes()) { @@ -2155,7 +2186,7 @@ absl::Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( } absl::Status CopyOutsideCompilationConstNodes( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { for (Node* n : g->op_nodes()) { if (!n->IsConstant() || !HasNodeAttr(n->def(), outside_compilation_attr_name)) { @@ -2205,8 +2236,8 @@ absl::Status RewriteOutsideCompilationSubgraphFn::operator()( const std::vector& arg_source_tensors, std::unique_ptr* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def) { - string old_name = node_def->op(); - string new_name = + std::string old_name = node_def->op(); + std::string new_name = absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name); node_def->set_op(new_name); node_def->set_name(new_name); @@ -2290,14 +2321,14 @@ absl::Status RewriteOutsideCompilationSubgraphFn::operator()( AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", *shapes, node_def); } else { - string shape_inference_func_name = + std::string shape_inference_func_name = absl::StrCat("_outside_compilation_shape_inference_", new_name); NameAttrList shape_inference_graph; shape_inference_graph.set_name(shape_inference_func_name); AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", std::vector{}, node_def); } - AddNodeAttr("ancestors", std::vector{}, node_def); + AddNodeAttr("ancestors", std::vector{}, node_def); AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def); AddNodeAttr("Toutputs", send_from_host_dtypes, node_def); AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def); @@ -2306,15 +2337,16 @@ absl::Status RewriteOutsideCompilationSubgraphFn::operator()( } absl::Status ExtractOutsideCompilationForFunction( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const NameAttrList& func_name_attrs, const string& new_func_name, - const string& host_graph_func_name, - const std::map& host_compute_core, FunctionLibraryRuntime* flr, - FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const NameAttrList& func_name_attrs, + const std::string& new_func_name, const std::string& host_graph_func_name, + const std::map& host_compute_core, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { // Convert the function to graph. - const string& func_name = func_name_attrs.name(); + const std::string& func_name = func_name_attrs.name(); FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR( flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); @@ -2345,8 +2377,8 @@ absl::Status ExtractOutsideCompilationForFunction( } std::unique_ptr graph_out; - std::vector outside_compilation_host_graphs; - std::vector shape_inference_graphs_to_rewrite; + std::vector outside_compilation_host_graphs; + std::vector shape_inference_graphs_to_rewrite; if (*has_outside_compilation) { // Copy outside compilation Const nodes with non outside compilation users. TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes( @@ -2404,7 +2436,7 @@ absl::Status ExtractOutsideCompilationForFunction( } } } - std::map host_compute_nodes; + std::map host_compute_nodes; for (Node* n : outside_compilation_nodes) { auto host_compute_node_or = ReplaceOutsideCompilationCallNode( graph_out.get(), n, host_compute_core, *cluster_deps); @@ -2416,11 +2448,11 @@ absl::Status ExtractOutsideCompilationForFunction( // them so XlaCompiler can handle them in correct order. for (const auto& iter : host_compute_nodes) { Node* host_compute_node = iter.second; - std::vector token_input_node_names; + std::vector token_input_node_names; TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(), kXlaTokenInputNodesAttrName, &token_input_node_names)); - for (const string& node_name : token_input_node_names) { + for (const std::string& node_name : token_input_node_names) { if (node_name == kXlaTokenArgNodeName) { continue; } @@ -2459,7 +2491,7 @@ absl::Status ExtractOutsideCompilationForFunction( // Shape inference graphs might contain Placeholder nodes for outside // compilation to outside compilation edges. Rewrite shape inference graphs // to remove such nodes. - for (const string& shape_inference_graph : + for (const std::string& shape_inference_graph : shape_inference_graphs_to_rewrite) { TF_RETURN_IF_ERROR( RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(), @@ -2467,7 +2499,7 @@ absl::Status ExtractOutsideCompilationForFunction( } // Remove the outside compilation graphs from function library. - for (const string& func : outside_compilation_host_graphs) { + for (const std::string& func : outside_compilation_host_graphs) { TF_RETURN_IF_ERROR(fld->RemoveFunction(func)); } @@ -2499,9 +2531,9 @@ absl::Status ExtractOutsideCompilationForFunction( } absl::Status ExtractOutsideCompilation( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters, Graph* g, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::unordered_map& clusters, Graph* g, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, bool* modified) { if (VLOG_IS_ON(4)) { @@ -2511,14 +2543,14 @@ absl::Status ExtractOutsideCompilation( *modified = false; auto node_name_index = g->BuildNodeNameIndex(); for (auto& iter : clusters) { - string xla_cluster_name = iter.first; + std::string xla_cluster_name = iter.first; Node* n = iter.second.node; auto const& func_name_attrs = iter.second.func_name_attrs; auto const& host_compute_core = iter.second.host_compute_core; - std::vector shape_inference_graphs; + std::vector shape_inference_graphs; bool has_outside_compilation; - string host_graph_func_name = + std::string host_graph_func_name = absl::StrCat("oc_host_graph_", xla_cluster_name); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, @@ -2528,7 +2560,7 @@ absl::Status ExtractOutsideCompilation( *modified |= has_outside_compilation; if (has_outside_compilation) { - string pivot_name = absl::StrCat(xla_cluster_name, "/pivot"); + std::string pivot_name = absl::StrCat(xla_cluster_name, "/pivot"); Node* pivot_node = node_name_index[pivot_name]; TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph( g, fld, host_graph_func_name, n, pivot_node)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index 7631ccd0bc6ab0..c1697fcb4cde0d 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -44,9 +44,9 @@ namespace tensorflow { class RewriteOutsideCompilationSubgraphFn { public: RewriteOutsideCompilationSubgraphFn( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, - const string& xla_cluster_name, const string& new_function_name) + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const std::string& new_function_name) : xla_cluster_attr_name_(xla_cluster_attr_name), outside_compilation_attr_name_(outside_compilation_attr_name), xla_cluster_name_(xla_cluster_name), @@ -59,10 +59,10 @@ class RewriteOutsideCompilationSubgraphFn { NodeDef* node_def); private: - string xla_cluster_attr_name_; - string outside_compilation_attr_name_; - string xla_cluster_name_; - string new_function_name_; + std::string xla_cluster_attr_name_; + std::string outside_compilation_attr_name_; + std::string xla_cluster_name_; + std::string new_function_name_; }; // For an XLA computation function, replace all outside compilations with @@ -88,12 +88,13 @@ class RewriteOutsideCompilationSubgraphFn { // has_outside_compilation: a bool indicating whether this function has any // outside compilation nodes. absl::Status ExtractOutsideCompilationForFunction( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const NameAttrList& func_name_attrs, const string& new_func_name, - const string& host_graph_func_name, - const std::map& host_compute_core, FunctionLibraryRuntime* flr, - FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const NameAttrList& func_name_attrs, + const std::string& new_func_name, const std::string& host_graph_func_name, + const std::map& host_compute_core, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* shape_inference_graphs, bool* has_outside_compilation); // Rewrites XLA computation in `clusters` to replace outside compilation nodes @@ -101,9 +102,9 @@ absl::Status ExtractOutsideCompilationForFunction( // of outside compilation outputs cannot be determined now, we will store shape // inference graph into `fld`. absl::Status ExtractOutsideCompilation( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters, Graph* g, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::unordered_map& clusters, Graph* g, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, bool* modified); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index 4d007d07504939..1a6441a80726a0 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -236,14 +236,14 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test { } absl::Status ExtractOutsideCompilationTest( - const string &xla_cluster_attr_name, - const string &outside_compilation_attr_name, - const string &xla_cluster_name, const NameAttrList &func_name_attrs, - const string &new_func_name, const string &host_graph_func_name, - const std::map &host_compute_core, - FunctionLibraryDefinition *fld, - std::vector *shape_inference_graphs, - bool *has_outside_compilation) { + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const NameAttrList& func_name_attrs, + const std::string& new_func_name, const std::string& host_graph_func_name, + const std::map& host_compute_core, + FunctionLibraryDefinition* fld, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { OptimizerOptions opts; pflr_ = std::make_unique( device_mgr_.get(), Env::Default(), /*config=*/nullptr, @@ -288,9 +288,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -342,7 +342,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map host_func_attrs; + protobuf::Map host_func_attrs; host_func_attrs["_device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper( *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody)); @@ -406,9 +406,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -481,9 +481,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -498,7 +498,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map host_func_attrs; + protobuf::Map host_func_attrs; host_func_attrs["_device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, @@ -568,7 +568,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { // _xla_token_input_nodes. Node *if_node = node_name_index["if"]; EXPECT_NE(if_node, nullptr); - std::vector token_inputs; + std::vector token_inputs; TF_CHECK_OK( GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs)); EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if")); @@ -631,9 +631,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -648,7 +648,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map host_func_attrs; + protobuf::Map host_func_attrs; host_func_attrs["_device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, @@ -767,9 +767,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { TF_CHECK_OK(fld.AddFunctionDef(*xla_fdef)); } - protobuf::Map attrs; - std::map host_compute_core; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -784,7 +784,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map host_func_attrs; + protobuf::Map host_func_attrs; host_func_attrs["_device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, @@ -873,9 +873,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -898,14 +898,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, EXPECT_NE(host_compute_1, nullptr); // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr. - std::vector token_input_nodes; + std::vector token_input_nodes; TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()), "_xla_token_input_nodes", &token_input_nodes)); - std::vector expected_token_input_nodes_0({"_xla_token_arg_node"}); + std::vector expected_token_input_nodes_0( + {"_xla_token_arg_node"}); EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0); token_input_nodes.clear(); - std::vector expected_token_input_nodes_1( + std::vector expected_token_input_nodes_1( {"_xla_token_arg_node", "outside_compilation_0_host_compute"}); TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), "_xla_token_input_nodes", &token_input_nodes)); @@ -955,9 +956,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -980,14 +981,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, EXPECT_NE(host_compute_1, nullptr); // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr. - std::vector token_input_nodes; + std::vector token_input_nodes; TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()), "_xla_token_input_nodes", &token_input_nodes)); - std::vector expected_token_input_nodes_0({"_xla_token_arg_node"}); + std::vector expected_token_input_nodes_0( + {"_xla_token_arg_node"}); EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0); token_input_nodes.clear(); - std::vector expected_token_input_nodes_1( + std::vector expected_token_input_nodes_1( {"_xla_token_arg_node", "outside_compilation_0_host_compute"}); TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), "_xla_token_input_nodes", &token_input_nodes)); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index e7a375231accdf..a0a0d45736f1e8 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -46,7 +46,7 @@ std::vector* jitrt_flag_list; std::vector* flag_list; absl::once_flag flags_init; -bool SetterForXlaAutoJitFlag(const string& value) { +bool SetterForXlaAutoJitFlag(const std::string& value) { int32_t opt_level; // We need to use the mark_for_compilation_flags directly here instead of // going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The @@ -81,7 +81,7 @@ bool SetterForXlaAutoJitFlag(const string& value) { return true; } -bool SetterForXlaCallModuleDisabledChecks(const string& value) { +bool SetterForXlaCallModuleDisabledChecks(const std::string& value) { auto directives = absl::StrSplit(value, ',', absl::SkipEmpty()); call_module_flags->disabled_checks.insert(directives.begin(), directives.end()); @@ -231,7 +231,7 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0; mark_for_compilation_flags->tf_xla_min_cluster_size = 4; mark_for_compilation_flags->tf_xla_max_cluster_size = - std::numeric_limits::max(); + std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_clustering_debug = false; mark_for_compilation_flags->tf_xla_cpu_global_jit = false; mark_for_compilation_flags->tf_xla_clustering_fuel = @@ -291,6 +291,7 @@ void AllocateAndParseFlags() { // Dump graphs in TFG dialect. bool use_tfg_graph_dumper = false; bool enable_tpu_variable_runtime_reformatting_pass = true; + bool enable_serialize_mlir_to_compressed_bytecode = false; flag_list = new std::vector( {Flag("tf_xla_enable_lazy_compilation", @@ -405,7 +406,10 @@ void AllocateAndParseFlags() { &enable_tpu_variable_runtime_reformatting_pass, "Enables TPUVariableRuntimeReformatting pass for MLIR-Based " "TensorFlow Compiler Bridge. This enables weight update sharding " - "and creates TPUReshardVariables ops.")}); + "and creates TPUReshardVariables ops."), + Flag("tf_serialize_mlir_to_compressed_bytecode", + &enable_serialize_mlir_to_compressed_bytecode, + "If true, serialize MLIR to compressed bytecode.")}); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); @@ -434,6 +438,8 @@ void AllocateAndParseFlags() { enable_mlir_multiple_local_cpu_devices; mlir_flags->tf_mlir_enable_debug_info_serialization = enable_mlir_debug_info_serialization; + mlir_flags->tf_serialize_mlir_to_compressed_bytecode = + enable_serialize_mlir_to_compressed_bytecode; if (use_tfg_graph_dumper) { UseMlirForGraphDump(MlirDumpConfig{}.elide_large_attributes().emit_dialect( @@ -457,7 +463,7 @@ void ResetFlags() { } // namespace -bool SetXlaAutoJitFlagFromFlagString(const string& value) { +bool SetXlaAutoJitFlagFromFlagString(const std::string& value) { absl::call_once(flags_init, &AllocateAndParseFlags); return SetterForXlaAutoJitFlag(value); } diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index b355c79364cb1b..96154b892ae5b0 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -41,15 +41,15 @@ struct XlaAutoJitFlag { // `optimization_level_general` applies. // // Experimental. - int32 optimization_level_single_gpu; - int32 optimization_level_general; + int32_t optimization_level_single_gpu; + int32_t optimization_level_general; }; // Sets the xla_auto_jit_flag based on the given flag string. Supported syntax // is: // : sets general and single_gpu setting to the provided number. // single-gpu(): sets the single_gpu setting to the provided number. -bool SetXlaAutoJitFlagFromFlagString(const string& value); +bool SetXlaAutoJitFlagFromFlagString(const std::string& value); // Flags associated with the XLA bridge's mark_for_compilation_pass module. struct MarkForCompilationPassFlags { @@ -57,16 +57,16 @@ struct MarkForCompilationPassFlags { // Minimum number of operators in an XLA compilation. Ignored for operators // placed on an XLA device or operators explicitly marked for compilation. - int32 tf_xla_min_cluster_size; + int32_t tf_xla_min_cluster_size; // Maximum number of operators in an XLA compilation. - int32 tf_xla_max_cluster_size; + int32_t tf_xla_max_cluster_size; // If non-empty, limit XLA clustering to the following TF operations. - string tf_xla_ops_to_cluster; + std::string tf_xla_ops_to_cluster; // If non-empty, remove following operations from XLA clustering excludelist. - string tf_xla_cluster_exclude_ops; + std::string tf_xla_cluster_exclude_ops; // Dump graphs during XLA compilation. bool tf_xla_clustering_debug; @@ -110,7 +110,7 @@ struct MarkForCompilationPassFlags { bool tf_xla_disable_strict_signature_checks; // Specifies the persistance cache prefix. Default is "xla_compile_cache" - string tf_xla_persistent_cache_prefix; + std::string tf_xla_persistent_cache_prefix; }; // Flags associated with XLA Sparse Core. @@ -299,6 +299,7 @@ struct MlirCommonFlags { // with different local CPU devices settings. bool tf_mlir_enable_multiple_local_cpu_devices; bool tf_mlir_enable_debug_info_serialization; + bool tf_serialize_mlir_to_compressed_bytecode; }; // Flags for the JitRt pipeline -- see tf_jitrt_pipeline.h for details. diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc index 75bd1d7310a295..1b0239c3550970 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc @@ -95,7 +95,7 @@ TEST(ForceXlaConstantsOnHostPassTest, Simple) { if (CanCreateXlaKernel(node->def())) { EXPECT_FALSE(found); found = true; - std::vector hostmem_attr; + std::vector hostmem_attr; EXPECT_TRUE(TryGetNodeAttr(node->def(), "_input_hostmem", &hostmem_attr)); EXPECT_EQ(hostmem_attr.size(), 1); EXPECT_EQ(hostmem_attr[0], 1); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 8317d222928200..03a7d1081b8b53 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -93,7 +93,7 @@ std::vector IntTensorAsVector(const Tensor& t) { result.reserve(t.NumElements()); for (int i = 0; i < t.NumElements(); i++) { int64_t element = t.dtype() == DT_INT32 - ? static_cast(t.flat()(i)) + ? static_cast(t.flat()(i)) : t.flat()(i); result.push_back(element); } @@ -251,14 +251,14 @@ absl::Status ComputeSliceSize(const Scope& host_scope, absl::Status ConvertTensorFlowSliceToStaticShapedSlice( Graph* g, Node* slice, const SliceInputs& slice_inputs, absl::string_view cluster_name, Node** result) { - string host_name; + std::string host_name; TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( slice->assigned_device_name(), &host_name)); absl::Status status; Scope main_scope = NewInternalScope(g, &status, /*refiner=*/nullptr) - .WithXlaCluster(string(cluster_name)) + .WithXlaCluster(std::string(cluster_name)) .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice")); Scope host_scope = main_scope.WithAssignedDevice(host_name); @@ -286,7 +286,7 @@ absl::Status ConvertTensorFlowSliceToStaticShapedSlice( TF_RETURN_IF_ERROR(main_scope.status()); - std::vector compile_time_const_inputs; + std::vector compile_time_const_inputs; compile_time_const_inputs.push_back("size"); (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, compile_time_const_inputs); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index 411f761995483a..6a8523a7d4c893 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -66,7 +66,8 @@ class FakeDevice : public Device { Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } - static std::unique_ptr Make(const string& name, const string& type) { + static std::unique_ptr Make(const std::string& name, + const std::string& type) { DeviceAttributes device_attributes; device_attributes.set_name(name); device_attributes.set_device_type(DeviceType(type).type()); @@ -100,7 +101,7 @@ absl::Status IncreaseDynamismForAutoJit(const Scope& s, // Scope::ToGraph seems to drop assigned devices, probably because it goes // through a GraphDef. So explicitly maintain the device assignment. - std::unordered_map assigned_device_names; + std::unordered_map assigned_device_names; for (Node* n : s.graph()->nodes()) { assigned_device_names[n->name()] = n->assigned_device_name(); } @@ -149,7 +150,7 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) { Inputs(m_slice_size_0, Const(static_cast(500)), Const(zero_32)))); - std::vector compile_time_constant_inputs; + std::vector compile_time_constant_inputs; compile_time_constant_inputs.push_back("size"); auto m_dynamic_slice = NodeWith( Op("Slice"), AssignedDevice(kDeviceName), diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index c3a24f3e0f7163..340cdbe8032c63 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -151,7 +151,7 @@ class MarkForCompilationPassImpl { std::optional resource_op_device, std::optional resource_var_operation_node_id, std::optional deadness_predicate, - bool is_xla_compile_attr_true, std::optional xla_scope) + bool is_xla_compile_attr_true, std::optional xla_scope) : cycles_graph_node_id_(tf_graph_node_id), effective_cluster_size_(effective_cluster_size), has_functional_control_flow_(has_functional_control_flow), @@ -220,7 +220,7 @@ class MarkForCompilationPassImpl { // If not nullopt then the all nodes in the cluster either do not have the // XlaScope attribute set or have it set to the value returned. - const std::optional& xla_scope() const { return xla_scope_; } + const std::optional& xla_scope() const { return xla_scope_; } // Returns the TF graph node IDs for the resource variable operations in // this cluster. @@ -228,7 +228,7 @@ class MarkForCompilationPassImpl { return resource_var_operation_node_ids_; } - string DebugString(const Graph& graph) const { + std::string DebugString(const Graph& graph) const { Node* node = graph.FindNodeId(cycles_graph_node_id()); if (!node) { // This should never happen but we try to be resilient because this is a @@ -254,7 +254,7 @@ class MarkForCompilationPassImpl { std::optional resource_op_device_; std::optional deadness_predicate_; bool is_xla_compile_attr_true_; - std::optional xla_scope_; + std::optional xla_scope_; std::vector resource_var_operation_node_ids_; Cluster(const Cluster&) = delete; @@ -365,7 +365,7 @@ class MarkForCompilationPassImpl { std::optional resource_var_operation_node_id, std::optional deadness_predicate, bool is_xla_compile_attr_true, - std::optional xla_scope) { + std::optional xla_scope) { cluster_storage_.push_back(std::make_unique( cycles_graph_node_id, effective_cluster_size, has_functional_control_flow, device_set, resource_op_device, @@ -374,7 +374,7 @@ class MarkForCompilationPassImpl { return cluster_storage_.back().get(); } - std::optional GetXlaScope(Node* n); + std::optional GetXlaScope(Node* n); // Returns the cluster for node `n`. If two nodes, N1 and N2, are placed in // the same cluster by the clustering algorithm then this function will return @@ -417,7 +417,8 @@ class MarkForCompilationPassImpl { // Returns a string representing `cycles_graph_node_id`. If the node is // unclusterable (either it is a phatom "frame" node or is not a compilation // candidate) then set `*found_unclustered` to true. - string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered); + std::string DebugStringForCyclesGraphNode(int node_id, + bool* found_unclustered); // We could not contract the edge from `from` to `to`. Return a string // describing an alternate path from `from` to `to` (besides the direct edge @@ -429,7 +430,7 @@ class MarkForCompilationPassImpl { // contracted because of the path [P,Q,R]" where P, Q and R are all clusters // since in that case a natural question is why we could not form a {A, P, Q, // R, B} cluster. - string DescribePotentialCycle(int from, int to); + std::string DescribePotentialCycle(int from, int to); // Merge the clusters `cluster_from` and `cluster_to`. After this step the // larger combined cluster is represented by `cluster_from`, but can have @@ -459,8 +460,8 @@ class MarkForCompilationPassImpl { return true; } - string EdgeContractionFailureMsg(Cluster* from, Cluster* to, - absl::string_view reason) { + std::string EdgeContractionFailureMsg(Cluster* from, Cluster* to, + absl::string_view reason) { return absl::StrCat("Could not contract ", from->DebugString(*graph_), " -> ", to->DebugString(*graph_), " because ", reason, "."); @@ -468,7 +469,7 @@ class MarkForCompilationPassImpl { DebugOptions debug_options_; Graph* graph_; - uint64 graph_fingerprint_; + uint64_t graph_fingerprint_; FunctionLibraryDefinition* flib_def_; Env* env_; OptimizerOptions::GlobalJitLevel global_jit_level_; @@ -547,7 +548,7 @@ std::vector MarkForCompilationPassImpl::FindAlternatePathForDebugging( return path; } -string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode( +std::string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode( int cycles_graph_node_id, bool* found_unclustered) { Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id); if (cluster) { @@ -567,8 +568,9 @@ string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode( return node->name(); } -string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) { - std::vector path_str; +std::string MarkForCompilationPassImpl::DescribePotentialCycle(int from, + int to) { + std::vector path_str; bool found_unclustered = false; absl::c_transform(FindAlternatePathForDebugging(from, to), std::back_inserter(path_str), [&](int node_id) { @@ -701,7 +703,7 @@ absl::StatusOr MarkForCompilationPassImpl::ForEachEdgeInPostOrder( // Make a copy of the set of successors because we may modify the graph in // TryToContractEdge. - std::vector successors_copy = + std::vector successors_copy = cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id()); for (int to : successors_copy) { @@ -974,7 +976,7 @@ class ClusterSequenceNumberGenerator { sequence_numbers_.clear(); } - int64 GetNext(uint64 key) { + int64_t GetNext(uint64_t key) { mutex_lock lock(mu_); return sequence_numbers_[key]++; } @@ -987,13 +989,13 @@ class ClusterSequenceNumberGenerator { private: mutex mu_; - absl::flat_hash_map sequence_numbers_; + absl::flat_hash_map sequence_numbers_; }; // Get a monotonic sequence numbers for a graph identified by its `fingerprint`. // The sequence number is necessary to disambiguate clusters extracted from the // same graph and when duplicate graphs exist within the same process. -int64_t GetNextClusterSequenceNumber(uint64 fingerprint) { +int64_t GetNextClusterSequenceNumber(uint64_t fingerprint) { return ClusterSequenceNumberGenerator::Global().GetNext(fingerprint); } @@ -1002,7 +1004,7 @@ absl::Status MarkForCompilationPassImpl::CreateClusters() { clusters_created_ = true; // Names for each cluster. - std::unordered_map cluster_names; + std::unordered_map cluster_names; if (debug_options_.dump_graphs) { DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_); @@ -1030,7 +1032,7 @@ absl::Status MarkForCompilationPassImpl::CreateClusters() { if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size || cluster->has_functional_control_flow() || cluster->is_xla_compile_attr_true()) { - string& name = cluster_names[cluster->cycles_graph_node_id()]; + std::string& name = cluster_names[cluster->cycles_graph_node_id()]; if (name.empty()) { if (!cluster_name_prefix_.empty()) { @@ -1099,7 +1101,7 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency( return false; } -std::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { +std::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { // Look for either _XlaScope or _XlaInternalScope on both nodes to guide // clustering. If both nodes have a scope and the scopes do not match, do // not cluster along this edge. If even one of the nodes lacks a scope @@ -1118,14 +1120,14 @@ std::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { if (global_jit_level_ != OptimizerOptions::OFF) { // If global_jit_level_ is ON, respect only _XlaInternalScope. - const string& scope = + const std::string& scope = GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr); if (!scope.empty()) { return scope; } } else { // If global_jit_level_ is OFF, respect only _XlaScope. - const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); + const std::string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); if (!scope.empty()) { return scope; } @@ -1186,9 +1188,9 @@ absl::Status MarkForCompilationPassImpl::BuildInitialClusterSet() { deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot)); } - const string& device_name_str = !node->assigned_device_name().empty() - ? node->assigned_device_name() - : node->requested_device(); + const std::string& device_name_str = !node->assigned_device_name().empty() + ? node->assigned_device_name() + : node->requested_device(); TF_ASSIGN_OR_RETURN(DeviceId device, device_info_cache_.GetIdFor(device_name_str)); @@ -1258,16 +1260,17 @@ absl::StatusOr IsIdentityDrivingConstsInLoop(Node* node) { return true; } -absl::flat_hash_set CreateClusterExcludeList() { +absl::flat_hash_set CreateClusterExcludeList() { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - absl::flat_hash_set excludelist; + absl::flat_hash_set excludelist; for (auto s : absl::StrSplit(flags->tf_xla_cluster_exclude_ops, ',')) { if (!s.empty()) { - excludelist.insert(string(s)); + excludelist.insert(std::string(s)); } } if (VLOG_IS_ON(2) && !excludelist.empty()) { - std::vector vexcludelist(excludelist.begin(), excludelist.end()); + std::vector vexcludelist(excludelist.begin(), + excludelist.end()); absl::c_sort(vexcludelist); VLOG(2) << "XLA clustering will exclude following TF operations from auto " "clustering: " @@ -1276,11 +1279,11 @@ absl::flat_hash_set CreateClusterExcludeList() { return excludelist; } -absl::flat_hash_set GetOrCreateAllowlist() { - absl::flat_hash_map>* allowlist_table = +absl::flat_hash_set GetOrCreateAllowlist() { + absl::flat_hash_map>* allowlist_table = tensorflow::GetAllowlistTable(); MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - absl::flat_hash_set allowlist; + absl::flat_hash_set allowlist; for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) { if (s == "FUSIBLE") { @@ -1292,12 +1295,12 @@ absl::flat_hash_set GetOrCreateAllowlist() { allowlist.insert(v.begin(), v.end()); } else if (!s.empty()) { // Should be a user provided TF operation. - allowlist.insert(string(s)); + allowlist.insert(std::string(s)); } } if (VLOG_IS_ON(2) && !allowlist.empty()) { - std::vector vallowlist(allowlist.begin(), allowlist.end()); + std::vector vallowlist(allowlist.begin(), allowlist.end()); absl::c_sort(vallowlist); VLOG(2) << "XLA clustering will only consider the following TF operations: " << absl::StrJoin(vallowlist, " "); @@ -1338,8 +1341,8 @@ absl::Status MarkForCompilationPassImpl::FindCompilationCandidates() { auto allowlist = GetOrCreateAllowlist(); - std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); - absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); + std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); + absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); // Check that user's provided TF operation really exists. for (const auto& s : allowlist) { if (!all_ops.contains(s)) { @@ -1674,7 +1677,7 @@ void MarkForCompilationPassImpl::DumpPostClusteringGraphs() { DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_); } -string RatioToString(int numerator, int denominator) { +std::string RatioToString(int numerator, int denominator) { return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, (100.0 * numerator) / denominator); } @@ -1985,10 +1988,11 @@ absl::Status MarkForCompilationPass::RunForTest( return MarkForCompilation(options, debug_options); } -absl::flat_hash_map>* GetAllowlistTable() { +absl::flat_hash_map>* +GetAllowlistTable() { // Table format: category name: {list of TF operations in that category} - static absl::flat_hash_map>* result = - new absl::flat_hash_map>{ + static absl::flat_hash_map>* result = + new absl::flat_hash_map>{ // Unary {"PW", {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", @@ -2056,8 +2060,8 @@ void ResetClusterSequenceNumber() { ClusterSequenceNumberGenerator::Global().Reset(); } -absl::flat_hash_set GetKnownXLAAllowlistOp() { - absl::flat_hash_set result{ +absl::flat_hash_set GetKnownXLAAllowlistOp() { + absl::flat_hash_set result{ "AdjustContrastv2", "AdjustHue", "AdjustSaturation", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 558912f2eee2e0..d6a2814ed33982 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -47,7 +47,7 @@ class MarkForCompilationPass : public GraphOptimizationPass { friend class MarkForCompilationPassTestHelper; }; -absl::flat_hash_map>* GetAllowlistTable(); +absl::flat_hash_map>* GetAllowlistTable(); namespace testing { // DO NOT USE IN PRODUCTION. @@ -56,7 +56,7 @@ namespace testing { void ResetClusterSequenceNumber(); // Return a list of operation that we choose not to put into the allowlist. -absl::flat_hash_set GetKnownXLAAllowlistOp(); +absl::flat_hash_set GetKnownXLAAllowlistOp(); } // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 1a120791206369..1d4031a4ffc926 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -67,10 +67,10 @@ static bool Initialized = [] { REGISTER_OP("UncompilableNullary").Output("o: float"); REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); -std::unordered_map GetClusters(const Graph& graph) { - std::unordered_map ids; +std::unordered_map GetClusters(const Graph& graph) { + std::unordered_map ids; for (Node* node : graph.nodes()) { - string cluster; + std::string cluster; if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) { CHECK(!cluster.empty()); ids[node->name()] = cluster; @@ -86,10 +86,10 @@ std::unordered_map GetClusters(const Graph& graph) { return ids; } -std::set GetClusterNames(const Graph& graph) { - std::set names; +std::set GetClusterNames(const Graph& graph) { + std::set names; for (Node* node : graph.nodes()) { - string cluster; + std::string cluster; if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) { CHECK(!cluster.empty()); names.insert(cluster); @@ -98,10 +98,10 @@ std::set GetClusterNames(const Graph& graph) { return names; } -absl::flat_hash_map> GetClusterSets( - const Graph& g, std::vector* cluster_names = nullptr) { +absl::flat_hash_map> GetClusterSets( + const Graph& g, std::vector* cluster_names = nullptr) { CHECK(cluster_names == nullptr || cluster_names->empty()); - absl::flat_hash_map> cluster_sets; + absl::flat_hash_map> cluster_sets; for (const auto& p : GetClusters(g)) { cluster_sets[p.second].push_back(p.first); } @@ -357,7 +357,7 @@ TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + std::string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; testing::FindNodeByName(graph.get(), "A") ->set_assigned_device_name(xla_cpu_device); testing::FindNodeByName(graph.get(), "tanh0") @@ -694,7 +694,7 @@ TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) { } namespace { -Node* MakeRead(const Scope& scope, const string& id, +Node* MakeRead(const Scope& scope, const std::string& id, Node** var_handle_op = nullptr) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); @@ -706,7 +706,7 @@ Node* MakeRead(const Scope& scope, const string& id, return read.node(); } -Node* MakeWrite(const Scope& scope, const string& id) { +Node* MakeWrite(const Scope& scope, const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output value_to_write = @@ -716,7 +716,7 @@ Node* MakeWrite(const Scope& scope, const string& id) { return assign_op.operation.node(); } -Node* MakeNeutral(const Scope& scope, const string& id) { +Node* MakeNeutral(const Scope& scope, const std::string& id) { return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); } } // namespace @@ -733,11 +733,11 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - absl::flat_hash_map> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", - "ValueToAssignW"}; + std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", + "ValueToAssignW"}; ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); } @@ -753,7 +753,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - absl::flat_hash_map> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 0); } @@ -779,13 +779,13 @@ TEST(XlaCompilationTest, ChainOfOps) { TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::vector cluster_names; - absl::flat_hash_map> cluster_sets = + std::vector cluster_names; + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes_a = { + std::vector expected_clustered_nodes_a = { "AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"}; ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); } @@ -881,7 +881,7 @@ TEST(XlaCompilationTest, ConstOp) { { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); - auto c = ops::Const(root.WithOpName("const"), string("string")); + auto c = ops::Const(root.WithOpName("const"), std::string("string")); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); @@ -901,12 +901,12 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; + std::string cluster_name = clusters.begin()->second; - std::unordered_map expected_clusters( + std::unordered_map expected_clusters( {{"negate", cluster_name}, {"add", cluster_name}}); EXPECT_EQ(clusters, expected_clusters); } @@ -924,12 +924,12 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; + std::string cluster_name = clusters.begin()->second; - std::unordered_map expected_clusters( + std::unordered_map expected_clusters( {{"negate", cluster_name}, {"identity", cluster_name}, {"add", cluster_name}}); @@ -956,7 +956,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't @@ -982,7 +982,7 @@ TEST(XlaCompilationTest, RandomShape) { TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["shape"], ""); } @@ -1028,7 +1028,7 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) { TF_ASSERT_OK( MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["fn_call"], ""); } @@ -1054,12 +1054,12 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { - n->set_assigned_device_name(string(xla_gpu_device)); + n->set_assigned_device_name(std::string(xla_gpu_device)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/shape_rng"], ""); EXPECT_EQ(clusters["test/reshape"], ""); } @@ -1087,12 +1087,12 @@ TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { - n->set_assigned_device_name(string(xla_gpu_device)); + n->set_assigned_device_name(std::string(xla_gpu_device)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/read"], ""); EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); } @@ -1133,15 +1133,15 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) { for (Node* n : graph->nodes()) { if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { - n->set_assigned_device_name(string(xla_gpu_dev0)); + n->set_assigned_device_name(std::string(xla_gpu_dev0)); } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { - n->set_assigned_device_name(string(xla_gpu_dev1)); + n->set_assigned_device_name(std::string(xla_gpu_dev1)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); // Each of the MatMuls should be in a separate cluster. - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]); EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]); @@ -1170,17 +1170,17 @@ TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) { for (Node* n : graph->nodes()) { if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) { - n->set_assigned_device_name(string(xla_cpu_dev0)); + n->set_assigned_device_name(std::string(xla_cpu_dev0)); } else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { - n->set_assigned_device_name(string(xla_gpu_dev0)); + n->set_assigned_device_name(std::string(xla_gpu_dev0)); } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { - n->set_assigned_device_name(string(xla_gpu_dev1)); + n->set_assigned_device_name(std::string(xla_gpu_dev1)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); // Each of the MatMuls should be in a separate cluster. - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]); EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]); @@ -1223,14 +1223,14 @@ TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) { TF_ASSERT_OK(root.ToGraph(graph.get())); for (Node* n : graph->nodes()) { if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { - n->set_assigned_device_name(string(xla_gpu_dev0)); + n->set_assigned_device_name(std::string(xla_gpu_dev0)); } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { - n->set_assigned_device_name(string(xla_gpu_dev1)); + n->set_assigned_device_name(std::string(xla_gpu_dev1)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]); EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]); @@ -1254,12 +1254,12 @@ TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) { for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { - n->set_assigned_device_name(string(xla_cpu_device)); + n->set_assigned_device_name(std::string(xla_cpu_device)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/a"], ""); EXPECT_NE(clusters["test/b"], ""); EXPECT_NE(clusters["test/c"], ""); @@ -1277,7 +1277,7 @@ TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/a"], ""); EXPECT_EQ(clusters["test/b"], ""); } @@ -1299,12 +1299,12 @@ TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) { for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { - n->set_assigned_device_name(string(xla_cpu_device)); + n->set_assigned_device_name(std::string(xla_cpu_device)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/check"], ""); EXPECT_NE(clusters["test/greaterequal"], ""); EXPECT_NE(clusters["test/assert"], ""); @@ -1324,7 +1324,7 @@ TEST(XlaCompilationTest, DontAutoClusterDummyOps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/assert"], ""); EXPECT_EQ(clusters["test/check"], ""); } @@ -1345,7 +1345,7 @@ TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/tensor_list_reserve"], ""); } @@ -1373,7 +1373,7 @@ TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/tensor_list_element_shape"], ""); } @@ -1391,7 +1391,7 @@ TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + std::string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { n->set_assigned_device_name(xla_cpu_device); @@ -1400,7 +1400,7 @@ TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/tensor_list_reserve"], ""); } @@ -1427,7 +1427,7 @@ TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/x"], ""); @@ -1451,7 +1451,7 @@ TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/x"], ""); EXPECT_EQ(clusters["test/y"], ""); @@ -1473,7 +1473,7 @@ TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/x"], ""); EXPECT_EQ(clusters["test/y"], ""); @@ -1486,8 +1486,8 @@ TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) { Node* resource_read = MakeRead(root, "read", &var_handle); Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a); - string resource_read_name = resource_read->name(); - string var_handle_name = var_handle->name(); + std::string resource_read_name = resource_read->name(); + std::string var_handle_name = var_handle->name(); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); @@ -1499,7 +1499,7 @@ TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/b"], ""); EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]); @@ -1512,8 +1512,8 @@ TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) { Node* resource_read = MakeRead(root, "read", &var_handle); Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a); - string resource_read_name = resource_read->name(); - string var_handle_name = var_handle->name(); + std::string resource_read_name = resource_read->name(); + std::string var_handle_name = var_handle->name(); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); @@ -1525,7 +1525,7 @@ TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/b"], ""); EXPECT_EQ(clusters[resource_read_name], ""); @@ -1555,7 +1555,7 @@ TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/z"], ""); } @@ -1580,7 +1580,7 @@ TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/z"], ""); } @@ -1610,7 +1610,7 @@ TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/y"], ""); EXPECT_EQ(clusters["test/x"], clusters["test/y"]); @@ -1632,7 +1632,7 @@ TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/y"], ""); EXPECT_EQ(clusters["test/y"], clusters["test/x"]); @@ -1705,7 +1705,7 @@ TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["some_ctrl_input"], ""); EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_0_update"]); @@ -1875,19 +1875,19 @@ TEST(XlaCompilationTest, ClusterSessionName) { TF_ASSERT_OK( MarkForCompilationPassTestHelper::MarkForCompilation(&graph, options)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; + std::string cluster_name = clusters.begin()->second; - std::unordered_map expected_clusters( + std::unordered_map expected_clusters( {{"negate", cluster_name}, {"add", cluster_name}}); EXPECT_EQ(clusters, expected_clusters); EXPECT_THAT(cluster_name, ::testing::StartsWith("test_session_name")); } namespace { -Node* MakeStageNode(GraphDefBuilder& builder, string name, +Node* MakeStageNode(GraphDefBuilder& builder, std::string name, std::initializer_list dtypes, absl::Span values) { auto opts = builder.opts() @@ -1949,7 +1949,7 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { &graph, MarkForCompilationPassTestHelper::Options().WithNoClusterScoping())); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["add0"], clusters["add1"]); EXPECT_EQ(clusters["add0"], clusters["relu1"]); EXPECT_EQ(clusters["relu0"], clusters["add1"]); @@ -1964,7 +1964,7 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["add0"], clusters["add1"]); EXPECT_NE(clusters["add0"], clusters["relu1"]); EXPECT_NE(clusters["relu0"], clusters["add1"]); @@ -1973,9 +1973,9 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { } TEST(XlaCompilationTest, XLALiteAllowlist) { auto* allowlist_table = tensorflow::GetAllowlistTable(); - absl::flat_hash_set hallowlist; - std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); - absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); + absl::flat_hash_set hallowlist; + std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); + absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); // Check that all the operations in the table are existing TF operations for (auto pair : *allowlist_table) { @@ -1988,10 +1988,10 @@ TEST(XlaCompilationTest, XLALiteAllowlist) { // Check that all registered XLA operation are in the allowlist // table or are known to not be in it. - absl::flat_hash_set known_not_in_list = + absl::flat_hash_set known_not_in_list = tensorflow::testing::GetKnownXLAAllowlistOp(); - std::vector unknow_op; - for (string op : vall_ops) { + std::vector unknow_op; + for (std::string op : vall_ops) { if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) { unknow_op.push_back(op); } diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc index ce1f2cd5bcd671..db158fc84a0173 100644 --- a/tensorflow/compiler/jit/node_matchers.cc +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -35,7 +35,7 @@ namespace { using impl::NodeMatcherProperties; using impl::OutEdge; -string IndentAllButFirstLine(absl::string_view text) { +std::string IndentAllButFirstLine(absl::string_view text) { std::vector lines = absl::StrSplit(text, '\n'); for (int i = 1; i < lines.size(); i++) { lines[i].insert(0, " "); @@ -86,21 +86,21 @@ bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, case DT_DOUBLE: return CompareTensor(tensor, expected_tensor, listener); case DT_INT8: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_INT16: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_INT32: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_INT64: return CompareTensor(tensor, expected_tensor, listener); case DT_UINT8: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_UINT16: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_UINT32: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_UINT64: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); default: LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly. << DataType_Name(tensor.dtype()); @@ -188,7 +188,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { if (control_dep_set && !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) { if (listener->IsInterested()) { - string explanation = inner_listener.str(); + std::string explanation = inner_listener.str(); if (!explanation.empty()) { explanation = absl::StrCat(", ", explanation, ","); } @@ -225,7 +225,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { } void DescribeTo(::std::ostream* os) const override { - std::vector predicates; + std::vector predicates; if (name) { predicates.push_back(absl::StrCat("name: ", *name)); @@ -282,10 +282,11 @@ struct NodeMatcher : public ::testing::MatcherInterface { if (!attrs.empty()) { printed_something = true; - std::vector attrs_str; + std::vector attrs_str; absl::c_transform( attrs, std::back_inserter(attrs_str), - [](const std::pair>& attr_kv_pair) { + [](const std::pair>& + attr_kv_pair) { return absl::StrCat(attr_kv_pair.first, "->", attr_kv_pair.second ? SummarizeAttrValue(*attr_kv_pair.second) @@ -319,7 +320,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { if (listener->IsInterested()) { *listener << "\ninput " << input_idx << " does not match expected:\n"; (*input_matchers)[input_idx].DescribeTo(listener->stream()); - string explanation = inner_listener.str(); + std::string explanation = inner_listener.str(); if (!explanation.empty()) { *listener << ", " << explanation; } @@ -327,14 +328,14 @@ struct NodeMatcher : public ::testing::MatcherInterface { return false; } - std::optional op; - std::optional name; - std::optional assigned_device; + std::optional op; + std::optional name; + std::optional assigned_device; std::optional constant_value; std::optional>> input_matchers; std::optional<::testing::Matcher>> control_dep_set; - std::map> attrs; + std::map> attrs; }; // Matches a dst and dst_output on an input edge. Today we only use this with @@ -352,7 +353,7 @@ class OutEdgeMatcher : public ::testing::MatcherInterface { if (listener->IsInterested()) { *listener << "\nsource does not match expected "; src_matcher_.DescribeTo(listener->stream()); - string explanation = inner_listener.str(); + std::string explanation = inner_listener.str(); if (!explanation.empty()) { *listener << "\n\t" << explanation; } @@ -432,21 +433,21 @@ ::testing::Matcher impl::NodeWith( return ::testing::MakeMatcher(matcher); } -impl::NodeMatcherProperties Name(string name) { +impl::NodeMatcherProperties Name(std::string name) { impl::NodeMatcherProperties props; props.set_name(std::move(name)); return props; } // Matches a node with op `op`. -impl::NodeMatcherProperties Op(string op) { +impl::NodeMatcherProperties Op(std::string op) { impl::NodeMatcherProperties props; props.set_op(std::move(op)); return props; } // Matches a node with assigned device `assigned_device`. -impl::NodeMatcherProperties AssignedDevice(string assigned_device) { +impl::NodeMatcherProperties AssignedDevice(std::string assigned_device) { impl::NodeMatcherProperties props; props.set_assigned_device(std::move(assigned_device)); return props; @@ -472,15 +473,15 @@ impl::NodeMatcherProperties impl::CtrlDeps( return props; } -std::pair impl::AttrLiteralHelper( - const std::pair& bool_attr) { +std::pair impl::AttrLiteralHelper( + const std::pair& bool_attr) { AttrValue attr_value; attr_value.set_b(bool_attr.second); return {bool_attr.first, attr_value}; } -std::pair impl::AttrLiteralHelper( - const std::pair>& int_list_attr) { +std::pair impl::AttrLiteralHelper( + const std::pair>& int_list_attr) { AttrValue attr_value; AttrValue::ListValue* list = attr_value.mutable_list(); for (int i : int_list_attr.second) { @@ -489,23 +490,24 @@ std::pair impl::AttrLiteralHelper( return {int_list_attr.first, attr_value}; } -std::pair impl::AttrLiteralHelper( - const std::pair>& string_list_attr) { +std::pair impl::AttrLiteralHelper( + const std::pair>& + string_list_attr) { AttrValue attr_value; AttrValue::ListValue* list = attr_value.mutable_list(); - for (const string& s : string_list_attr.second) { + for (const std::string& s : string_list_attr.second) { list->add_s(s); } return {string_list_attr.first, attr_value}; } -impl::NodeMatcherProperties impl::Attr(std::pair attr) { +impl::NodeMatcherProperties impl::Attr(std::pair attr) { impl::NodeMatcherProperties props; props.set_attr(std::move(attr)); return props; } -impl::NodeMatcherProperties impl::Attr(string name) { +impl::NodeMatcherProperties impl::Attr(std::string name) { impl::NodeMatcherProperties props; props.set_attr({std::move(name), std::nullopt}); return props; diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h index bb2c1875306185..1391df3743bd4c 100644 --- a/tensorflow/compiler/jit/node_matchers.h +++ b/tensorflow/compiler/jit/node_matchers.h @@ -84,11 +84,11 @@ class NodeMatcherProperties { public: using NodeSeqMatcher = std::vector<::testing::Matcher>; using InputSeqMatcher = std::vector<::testing::Matcher>; - using AttrKeyValuePair = std::pair>; + using AttrKeyValuePair = std::pair>; - const std::optional& name() const { return name_; } - const std::optional& op() const { return op_; } - const std::optional& assigned_device() const { + const std::optional& name() const { return name_; } + const std::optional& op() const { return op_; } + const std::optional& assigned_device() const { return assigned_device_; } const std::optional& constant_value() const { @@ -102,17 +102,17 @@ class NodeMatcherProperties { } const std::optional& attr() const { return attr_; } - void set_name(string name) { + void set_name(std::string name) { DCHECK(IsEmpty()); name_ = std::move(name); } - void set_op(string op) { + void set_op(std::string op) { DCHECK(IsEmpty()); op_ = std::move(op); } - void set_assigned_device(string assigned_device) { + void set_assigned_device(std::string assigned_device) { DCHECK(IsEmpty()); assigned_device_ = std::move(assigned_device); } @@ -144,9 +144,9 @@ class NodeMatcherProperties { } private: - std::optional name_; - std::optional op_; - std::optional assigned_device_; + std::optional name_; + std::optional op_; + std::optional assigned_device_; std::optional constant_value_; std::optional input_matchers_; std::optional control_deps_; @@ -162,39 +162,40 @@ impl::NodeMatcherProperties Inputs( impl::NodeMatcherProperties CtrlDeps( absl::Span> control_deps); -impl::NodeMatcherProperties Attr(std::pair attrs); -impl::NodeMatcherProperties Attr(string name); +impl::NodeMatcherProperties Attr(std::pair attrs); +impl::NodeMatcherProperties Attr(std::string name); -std::pair AttrLiteralHelper( - const std::pair& bool_attr); +std::pair AttrLiteralHelper( + const std::pair& bool_attr); -std::pair AttrLiteralHelper( - const std::pair>& int_list_attr); +std::pair AttrLiteralHelper( + const std::pair>& int_list_attr); -std::pair AttrLiteralHelper( - const std::pair>& string_list_attr); +std::pair AttrLiteralHelper( + const std::pair>& + string_list_attr); } // namespace impl // ----------------------------------------------------------------------------- // Public interface. // Matches a node with name `name`. -impl::NodeMatcherProperties Name(string name); +impl::NodeMatcherProperties Name(std::string name); // Matches a node with op `op`. -impl::NodeMatcherProperties Op(string op); +impl::NodeMatcherProperties Op(std::string op); // Matches a node with assigned device `assigned_device`. -impl::NodeMatcherProperties AssignedDevice(string assigned_device); +impl::NodeMatcherProperties AssignedDevice(std::string assigned_device); // Matches a node with a boolean typed attribute named `name` and with value // `value`. template -impl::NodeMatcherProperties Attr(const string& name, ValueTy value) { +impl::NodeMatcherProperties Attr(const std::string& name, ValueTy value) { return impl::Attr({impl::AttrLiteralHelper({name, value})}); } -inline impl::NodeMatcherProperties Attr(const string& name) { +inline impl::NodeMatcherProperties Attr(const std::string& name) { return impl::Attr(name); } diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc index 6f37d5617b6ce6..ac1d9ce3468df1 100644 --- a/tensorflow/compiler/jit/node_matchers_test.cc +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -41,7 +41,7 @@ using testing::matchers::Op; using testing::matchers::Out; template -string Explain(const T& t, const M& m) { +std::string Explain(const T& t, const M& m) { ::testing::StringMatchResultListener listener; EXPECT_THAT(t, ::testing::Not(m)); // For the error message. EXPECT_FALSE(m.MatchAndExplain(t, &listener)); diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index c8bbcee20e3829..9539a14d060f42 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -100,7 +100,7 @@ absl::Status PartiallyDecluster(std::unique_ptr* graph) { return pass.Run(opt_options); } -Node* FindNodeByName(const Graph& graph, const string& name) { +Node* FindNodeByName(const Graph& graph, const std::string& name) { for (Node* node : graph.nodes()) { if (node->name() == name) { return node; @@ -109,7 +109,7 @@ Node* FindNodeByName(const Graph& graph, const string& name) { return nullptr; } -bool GetInputsForNode(const Graph& graph, const string& node_name, +bool GetInputsForNode(const Graph& graph, const std::string& node_name, std::vector* inputs) { const Node* node = FindNodeByName(graph, node_name); if (node == nullptr) { @@ -292,7 +292,7 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { void AddToCluster(absl::Span nodes, absl::string_view cluster_name) { for (Node* n : nodes) { - n->AddAttr(kXlaClusterAttr, string(cluster_name)); + n->AddAttr(kXlaClusterAttr, std::string(cluster_name)); } } diff --git a/tensorflow/compiler/jit/pjrt_base_device.cc b/tensorflow/compiler/jit/pjrt_base_device.cc index ce7ed954575040..d25d77d6cff22b 100644 --- a/tensorflow/compiler/jit/pjrt_base_device.cc +++ b/tensorflow/compiler/jit/pjrt_base_device.cc @@ -17,8 +17,8 @@ limitations under the License. namespace tensorflow { namespace { -DeviceAttributes BuildPjRtBaseDeviceAttributes(const string& name_prefix, - const string& device_name, +DeviceAttributes BuildPjRtBaseDeviceAttributes(const std::string& name_prefix, + const std::string& device_name, int device_ordinal) { return Device::BuildDeviceAttributes( absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc index e4d88f5816ec87..0bbad40fe5f25c 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.cc +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -139,7 +139,7 @@ void PjRtDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, return; } - xla::PjRtFuture<> future = device_buffer->ToLiteral(literal.get()); + tsl::Future future = device_buffer->ToLiteral(literal.get()); future.OnReady([literal = std::move(literal), done = std::move(done)]( const absl::Status& status) { done(status); }); } diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 2fee2b0b898890..33f09704d7c72b 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -143,7 +143,7 @@ bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { using ResourceOp = std::pair; -string ResourceOpToString(const ResourceOp& resource_op) { +std::string ResourceOpToString(const ResourceOp& resource_op) { return absl::StrCat( resource_op.first, ": ", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); @@ -233,14 +233,14 @@ class ResourceOpSet { void operator=(const ResourceOpSet&) = delete; }; -string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { - std::vector elements_debug_string; +std::string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { + std::vector elements_debug_string; std::transform(resource_op_set.begin(), resource_op_set.end(), std::back_inserter(elements_debug_string), ResourceOpToString); return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); } -string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { +std::string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { return absl::StrCat( "[", n.name(), ": ", n.type_string(), "(", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc index 8a80b8ae9b3497..6b038c992f1715 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -38,7 +38,7 @@ limitations under the License. namespace tensorflow { namespace { -Node* MakeRead(const Scope& scope, const string& id) { +Node* MakeRead(const Scope& scope, const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output read = @@ -46,7 +46,7 @@ Node* MakeRead(const Scope& scope, const string& id) { return read.node(); } -Node* MakeWrite(const Scope& scope, const string& id) { +Node* MakeWrite(const Scope& scope, const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output value_to_write = @@ -56,7 +56,7 @@ Node* MakeWrite(const Scope& scope, const string& id) { return assign_op.operation.node(); } -Node* MakeModify(const Scope& scope, const string& id) { +Node* MakeModify(const Scope& scope, const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f); @@ -65,7 +65,7 @@ Node* MakeModify(const Scope& scope, const string& id) { return assign_add_op.operation.node(); } -Node* MakeNeutral(const Scope& scope, const string& id) { +Node* MakeNeutral(const Scope& scope, const std::string& id) { return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); } @@ -238,7 +238,8 @@ TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) { EXPECT_EQ(incompatible_pairs[1], write_modify_pair); } -FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { +FunctionDefLibrary CreateFunctionDefLibWithConstFunction( + const std::string& name) { FunctionDefLibrary flib_def; FunctionDef func = FunctionDefHelper::Create( /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, @@ -249,8 +250,8 @@ FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { return flib_def; } -Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name, - absl::Status* status) { +Node* MakeCall(Graph* graph, const std::string& callee_name, + const std::string& node_name, absl::Status* status) { NodeDef call_node; call_node.set_name(node_name); call_node.set_op(callee_name); diff --git a/tensorflow/compiler/jit/shape_inference.h b/tensorflow/compiler/jit/shape_inference.h index 467ecb83a74aae..b1469d2d699bf1 100644 --- a/tensorflow/compiler/jit/shape_inference.h +++ b/tensorflow/compiler/jit/shape_inference.h @@ -35,7 +35,8 @@ struct InferredShape { DataType handle_type = DT_INVALID; PartialTensorShape handle_shape; }; -typedef std::unordered_map> GraphShapeInfo; +typedef std::unordered_map> + GraphShapeInfo; // Infer shapes for all Tensors in a graph, and save them in a map. The vector // for a Node contains the information about each of its outputs. diff --git a/tensorflow/compiler/jit/shape_inference_test.cc b/tensorflow/compiler/jit/shape_inference_test.cc index eaabf18c79603c..599d442de4b092 100644 --- a/tensorflow/compiler/jit/shape_inference_test.cc +++ b/tensorflow/compiler/jit/shape_inference_test.cc @@ -61,7 +61,7 @@ TEST(ShapeInferenceTest, Basics) { TF_ASSERT_OK(InferShapes(graph.get(), /*arg_shapes=*/{}, /*fnlib_def=*/nullptr, &shape_info)); - std::map> expected = { + std::map> expected = { {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({3})}}, {"C", {PartialTensorShape()}}, {"D", {PartialTensorShape({2, 3})}}, {"E", {PartialTensorShape()}}, {"F", {PartialTensorShape()}}, @@ -94,7 +94,7 @@ TEST(ShapeInferenceTest, UseArgShapesForVariableBatchSize) { TF_ASSERT_OK(InferShapes(graph.get(), arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); - std::map> expected = { + std::map> expected = { {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({2, 3})}}, {"C", {PartialTensorShape({2, 3})}}, @@ -127,7 +127,7 @@ TEST(ShapeInferenceTest, UseArgShapesForVariableBatchSizeIncompleteUserArgs) { TF_ASSERT_OK(InferShapes(graph.get(), arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); - std::map> expected = { + std::map> expected = { {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({2, 3})}}, {"C", {PartialTensorShape({2, 3})}}, @@ -156,7 +156,7 @@ TEST(ShapeInferenceTest, WhileLoop) { ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); auto merge = ops::Merge(scope.WithOpName("while/Merge"), std::initializer_list{enter, dummy}); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 10); auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); @@ -168,11 +168,11 @@ TEST(ShapeInferenceTest, WhileLoop) { auto identity = ops::Identity(scope.WithOpName("while/Identity"), switch_node.output_true); auto identity_shape = - ops::Const(scope.WithOpName("while/Identity/shape"), {}); + ops::Const(scope.WithOpName("while/Identity/shape"), {}); auto identity_reshaped = ops::Reshape( scope.WithOpName("while/Identity/reshaped"), identity, identity_shape); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity_reshaped, one); auto next_iteration = @@ -190,7 +190,7 @@ TEST(ShapeInferenceTest, WhileLoop) { GraphShapeInfo shape_info; TF_ASSERT_OK(InferShapes(&graph, /*arg_shapes=*/{}, /*fnlib_def=*/nullptr, &shape_info)); - std::map> expected = { + std::map> expected = { {"while/Identity", {PartialTensorShape()}}, {"while/add", {PartialTensorShape({})}}, }; diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc index 81ab1d8d05f96e..30a9ab51faf105 100644 --- a/tensorflow/compiler/jit/test_util.cc +++ b/tensorflow/compiler/jit/test_util.cc @@ -29,7 +29,7 @@ namespace tensorflow { absl::Status ShapeAnnotationsMatch( const Graph& graph, const GraphShapeInfo& shape_info, - std::map> expected_shapes) { + std::map> expected_shapes) { for (Node* node : graph.op_nodes()) { auto sit = shape_info.find(node->name()); TF_RET_CHECK(sit != shape_info.end()) @@ -50,7 +50,7 @@ absl::Status ShapeAnnotationsMatch( } } if (!expected_shapes.empty()) { - std::vector missing; + std::vector missing; missing.reserve(expected_shapes.size()); for (const auto& entry : expected_shapes) { missing.push_back(entry.first); @@ -88,12 +88,12 @@ void DeviceSetup::AddDevicesAndSetUp( flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); } -Device* DeviceSetup::GetDevice(const string& device_name) { +Device* DeviceSetup::GetDevice(const std::string& device_name) { if (device_mgr_ == nullptr) { return nullptr; } - string full_device_name = absl::StrCat( + std::string full_device_name = absl::StrCat( "/job:localhost/replica:0/task:0/device:", device_name, ":0"); Device* device; TF_CHECK_OK(device_mgr_->LookupDevice(full_device_name, &device)); diff --git a/tensorflow/compiler/jit/test_util.h b/tensorflow/compiler/jit/test_util.h index ec694662297399..ba7d2533ef7c74 100644 --- a/tensorflow/compiler/jit/test_util.h +++ b/tensorflow/compiler/jit/test_util.h @@ -44,7 +44,7 @@ namespace tensorflow { // `expected_shapes` entries. absl::Status ShapeAnnotationsMatch( const Graph& graph, const GraphShapeInfo& shape_info, - std::map> expected_shapes); + std::map> expected_shapes); // A helper object to create GraphOptimizationPassOptions. struct GraphOptimizationPassWrapper { @@ -74,7 +74,7 @@ class DeviceSetup { void AddDevicesAndSetUp( const std::vector& device_names, const std::optional& fdef = std::nullopt); - Device* GetDevice(const string& device_name); + Device* GetDevice(const std::string& device_name); FunctionLibraryRuntime* flr() { return flr_; } private: diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test.cc b/tensorflow/compiler/jit/tests/auto_clustering_test.cc index 90e73c23d210d7..d108bc51b5ee33 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test.cc @@ -23,7 +23,7 @@ class AutoClusteringTestImpl : public AutoClusteringTest { protected: // Test auto-clustering with a proto text file ${key}.pbtxt. absl::Status RunAutoClusteringTestWithPbtxt(absl::string_view key) { - string file_name_without_extension = + std::string file_name_without_extension = absl::StrCat(testing::TensorFlowSrcRoot(), "/compiler/jit/tests/", key); return AutoClusteringTest::RunAutoClusteringTestWithPbtxt( @@ -33,7 +33,7 @@ class AutoClusteringTestImpl : public AutoClusteringTest { // Test auto-clustering with a gzipped proto text file ${key}.pbtxt.gz. absl::Status RunAutoClusteringTestWithGzippedPbtxt(absl::string_view key) { - string file_name_without_extension = + std::string file_name_without_extension = absl::StrCat(testing::TensorFlowSrcRoot(), "/compiler/jit/tests/", key); return AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index dee77ac750ee54..258449e91120e1 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -33,7 +33,7 @@ limitations under the License. namespace tensorflow { namespace { -absl::StatusOr SummarizeClustering( +absl::StatusOr SummarizeClustering( const GraphDef& auto_clustered_graph_def) { testing::ResetClusterSequenceNumber(); Graph graph(OpRegistry::Global()); @@ -45,7 +45,7 @@ absl::StatusOr SummarizeClustering( // cluster_id -> (operation name -> # of operations) const int kNoCluster = -1; - std::map> clusters; + std::map> clusters; std::map cluster_size; int clustered_nodes = 0; for (Node* n : graph.op_nodes()) { @@ -60,7 +60,7 @@ absl::StatusOr SummarizeClustering( cluster_size[cluster]++; } - string result = + std::string result = absl::StrCat("Clustered nodes: ", clustered_nodes, "\nUnclustered nodes: ", cluster_size[kNoCluster], "\nNumber of clusters: ", clusters.size() - 1, "\n\n"); @@ -99,7 +99,7 @@ absl::Status AssertGraphDefIsUnclustered(const GraphDef& graphdef) { return absl::OkStatus(); } -absl::Status ReadTextProtoFromString(Env* env, const string& data, +absl::Status ReadTextProtoFromString(Env* env, const std::string& data, ::tensorflow::protobuf::Message* proto) { if (!::tensorflow::protobuf::TextFormat::ParseFromString(data, proto)) { return errors::DataLoss("Can't parse input data as text proto"); @@ -141,7 +141,8 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestImpl( graphdef = std::move(next); } - TF_ASSIGN_OR_RETURN(string clustering_summary, SummarizeClustering(graphdef)); + TF_ASSIGN_OR_RETURN(std::string clustering_summary, + SummarizeClustering(graphdef)); // To update golden files flip this to true and run // @@ -149,13 +150,15 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestImpl( // tensorflow/compiler/jit/tests:auto_clustering_test bool update_golden = false; if (update_golden) { - TF_RETURN_IF_ERROR(WriteStringToFile( - Env::Default(), string(golden_summary_file_path), clustering_summary)); + TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), + std::string(golden_summary_file_path), + clustering_summary)); } - string golden_file_contents; - TF_RETURN_IF_ERROR(ReadFileToString( - Env::Default(), string(golden_summary_file_path), &golden_file_contents)); + std::string golden_file_contents; + TF_RETURN_IF_ERROR(ReadFileToString(Env::Default(), + std::string(golden_summary_file_path), + &golden_file_contents)); EXPECT_EQ(golden_file_contents, clustering_summary); @@ -167,7 +170,7 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestWithPbtxt( absl::string_view golden_summary_file_path) { GraphDef graphdef; TF_RETURN_IF_ERROR( - ReadTextProto(Env::Default(), string(pbtxt_file_path), &graphdef)); + ReadTextProto(Env::Default(), std::string(pbtxt_file_path), &graphdef)); return RunAutoClusteringTestImpl(std::move(graphdef), golden_summary_file_path); } @@ -177,8 +180,8 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( absl::string_view golden_summary_file_path) { Env* env = Env::Default(); std::unique_ptr file_reader; - TF_RETURN_IF_ERROR( - env->NewRandomAccessFile(string(gzipped_pbtxt_file_path), &file_reader)); + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + std::string(gzipped_pbtxt_file_path), &file_reader)); std::unique_ptr input_stream( new io::RandomAccessInputStream(file_reader.get())); constexpr int k_buffer_size = 256 << 10; // 256kb @@ -206,7 +209,7 @@ absl::Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, benchmark::State& state) { GraphDef graph_def; TF_RETURN_IF_ERROR( - ReadTextProto(Env::Default(), string(graph_def_path), &graph_def)); + ReadTextProto(Env::Default(), std::string(graph_def_path), &graph_def)); OptimizationPassRunner runner; TF_RETURN_IF_ERROR(runner.SetJitLevel(tensorflow::OptimizerOptions::ON_2)); diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc index e4be1a1f641656..33e2daf941eafb 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc @@ -29,7 +29,7 @@ namespace { // Creates a float tensor of linearly increasing values, starting from offset. Tensor CreateInputTensor(const TensorShape& shape, float offset) { Tensor tensor(DT_FLOAT, shape); - for (int64 i = 0; i < tensor.flat().size(); ++i) { + for (int64_t i = 0; i < tensor.flat().size(); ++i) { tensor.flat()(i) = offset + i; } return tensor; @@ -127,7 +127,7 @@ absl::Status DeviceCompilerSerializeTest::ExecuteWithBatch( } Tensor f32_input(DT_FLOAT, shape); - for (int64 i = 0; i < f32_input.NumElements(); ++i) { + for (int64_t i = 0; i < f32_input.NumElements(); ++i) { EXPECT_NEAR(golden_output_tensors[0].flat()(i), output_tensors[0].flat()(i), 1e-3); } @@ -139,7 +139,7 @@ DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( absl::string_view persistent_cache_dir_path, absl::string_view file_prefix) { Env* env = Env::Default(); - std::vector file_names; + std::vector file_names; TF_RETURN_IF_ERROR( env->GetChildren(tensorflow::testing::TmpDir(), &file_names)); diff --git a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc index bec124f1866689..62089beed8224f 100644 --- a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc +++ b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc @@ -34,8 +34,12 @@ namespace { class XlaHostSendRecvDeviceContextTest : public ::testing::Test { public: - void SetDevice(const string& device_type) { + absl::Status SetDevice(const string& device_type) { auto device_factory = DeviceFactory::GetFactory(device_type); + if (device_factory == nullptr) { + return absl::NotFoundError( + "Failed to get DeviceFactory for device_type: " + device_type); + } SessionOptions options; std::vector> devices; Status s = device_factory->CreateDevices( @@ -49,6 +53,7 @@ class XlaHostSendRecvDeviceContextTest : public ::testing::Test { AllocatorAttributes device_alloc_attr; device_alloc_attr.set_on_host(false); device_allocator_ = device_->GetAllocator(device_alloc_attr); + return absl::OkStatus(); } protected: @@ -58,7 +63,7 @@ class XlaHostSendRecvDeviceContextTest : public ::testing::Test { }; TEST_F(XlaHostSendRecvDeviceContextTest, CopyDeviceTensorToCPU) { - SetDevice("GPU"); + TF_ASSERT_OK(SetDevice("GPU")); Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); test::FillValues(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5}); Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2})); @@ -93,7 +98,7 @@ TEST_F(XlaHostSendRecvDeviceContextTest, CopyDeviceTensorToCPU) { } TEST_F(XlaHostSendRecvDeviceContextTest, CopyCPUTensorToDevice) { - SetDevice("GPU"); + TF_ASSERT_OK(SetDevice("GPU")); Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); test::FillValues(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5}); Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2})); @@ -127,7 +132,7 @@ TEST_F(XlaHostSendRecvDeviceContextTest, CopyCPUTensorToDevice) { } TEST_F(XlaHostSendRecvDeviceContextTest, RoundTrip) { - SetDevice("GPU"); + TF_ASSERT_OK(SetDevice("GPU")); Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); test::FillValues(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5}); Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2})); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index f26fcd34df7583..8ccb236897ce39 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -45,11 +45,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "xla/client/local_client.h" +#include "xla/future.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/pjrt_future.h" #include "xla/service/executable.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/service/shaped_buffer.h" @@ -809,8 +809,6 @@ xla::ExecuteOptions GetPjRtExecuteOptions( const DeviceType& device_type, absl::flat_hash_set non_donatable_input_indices) { xla::ExecuteOptions options; - options.arguments_are_tupled = false; - options.untuple_result = true; // Hardcode run id to always be one: TF distributed strategy // differentiates between subsequent runs using dependency edges. This // is safe, as only TF dist-strat can produce distributed ops, and we @@ -925,7 +923,7 @@ absl::StatusOr>> RunPjRtExecutable( &executable_args, &owned_executable_args, &non_donatable_input_indices)); std::vector> execute_outputs; - std::optional> future; + std::optional> future; if (executable->num_replicas() != 1 || executable->num_partitions() != 1) { TF_ASSIGN_OR_RETURN( execute_outputs, diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc index 9e71286dc95df8..d8ed5feac79f12 100644 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -207,8 +207,6 @@ class PjRtExecutionUtilTest : public OpsTestBase { &executable_args, /*owned_args=*/{}, &non_donatable_input_indices)); xla::ExecuteOptions exe_options; - exe_options.arguments_are_tupled = false; - exe_options.untuple_result = true; // TODO(b/257548614): currently PJRT is compiled as portable (num_replica = // 1 and num_partition = 1). Support multiple partitions case. @@ -520,8 +518,6 @@ TEST_F(PjRtExecutionUtilTest, PopulateCtxOutputsResourceUpdates) { TEST(XlaLaunchUtilTest, GetPjRtExecuteOptions) { xla::ExecuteOptions options = GetPjRtExecuteOptions(DeviceType(DEVICE_GPU), {}); - EXPECT_FALSE(options.arguments_are_tupled); - EXPECT_TRUE(options.untuple_result); EXPECT_FALSE(options.strict_shape_checking); EXPECT_TRUE(options.use_major_to_minor_data_layout_for_callbacks); } diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 7f200aa186a466..ab6c5abeca86f0 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1990,7 +1990,6 @@ cc_library( ":tf_tfl_passes", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/lite/core:macros", "//tensorflow/compiler/mlir/lite/debug", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:converter_error_data_proto_cc", @@ -2212,10 +2211,8 @@ tf_proto_library( srcs = ["converter_flags.proto"], make_default_target_header_only = True, protodeps = [ - "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_options_proto", - "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto", - "//tensorflow/compiler/mlir/lite/debug:debug_options_proto", ":types_proto", + "//tensorflow/compiler/mlir/lite/debug:debug_options_proto", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/compiler/mlir/lite/converter_flags.proto b/tensorflow/compiler/mlir/lite/converter_flags.proto index 1c1a1ad00aea74..49795ad8337d9a 100644 --- a/tensorflow/compiler/mlir/lite/converter_flags.proto +++ b/tensorflow/compiler/mlir/lite/converter_flags.proto @@ -17,8 +17,6 @@ package tflite; import "tensorflow/compiler/mlir/lite/debug/debug_options.proto"; import "tensorflow/compiler/mlir/lite/types.proto"; -import "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto"; -import "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto"; // Supported I/O file formats. Some formats may be input-only or output-only. enum FileFormat { @@ -43,6 +41,8 @@ enum FileFormat { // // Next ID to use: 69. message ConverterFlags { + reserved 54, 61; + // Input file format optional FileFormat input_format = 1; @@ -312,12 +312,6 @@ message ConverterFlags { // If true, disable folding mul->fc as in layer norm during optimize pass. optional bool disable_fuse_mul_and_fc = 53 [default = false]; - // Indicates the quantization specs. Quantization spec can be set to either - // a preset method or a custom method. - // Note: This is deprecated; use `quantization_config` instead. - optional stablehlo.quantization.QuantizationOptions quantization_options = 54 - [deprecated = true]; - // Flag to enable hlo to tf conversion. // This is useful to exercise StableHLO -> HLO -> TF -> TFLite path. optional bool enable_hlo_to_tf_conversion = 55 @@ -346,11 +340,6 @@ message ConverterFlags { // WARNING: Experimental interface, subject to change. optional string qdq_conversion_mode = 60 [default = "NONE"]; - // Configures quantization behavior. This config is fed to the StableHLO - // Quantizer integrated in the converter. - // WARNING: Experimental interface, subject to change. - optional stablehlo.quantization.QuantizationConfig quantization_config = 61; - // Disables per channel weights quantization for Dense layers and enables // legacy per tensor quantization. The legacy quantization for Dense layers is // inconsistent with Conv 1x1 which always performs per channel quantization. diff --git a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc index 0e8210b97e315b..7facd69ecca298 100644 --- a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc @@ -341,6 +341,7 @@ using tflite::TensorType_FLOAT16; using tflite::TensorType_FLOAT32; using tflite::TensorType_FLOAT64; using tflite::TensorType_INT16; +using tflite::TensorType_INT2; using tflite::TensorType_INT32; using tflite::TensorType_INT4; using tflite::TensorType_INT64; @@ -1400,6 +1401,9 @@ absl::Status ConvertTensorType(TensorType tensor_type, TfLiteType* type) { case TensorType_INT4: *type = kTfLiteInt4; return OkStatus(); + case TensorType_INT2: + *type = kTfLiteInt2; + return OkStatus(); default: *type = kTfLiteNoType; auto error_message = diff --git a/tensorflow/compiler/mlir/lite/core/c/tflite_types.h b/tensorflow/compiler/mlir/lite/core/c/tflite_types.h index 068facb10761c7..f09923dda5fc7c 100644 --- a/tensorflow/compiler/mlir/lite/core/c/tflite_types.h +++ b/tensorflow/compiler/mlir/lite/core/c/tflite_types.h @@ -64,6 +64,7 @@ typedef enum { kTfLiteUInt16 = 17, kTfLiteInt4 = 18, kTfLiteBFloat16 = 19, + kTfLiteInt2 = 20, } TfLiteType; // LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType) diff --git a/tensorflow/compiler/mlir/lite/debug/debug_test.cc b/tensorflow/compiler/mlir/lite/debug/debug_test.cc index 6c26865757950a..b82d5725182745 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug_test.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug_test.cc @@ -120,7 +120,7 @@ class InitPassManagerTest : public testing::Test { } absl::Status GetDumpDir(std::string* dump_dir) { - std::vector files; + std::vector files; if (auto status = tsl::Env::Default()->GetChildren(path_, &files); !status.ok()) { return status; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 93879c1d6254b9..41dffc228a6b2c 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -217,6 +217,13 @@ static StatusOr GetTFLiteType(Type type, switch (itype.getWidth()) { case 1: return tflite::TensorType_BOOL; + case 2: + if (itype.isUnsigned()) { + return Status(absl::StatusCode::kInvalidArgument, + "Unsupported 2bit unsigned int type"); + } else { + return tflite::TensorType_INT2; + } case 4: if (itype.isUnsigned()) { return Status(absl::StatusCode::kInvalidArgument, @@ -879,7 +886,7 @@ class Translator { std::vector>> string_buffers_to_delete_; std::vector>> - packed_int4_buffers_to_delete_; + packed_low_bit_buffers_to_delete_; // Maps custom options data to corresponding node // Key is set to be the list of input tensor indices and list of output tensor @@ -1027,18 +1034,21 @@ std::optional> Translator::BuildBuffer( auto type = mlir::cast(value.getType()); tflite::TensorType tflite_element_type = GetTFLiteType(type.getElementType()).value(); - if (tflite_element_type == tflite::TensorType_INT4) { + if (tflite_element_type == tflite::TensorType_INT4 || + tflite_element_type == tflite::TensorType_INT2) { std::vector data; for (mlir::APInt v : attr.getValues()) { data.emplace_back(static_cast(*(v.getRawData()))); } - auto packed_buffer = std::make_unique>( - tflite::PackInt4ValuesDensely(data)); + auto packed_buffer = + std::make_unique>(tflite::PackLowBitValuesDensely( + data, /*bit_width=*/( + tflite_element_type == tflite::TensorType_INT4 ? 4 : 2))); if (use_buffer_offset_) { buffer_data_map_[index] = absl::string_view(reinterpret_cast(packed_buffer->data()), packed_buffer->size()); - packed_int4_buffers_to_delete_.emplace_back(std::move(packed_buffer)); + packed_low_bit_buffers_to_delete_.emplace_back(std::move(packed_buffer)); return tflite::CreateBuffer(builder_, 0, 1, 1); } else { if (IsModelBiggerThan2GB(packed_buffer->size())) { @@ -4239,10 +4249,10 @@ std::optional Translator::TranslateInternal() { // Free all the buffers/tensors, etc. that were created but were kept around // to copy into the flatbuffer. - for (auto& packed_int4_buffer : packed_int4_buffers_to_delete_) { - packed_int4_buffer.reset(); + for (auto& packed_low_bit_buffer : packed_low_bit_buffers_to_delete_) { + packed_low_bit_buffer.reset(); } - packed_int4_buffers_to_delete_.clear(); + packed_low_bit_buffers_to_delete_.clear(); for (auto& str_buffer : string_buffers_to_delete_) { str_buffer.reset(); diff --git a/tensorflow/compiler/mlir/lite/integrations/BUILD b/tensorflow/compiler/mlir/lite/integrations/BUILD index cae74c9c3ac7b2..64baf0b2fb7731 100644 --- a/tensorflow/compiler/mlir/lite/integrations/BUILD +++ b/tensorflow/compiler/mlir/lite/integrations/BUILD @@ -40,6 +40,7 @@ pybind_extension( "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/python/lib/core:ndarray_tensor", "//tensorflow/python/lib/core:py_func_lib", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:ArithDialect", @@ -48,15 +49,14 @@ pybind_extension( "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", - "@llvm-project//mlir:MLIRBindingsPythonHeaders", - "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//third_party/python_runtime:headers", - "@pybind11", + "@nanobind", "@stablehlo//:register", "@stablehlo//:stablehlo_ops", "@stablehlo//:vhlo_ops", diff --git a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc index 480901785329e7..80975abd3e9a7a 100644 --- a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc +++ b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc @@ -21,9 +21,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "absl/strings/string_view.h" #include "llvm/Support/Casting.h" #include "mlir-c/IR.h" // from @llvm-project -#include "mlir/Bindings/Python/PybindAdaptors.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Bindings/Python/NanobindAdaptors.h" // from @llvm-project // IWYU pragma: keep #include "mlir/CAPI/IR.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project @@ -40,9 +41,10 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "pybind11/cast.h" // from @pybind11 -#include "pybind11/pybind11.h" // from @pybind11 -#include "pybind11/pytypes.h" // from @pybind11 +#include "nanobind/nanobind.h" // from @nanobind +#include "nanobind/stl/string.h" // from @nanobind +#include "nanobind/stl/string_view.h" // from @nanobind +#include "nanobind/stl/vector.h" // from @nanobind #include "stablehlo/dialect/Register.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo @@ -57,7 +59,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" -namespace py = pybind11; +namespace nb = nanobind; // ----------------------------------------------------------------------------- // Module initialization. @@ -70,7 +72,7 @@ class MlirPythonPass mlir::OperationPass> { public: explicit MlirPythonPass(std::string name, std::string description, - py::object pyfunc) + nb::object pyfunc) : name_(name), description_(description), pyfunc_(pyfunc) { pyfunc.inc_ref(); } @@ -85,8 +87,8 @@ class MlirPythonPass auto module_clone = getOperation().clone(); MlirModule c_module = wrap(module_clone); - auto py_module = py::cast(c_module); - auto py_args = py::make_tuple(py_module); + auto py_module = nb::cast(c_module); + auto py_args = nb::make_tuple(py_module); PyObject* py_pass_ret = PyObject_CallObject(pyfunc_.ptr(), py_args.ptr()); if (py_pass_ret == nullptr || PyErr_Occurred()) { @@ -95,8 +97,8 @@ class MlirPythonPass signalPassFailure(); return; } - auto py_new_module_op = py::cast(py_pass_ret); - auto c_new_module_op = py::cast(py_new_module_op); + auto py_new_module_op = nb::steal(py_pass_ret); + auto c_new_module_op = nb::cast(py_new_module_op); mlir::Operation* new_module_op = unwrap(c_new_module_op); // TODO: Copy attributes from new_module @@ -108,7 +110,7 @@ class MlirPythonPass private: std::string name_; std::string description_; - py::object pyfunc_; + nb::object pyfunc_; }; inline void RegisterDialects(mlir::DialectRegistry& registry) { @@ -131,7 +133,7 @@ inline void RegisterPasses() { []() { return mlir::TFL::CreateOptimizePass(); }); } -PYBIND11_MODULE(model_utils_core_pybind, m) { +NB_MODULE(model_utils_core_pybind, m) { Py_Initialize(); m.doc() = "LiteRT ModelUtils Core Pybinds"; @@ -142,7 +144,7 @@ PYBIND11_MODULE(model_utils_core_pybind, m) { m.def("mlir_opt_main", [](std::vector argv, std::vector pass_names, std::vector pass_descriptions, - std::vector pass_fns) { + std::vector pass_fns) { std::vector c_argv_vec; c_argv_vec.reserve(argv.size()); for (size_t i = 0; i < argv.size(); ++i) @@ -178,14 +180,15 @@ PYBIND11_MODULE(model_utils_core_pybind, m) { }); m.def("flatbuffer_to_mlir", - [](py::bytes buffer, MlirContext context) -> MlirModule { + [](nb::bytes buffer, MlirContext context) -> MlirModule { mlir::DialectRegistry registry; RegisterDialects(registry); unwrap(context)->appendDialectRegistry(registry); unwrap(context)->loadAllAvailableDialects(); auto module_op = tflite::FlatBufferToMlir( - buffer, unwrap(context), mlir::UnknownLoc::get(unwrap(context))); + absl::string_view(buffer.c_str(), buffer.size()), unwrap(context), + mlir::UnknownLoc::get(unwrap(context))); return wrap(module_op.release()); }); @@ -197,7 +200,7 @@ PYBIND11_MODULE(model_utils_core_pybind, m) { std::string result; tflite::MlirToFlatBufferTranslateFunction(module_op, options, &result, true); - return py::bytes(result); + return nb::bytes(result.data(), result.size()); }); m.def("get_operation_attribute_names", [](MlirOperation c_op) { @@ -227,7 +230,7 @@ PYBIND11_MODULE(model_utils_core_pybind, m) { PyObject* np_array = Py_None; status = tensorflow::TensorToNdarray(tensor, &np_array); - return py::reinterpret_steal(np_array); + return nb::steal(np_array); }); } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index e12fc16c56a49e..08c37384741ca4 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -247,16 +247,26 @@ bool ShouldFoldOperation(Operation* inst) { return size; }; - int64_t results_size = get_size(inst->getResultTypes()); - int64_t operands_size = get_size(inst->getOperandTypes()); + int64_t inputs_size = get_size(inst->getOperandTypes()); + int64_t outputs_size = get_size(inst->getResultTypes()); - constexpr int kSizeFactor = 2; - constexpr int64_t kResultsSizeThreshold = (1 << 19); // 64 KiB - constexpr int64_t kOperandsSizeThreshold = 200L * 1024 * 1024 * 8; // 200 MiB + constexpr int64_t kInputsSizeThreshold = 200L * 1024 * 1024 * 8; // 200 MiB + constexpr int64_t kOutputsSizeThreshold = + 2 * kInputsSizeThreshold; // 400 MiB - return (operands_size <= kOperandsSizeThreshold) && - ((results_size <= kResultsSizeThreshold) || - (results_size <= kSizeFactor * operands_size)); + auto output_size_is_smaller_than_inputs = outputs_size <= inputs_size; + + auto inputs_and_outputs_smaller_than_arbitrary_thresholds = + (inputs_size <= kInputsSizeThreshold) && + (outputs_size <= kOutputsSizeThreshold); + + // Folding rules are: + // 1. if the size of the resulting outputs are smaller than the inputs then + // just do the fold. The model size will be smaller as a result. + // 2. if the inputs and outputs sizes are smaller than certain thresholds, do + // the fold regardless of their impact on model size. + return output_size_is_smaller_than_inputs || + inputs_and_outputs_smaller_than_arbitrary_thresholds; } // Returns dimension index for the given axis that supports negative @@ -4374,10 +4384,23 @@ OpFoldResult CastFloatToFloat(DenseFPElementsAttr data, FloatType in_type, return DenseFPElementsAttr::get(result_type, MapStaticCast(data)); } + + if (in_type.isF32() && out_type.isF16()) { + return data.mapValues(out_type, [&](const APFloat& old_value) { + APFloat value(old_value); + bool unused_loses_info; + value.convert(out_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused_loses_info); + return value.bitcastToAPInt(); + }); + } return {}; } OpFoldResult CastOp::fold(FoldAdaptor adaptor) { + auto in_type = getInput().getType().getElementType(); + auto out_type = getType().getElementType(); + if (!ShouldFoldOperation(this->getOperation())) return {}; auto operands = adaptor.getOperands(); @@ -4390,9 +4413,6 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { auto input = operands[0]; - auto in_type = getInput().getType().getElementType(); - auto out_type = getType().getElementType(); - if (auto int_in_type = llvm::dyn_cast_or_null(in_type)) { auto in_data = llvm::dyn_cast_or_null(input); if (!in_data) { @@ -4962,7 +4982,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl& regions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -5233,6 +5253,22 @@ int64_t SoftmaxOp::GetArithmeticCount(Operation* op) { // TanhOp //===----------------------------------------------------------------------===// +OpFoldResult TanhOp::fold(FoldAdaptor adaptor) { + if (!ShouldFoldOperation(this->getOperation())) return {}; + + auto operands = adaptor.getOperands(); + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::tanh(f); + return APFloat(result); + }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + int64_t TanhOp::GetArithmeticCount(Operation* op) { int64_t count; // As a very rough ballpark, the cost of evaluating a math function @@ -5719,6 +5755,10 @@ static FailureOr> parseI32Array(AsmParser& parser) { } // namespace TFL } // namespace mlir +using namespace mlir; // NOLINT +using mlir::TFL::ControlType; +using mlir::TFL::LSTMKernelTypeAttr; + #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_dialect.cc.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_enums.cc.inc" #define GET_ATTRDEF_CLASSES diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 44370d1cfdeb96..c90859cd6accfe 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -112,6 +112,7 @@ class TFL_VariadicTensorOf allowedRuntimeTypes, Variadic>, TFL_RuntimeType>>; +def TFL_I2 : I<2>; def TFL_I4 : I<4>; def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>; @@ -1099,7 +1100,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input, - TFL_TensorOf<[F32, QI4, QI8, QUI8, QI16]>:$filter, + TFL_TensorOf<[F32, QI2, QI4, QI8, QUI8, QI16]>:$filter, TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_AFAttr:$fused_activation_function, @@ -2476,13 +2477,13 @@ equivalent to setting: }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, + TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32OrI64Tensor:$begin, TFL_I32OrI64Tensor:$size ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output + TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output ); let hasVerifier = 1; @@ -3574,6 +3575,8 @@ def TFL_TanhOp: TFL_Op<"tanh", [ /*scale=*/1.0 / (1<<(bit_width-1)), /*zero_point=*/0); } }]; + + let hasFolder = 1; } def TFL_TileOp: TFL_Op<"tile", [ @@ -4072,13 +4075,10 @@ def TFL_CastOp : TFL_Op<"cast", [ }]; let arguments = (ins - TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$input + TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I2, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$input ); - // TODO(b/393644251): Temporary support for INT4 TFL_CastOp. Runtime - // probably already supports INT4. We should remove the INT4 support here or - // make sure the runtime supports is there, as part of closing the bug. - let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$output); + let results = (outs TFL_TensorOf<[F16, BF16, F32, F64, I1, TFL_I2, TFL_I4, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$output); // TFLite's cast op does not utilize CastOptions, instead derives types // from the TfLiteTensors. @@ -4281,7 +4281,7 @@ def TFL_DequantizeOp: TFL_Op<"dequantize", [NoMemoryEffect]> { quantization parameters. }]; - let arguments = (ins TFL_TensorOf<[QI4, QI8, QUI8, QI16, F16]>:$input); + let arguments = (ins TFL_TensorOf<[QI2, QI4, QI8, QUI8, QI16, F16]>:$input); let results = (outs TFL_FpTensor:$output); diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape_test.cc b/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape_test.cc index a3ae7f73b24f24..b5a3319ba13362 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape_test.cc +++ b/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape_test.cc @@ -19,9 +19,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h" #include -#include #include -#include #include #include diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h index aa700dc166e046..29ed664e7ae78f 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h @@ -31,7 +31,7 @@ namespace tensorflow { // error status if it fails to convert the input. absl::Status ConvertJaxToTFLiteFlatBuffer( const std::string& input, const tflite::ModelFlags& model_flags, - tflite::ConverterFlags& converter_flags, string* result); + tflite::ConverterFlags& converter_flags, std::string* result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index fa94cd3b5b8120..c334f24442b491 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -140,8 +140,8 @@ absl::Status ConvertSavedModelToTFLiteFlatBuffer( mlir::TFL::QuantizationSpecs quant_specs; // Parse input arrays. - std::vector node_names; - std::vector node_dtypes; + std::vector node_names; + std::vector node_dtypes; std::vector>> node_shapes; std::vector> node_mins; std::vector> node_maxs; @@ -210,8 +210,6 @@ absl::Status ConvertSavedModelToTFLiteFlatBuffer( converter_flags.convert_to_stablehlo(); pass_config.legalize_custom_tensor_list_ops = converter_flags.legalize_custom_tensor_list_ops(); - pass_config.enable_stablehlo_quantizer = - converter_flags.has_quantization_config(); pass_config.enable_composite_direct_lowering = converter_flags.enable_composite_direct_lowering(); pass_config.model_origin_framework = converter_flags.model_origin_framework(); diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index 33b9bacf2dfdeb..446652ccb8da05 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -32,7 +32,7 @@ namespace tensorflow { // error status if it fails to convert the input. absl::Status ConvertSavedModelToTFLiteFlatBuffer( const tflite::ModelFlags& model_flags, - tflite::ConverterFlags& converter_flags, string* result, + tflite::ConverterFlags& converter_flags, std::string* result, const quantization::PyFunctionLibrary* quantization_py_function_lib); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index f837a6f0140e7b..de75080ab5da82 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -46,8 +46,8 @@ absl::Status RegisterAllCustomOps( absl::Status PopulateQuantizationSpecs( const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, - mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, - std::vector* node_dtypes, + mlir::TFL::QuantizationSpecs* quant_specs, + std::vector* node_names, std::vector* node_dtypes, std::vector>>* node_shapes, std::vector>* node_mins, std::vector>* node_maxs); @@ -60,7 +60,8 @@ absl::Status ConvertMLIRToTFLiteFlatBuffer( std::unique_ptr&& context, mlir::OwningOpRef module, const mlir::TFL::PassConfig& pass_config, - const std::unordered_set& saved_model_tags, string* result, + const std::unordered_set& saved_model_tags, + std::string* result, const quantization::PyFunctionLibrary* quantization_py_function_lib); // Give a warning for any unused flags that have been specified. diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td index 26bcf0ee0022d3..1653eb8a737482 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td @@ -56,6 +56,7 @@ class Int8UniformQuantizedType // General uniform quantized types. The definitions can be used to specify // operand's tensor types. +def QI2 : QuantizedType<"Uniform", [2], 1>; def QI4 : QuantizedType<"Uniform", [4], 1>; def QUI8 : QuantizedType<"Uniform", [8], 0>; def QI8 : QuantizedType<"Uniform", [8], 1>; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index ae3b6233f8e959..1e1f79af16cbd6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -93,7 +93,7 @@ std::vector GetAsVector(const flatbuffers::Vector* vec) { class QuantizeWeightsTest : public testing::Test { protected: - QuantizeWeightsTest() {} + QuantizeWeightsTest() = default; void LoadBasicModel() { input_model_ = ReadTestModel(); diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD index 8d57263800e2b4..d70688fc488a6a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD @@ -14,10 +14,7 @@ cc_library( name = "portable_tensor_utils", srcs = ["portable_tensor_utils.cc"], hdrs = ["portable_tensor_utils.h"], - visibility = [ - "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:__pkg__", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:__pkg__", - ], + visibility = ["//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:__pkg__"], ) cc_library( diff --git a/tensorflow/compiler/mlir/lite/schema/schema.fbs b/tensorflow/compiler/mlir/lite/schema/schema.fbs index d74477be913d34..6cd1c51fb0cf9e 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema.fbs +++ b/tensorflow/compiler/mlir/lite/schema/schema.fbs @@ -24,6 +24,8 @@ // Version 3c: Move constant tensor buffers & custom op buffers outside from // Flatbuffers. Has backward compatibility with version 3, 3a and // 3b. +// Version 3d: Add ExternalBuffer tables and tensor.external_buffer field for +// referencing immutable data stored in external files. namespace tflite; @@ -59,6 +61,7 @@ enum TensorType : byte { UINT16 = 16, INT4 = 17, BFLOAT16 = 18, + INT2 = 19, } // Custom quantization parameters for experimenting with new quantization @@ -262,6 +265,11 @@ table Tensor { // Currently only 1 subtype is supported. The field is defined as an array for // flexibility of supporting multiple subtypes in the future. variant_tensors:[VariantSubType]; + + // Optional reference to an ExternalBuffer entry that stores constant tensor + // data outside of the FlatBuffer. A value of 0 indicates that the tensor uses + // the traditional embedded buffer field instead. + external_buffer:uint; } // A list of builtin operators. Builtin operators are slightly faster than custom @@ -1612,6 +1620,22 @@ table Buffer { size: ulong; } +// Groups external buffers by file/URI. +table ExternalBufferGroup { + name:string; +} + +// Describes an immutable data slice stored in an external file. +table ExternalBuffer { + // Unique identifier for this external buffer. + id:uint; + // Index into the external_buffer_groups array. + group:uint; + offset:ulong; + length:ulong; + packing:string; +} + table Metadata { // A human readable string to uniquely identify a Metadata. name:string; @@ -1679,6 +1703,12 @@ table Model { // Optional SignatureDefs for the model. signature_defs:[SignatureDef]; + + // Optional groups for external weight buffers. + external_buffer_groups:[ExternalBufferGroup]; + + // Optional list of external weight buffers referenced by tensors. + external_buffers:[ExternalBuffer]; } root_type Model; diff --git a/tensorflow/compiler/mlir/lite/schema/schema_generated.h b/tensorflow/compiler/mlir/lite/schema/schema_generated.h index 43d51c40f01b8a..2b1701a8b9c0b9 100755 --- a/tensorflow/compiler/mlir/lite/schema/schema_generated.h +++ b/tensorflow/compiler/mlir/lite/schema/schema_generated.h @@ -681,6 +681,14 @@ struct Buffer; struct BufferBuilder; struct BufferT; +struct ExternalBufferGroup; +struct ExternalBufferGroupBuilder; +struct ExternalBufferGroupT; + +struct ExternalBuffer; +struct ExternalBufferBuilder; +struct ExternalBufferT; + struct Metadata; struct MetadataBuilder; struct MetadataT; @@ -717,11 +725,12 @@ enum TensorType : int8_t { TensorType_UINT16 = 16, TensorType_INT4 = 17, TensorType_BFLOAT16 = 18, + TensorType_INT2 = 19, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_BFLOAT16 + TensorType_MAX = TensorType_INT2 }; -inline const TensorType (&EnumValuesTensorType())[19] { +inline const TensorType (&EnumValuesTensorType())[20] { static const TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -741,13 +750,14 @@ inline const TensorType (&EnumValuesTensorType())[19] { TensorType_UINT32, TensorType_UINT16, TensorType_INT4, - TensorType_BFLOAT16 + TensorType_BFLOAT16, + TensorType_INT2 }; return values; } inline const char * const *EnumNamesTensorType() { - static const char * const names[20] = { + static const char * const names[21] = { "FLOAT32", "FLOAT16", "INT32", @@ -767,13 +777,14 @@ inline const char * const *EnumNamesTensorType() { "UINT16", "INT4", "BFLOAT16", + "INT2", nullptr }; return names; } inline const char *EnumNameTensorType(TensorType e) { - if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_BFLOAT16)) return ""; + if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT2)) return ""; const size_t index = static_cast(e); return EnumNamesTensorType()[index]; } @@ -5949,6 +5960,7 @@ struct TensorT : public ::flatbuffers::NativeTable { std::vector shape_signature{}; bool has_rank = false; std::vector> variant_tensors{}; + uint32_t external_buffer = 0; TensorT() = default; TensorT(const TensorT &o); TensorT(TensorT&&) FLATBUFFERS_NOEXCEPT = default; @@ -5968,7 +5980,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_SPARSITY = 16, VT_SHAPE_SIGNATURE = 18, VT_HAS_RANK = 20, - VT_VARIANT_TENSORS = 22 + VT_VARIANT_TENSORS = 22, + VT_EXTERNAL_BUFFER = 24 }; const ::flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -6000,6 +6013,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector<::flatbuffers::Offset> *variant_tensors() const { return GetPointer> *>(VT_VARIANT_TENSORS); } + uint32_t external_buffer() const { + return GetField(VT_EXTERNAL_BUFFER, 0); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -6019,6 +6035,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyOffset(verifier, VT_VARIANT_TENSORS) && verifier.VerifyVector(variant_tensors()) && verifier.VerifyVectorOfTables(variant_tensors()) && + VerifyField(verifier, VT_EXTERNAL_BUFFER, 4) && verifier.EndTable(); } TensorT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -6060,6 +6077,9 @@ struct TensorBuilder { void add_variant_tensors(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> variant_tensors) { fbb_.AddOffset(Tensor::VT_VARIANT_TENSORS, variant_tensors); } + void add_external_buffer(uint32_t external_buffer) { + fbb_.AddElement(Tensor::VT_EXTERNAL_BUFFER, external_buffer, 0); + } explicit TensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -6082,8 +6102,10 @@ inline ::flatbuffers::Offset CreateTensor( ::flatbuffers::Offset sparsity = 0, ::flatbuffers::Offset<::flatbuffers::Vector> shape_signature = 0, bool has_rank = false, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> variant_tensors = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> variant_tensors = 0, + uint32_t external_buffer = 0) { TensorBuilder builder_(_fbb); + builder_.add_external_buffer(external_buffer); builder_.add_variant_tensors(variant_tensors); builder_.add_shape_signature(shape_signature); builder_.add_sparsity(sparsity); @@ -6108,7 +6130,8 @@ inline ::flatbuffers::Offset CreateTensorDirect( ::flatbuffers::Offset sparsity = 0, const std::vector *shape_signature = nullptr, bool has_rank = false, - const std::vector<::flatbuffers::Offset> *variant_tensors = nullptr) { + const std::vector<::flatbuffers::Offset> *variant_tensors = nullptr, + uint32_t external_buffer = 0) { auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; auto name__ = name ? _fbb.CreateString(name) : 0; auto shape_signature__ = shape_signature ? _fbb.CreateVector(*shape_signature) : 0; @@ -6124,7 +6147,8 @@ inline ::flatbuffers::Offset CreateTensorDirect( sparsity, shape_signature__, has_rank, - variant_tensors__); + variant_tensors__, + external_buffer); } ::flatbuffers::Offset CreateTensor(::flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -16528,6 +16552,182 @@ inline ::flatbuffers::Offset CreateBufferDirect( ::flatbuffers::Offset CreateBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ExternalBufferGroupT : public ::flatbuffers::NativeTable { + typedef ExternalBufferGroup TableType; + std::string name{}; +}; + +struct ExternalBufferGroup FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ExternalBufferGroupT NativeTableType; + typedef ExternalBufferGroupBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4 + }; + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + verifier.EndTable(); + } + ExternalBufferGroupT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExternalBufferGroupT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferGroupT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExternalBufferGroupBuilder { + typedef ExternalBufferGroup Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(ExternalBufferGroup::VT_NAME, name); + } + explicit ExternalBufferGroupBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateExternalBufferGroup( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0) { + ExternalBufferGroupBuilder builder_(_fbb); + builder_.add_name(name); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateExternalBufferGroupDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return tflite::CreateExternalBufferGroup( + _fbb, + name__); +} + +::flatbuffers::Offset CreateExternalBufferGroup(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferGroupT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ExternalBufferT : public ::flatbuffers::NativeTable { + typedef ExternalBuffer TableType; + uint32_t id = 0; + uint32_t group = 0; + uint64_t offset = 0; + uint64_t length = 0; + std::string packing{}; +}; + +struct ExternalBuffer FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ExternalBufferT NativeTableType; + typedef ExternalBufferBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ID = 4, + VT_GROUP = 6, + VT_OFFSET = 8, + VT_LENGTH = 10, + VT_PACKING = 12 + }; + uint32_t id() const { + return GetField(VT_ID, 0); + } + uint32_t group() const { + return GetField(VT_GROUP, 0); + } + uint64_t offset() const { + return GetField(VT_OFFSET, 0); + } + uint64_t length() const { + return GetField(VT_LENGTH, 0); + } + const ::flatbuffers::String *packing() const { + return GetPointer(VT_PACKING); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ID, 4) && + VerifyField(verifier, VT_GROUP, 4) && + VerifyField(verifier, VT_OFFSET, 8) && + VerifyField(verifier, VT_LENGTH, 8) && + VerifyOffset(verifier, VT_PACKING) && + verifier.VerifyString(packing()) && + verifier.EndTable(); + } + ExternalBufferT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExternalBufferT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExternalBufferBuilder { + typedef ExternalBuffer Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_id(uint32_t id) { + fbb_.AddElement(ExternalBuffer::VT_ID, id, 0); + } + void add_group(uint32_t group) { + fbb_.AddElement(ExternalBuffer::VT_GROUP, group, 0); + } + void add_offset(uint64_t offset) { + fbb_.AddElement(ExternalBuffer::VT_OFFSET, offset, 0); + } + void add_length(uint64_t length) { + fbb_.AddElement(ExternalBuffer::VT_LENGTH, length, 0); + } + void add_packing(::flatbuffers::Offset<::flatbuffers::String> packing) { + fbb_.AddOffset(ExternalBuffer::VT_PACKING, packing); + } + explicit ExternalBufferBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateExternalBuffer( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t id = 0, + uint32_t group = 0, + uint64_t offset = 0, + uint64_t length = 0, + ::flatbuffers::Offset<::flatbuffers::String> packing = 0) { + ExternalBufferBuilder builder_(_fbb); + builder_.add_length(length); + builder_.add_offset(offset); + builder_.add_packing(packing); + builder_.add_group(group); + builder_.add_id(id); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateExternalBufferDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t id = 0, + uint32_t group = 0, + uint64_t offset = 0, + uint64_t length = 0, + const char *packing = nullptr) { + auto packing__ = packing ? _fbb.CreateString(packing) : 0; + return tflite::CreateExternalBuffer( + _fbb, + id, + group, + offset, + length, + packing__); +} + +::flatbuffers::Offset CreateExternalBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct MetadataT : public ::flatbuffers::NativeTable { typedef Metadata TableType; std::string name{}; @@ -16799,6 +16999,8 @@ struct ModelT : public ::flatbuffers::NativeTable { std::vector metadata_buffer{}; std::vector> metadata{}; std::vector> signature_defs{}; + std::vector> external_buffer_groups{}; + std::vector> external_buffers{}; ModelT() = default; ModelT(const ModelT &o); ModelT(ModelT&&) FLATBUFFERS_NOEXCEPT = default; @@ -16816,7 +17018,9 @@ struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_BUFFERS = 12, VT_METADATA_BUFFER = 14, VT_METADATA = 16, - VT_SIGNATURE_DEFS = 18 + VT_SIGNATURE_DEFS = 18, + VT_EXTERNAL_BUFFER_GROUPS = 20, + VT_EXTERNAL_BUFFERS = 22 }; uint32_t version() const { return GetField(VT_VERSION, 0); @@ -16842,6 +17046,12 @@ struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector<::flatbuffers::Offset> *signature_defs() const { return GetPointer> *>(VT_SIGNATURE_DEFS); } + const ::flatbuffers::Vector<::flatbuffers::Offset> *external_buffer_groups() const { + return GetPointer> *>(VT_EXTERNAL_BUFFER_GROUPS); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *external_buffers() const { + return GetPointer> *>(VT_EXTERNAL_BUFFERS); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_VERSION, 4) && @@ -16864,6 +17074,12 @@ struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyOffset(verifier, VT_SIGNATURE_DEFS) && verifier.VerifyVector(signature_defs()) && verifier.VerifyVectorOfTables(signature_defs()) && + VerifyOffset(verifier, VT_EXTERNAL_BUFFER_GROUPS) && + verifier.VerifyVector(external_buffer_groups()) && + verifier.VerifyVectorOfTables(external_buffer_groups()) && + VerifyOffset(verifier, VT_EXTERNAL_BUFFERS) && + verifier.VerifyVector(external_buffers()) && + verifier.VerifyVectorOfTables(external_buffers()) && verifier.EndTable(); } ModelT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -16899,6 +17115,12 @@ struct ModelBuilder { void add_signature_defs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> signature_defs) { fbb_.AddOffset(Model::VT_SIGNATURE_DEFS, signature_defs); } + void add_external_buffer_groups(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> external_buffer_groups) { + fbb_.AddOffset(Model::VT_EXTERNAL_BUFFER_GROUPS, external_buffer_groups); + } + void add_external_buffers(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> external_buffers) { + fbb_.AddOffset(Model::VT_EXTERNAL_BUFFERS, external_buffers); + } explicit ModelBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -16919,8 +17141,12 @@ inline ::flatbuffers::Offset CreateModel( ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> buffers = 0, ::flatbuffers::Offset<::flatbuffers::Vector> metadata_buffer = 0, ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> metadata = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> signature_defs = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> signature_defs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> external_buffer_groups = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> external_buffers = 0) { ModelBuilder builder_(_fbb); + builder_.add_external_buffers(external_buffers); + builder_.add_external_buffer_groups(external_buffer_groups); builder_.add_signature_defs(signature_defs); builder_.add_metadata(metadata); builder_.add_metadata_buffer(metadata_buffer); @@ -16941,7 +17167,9 @@ inline ::flatbuffers::Offset CreateModelDirect( const std::vector<::flatbuffers::Offset> *buffers = nullptr, const std::vector *metadata_buffer = nullptr, const std::vector<::flatbuffers::Offset> *metadata = nullptr, - const std::vector<::flatbuffers::Offset> *signature_defs = nullptr) { + const std::vector<::flatbuffers::Offset> *signature_defs = nullptr, + const std::vector<::flatbuffers::Offset> *external_buffer_groups = nullptr, + const std::vector<::flatbuffers::Offset> *external_buffers = nullptr) { auto operator_codes__ = operator_codes ? _fbb.CreateVector<::flatbuffers::Offset>(*operator_codes) : 0; auto subgraphs__ = subgraphs ? _fbb.CreateVector<::flatbuffers::Offset>(*subgraphs) : 0; auto description__ = description ? _fbb.CreateString(description) : 0; @@ -16949,6 +17177,8 @@ inline ::flatbuffers::Offset CreateModelDirect( auto metadata_buffer__ = metadata_buffer ? _fbb.CreateVector(*metadata_buffer) : 0; auto metadata__ = metadata ? _fbb.CreateVector<::flatbuffers::Offset>(*metadata) : 0; auto signature_defs__ = signature_defs ? _fbb.CreateVector<::flatbuffers::Offset>(*signature_defs) : 0; + auto external_buffer_groups__ = external_buffer_groups ? _fbb.CreateVector<::flatbuffers::Offset>(*external_buffer_groups) : 0; + auto external_buffers__ = external_buffers ? _fbb.CreateVector<::flatbuffers::Offset>(*external_buffers) : 0; return tflite::CreateModel( _fbb, version, @@ -16958,7 +17188,9 @@ inline ::flatbuffers::Offset CreateModelDirect( buffers__, metadata_buffer__, metadata__, - signature_defs__); + signature_defs__, + external_buffer_groups__, + external_buffers__); } ::flatbuffers::Offset CreateModel(::flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -17212,7 +17444,7 @@ inline void SparsityParameters::UnPackTo(SparsityParametersT *_o, const ::flatbu (void)_resolver; { auto _e = traversal_order(); if (_e) { _o->traversal_order.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->traversal_order[_i] = _e->Get(_i); } } else { _o->traversal_order.resize(0); } } { auto _e = block_map(); if (_e) { _o->block_map.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->block_map[_i] = _e->Get(_i); } } else { _o->block_map.resize(0); } } - { auto _e = dim_metadata(); if (_e) { _o->dim_metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->dim_metadata[_i]) { _e->Get(_i)->UnPackTo(_o->dim_metadata[_i].get(), _resolver); } else { _o->dim_metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->dim_metadata.resize(0); } } + { auto _e = dim_metadata(); if (_e) { _o->dim_metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->dim_metadata[_i]) { _e->Get(_i)->UnPackTo(_o->dim_metadata[_i].get(), _resolver); } else { _o->dim_metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->dim_metadata.resize(0); } } } inline ::flatbuffers::Offset SparsityParameters::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityParametersT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -17274,7 +17506,8 @@ inline TensorT::TensorT(const TensorT &o) is_variable(o.is_variable), sparsity((o.sparsity) ? new tflite::SparsityParametersT(*o.sparsity) : nullptr), shape_signature(o.shape_signature), - has_rank(o.has_rank) { + has_rank(o.has_rank), + external_buffer(o.external_buffer) { variant_tensors.reserve(o.variant_tensors.size()); for (const auto &variant_tensors_ : o.variant_tensors) { variant_tensors.emplace_back((variant_tensors_) ? new tflite::VariantSubTypeT(*variant_tensors_) : nullptr); } } @@ -17290,6 +17523,7 @@ inline TensorT &TensorT::operator=(TensorT o) FLATBUFFERS_NOEXCEPT { std::swap(shape_signature, o.shape_signature); std::swap(has_rank, o.has_rank); std::swap(variant_tensors, o.variant_tensors); + std::swap(external_buffer, o.external_buffer); return *this; } @@ -17311,7 +17545,8 @@ inline void Tensor::UnPackTo(TensorT *_o, const ::flatbuffers::resolver_function { auto _e = sparsity(); if (_e) { if(_o->sparsity) { _e->UnPackTo(_o->sparsity.get(), _resolver); } else { _o->sparsity = std::unique_ptr(_e->UnPack(_resolver)); } } else if (_o->sparsity) { _o->sparsity.reset(); } } { auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } else { _o->shape_signature.resize(0); } } { auto _e = has_rank(); _o->has_rank = _e; } - { auto _e = variant_tensors(); if (_e) { _o->variant_tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->variant_tensors[_i]) { _e->Get(_i)->UnPackTo(_o->variant_tensors[_i].get(), _resolver); } else { _o->variant_tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->variant_tensors.resize(0); } } + { auto _e = variant_tensors(); if (_e) { _o->variant_tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->variant_tensors[_i]) { _e->Get(_i)->UnPackTo(_o->variant_tensors[_i].get(), _resolver); } else { _o->variant_tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->variant_tensors.resize(0); } } + { auto _e = external_buffer(); _o->external_buffer = _e; } } inline ::flatbuffers::Offset Tensor::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -17332,6 +17567,7 @@ inline ::flatbuffers::Offset CreateTensor(::flatbuffers::FlatBufferBuild auto _shape_signature = _o->shape_signature.size() ? _fbb.CreateVector(_o->shape_signature) : 0; auto _has_rank = _o->has_rank; auto _variant_tensors = _o->variant_tensors.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->variant_tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateVariantSubType(*__va->__fbb, __va->__o->variant_tensors[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _external_buffer = _o->external_buffer; return tflite::CreateTensor( _fbb, _shape, @@ -17343,7 +17579,8 @@ inline ::flatbuffers::Offset CreateTensor(::flatbuffers::FlatBufferBuild _sparsity, _shape_signature, _has_rank, - _variant_tensors); + _variant_tensors, + _external_buffer); } inline StablehloGatherOptionsT *StablehloGatherOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { @@ -21572,10 +21809,10 @@ inline SubGraphT *SubGraph::UnPack(const ::flatbuffers::resolver_function_t *_re inline void SubGraph::UnPackTo(SubGraphT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->tensors[_i]) { _e->Get(_i)->UnPackTo(_o->tensors[_i].get(), _resolver); } else { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->tensors.resize(0); } } + { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->tensors[_i]) { _e->Get(_i)->UnPackTo(_o->tensors[_i].get(), _resolver); } else { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->tensors.resize(0); } } { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } else { _o->inputs.resize(0); } } { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } else { _o->outputs.resize(0); } } - { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operators[_i]) { _e->Get(_i)->UnPackTo(_o->operators[_i].get(), _resolver); } else { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->operators.resize(0); } } + { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operators[_i]) { _e->Get(_i)->UnPackTo(_o->operators[_i].get(), _resolver); } else { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->operators.resize(0); } } { auto _e = name(); if (_e) _o->name = _e->str(); } { auto _e = debug_metadata_index(); _o->debug_metadata_index = _e; } } @@ -21637,6 +21874,70 @@ inline ::flatbuffers::Offset CreateBuffer(::flatbuffers::FlatBufferBuild _size); } +inline ExternalBufferGroupT *ExternalBufferGroup::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ExternalBufferGroupT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ExternalBufferGroup::UnPackTo(ExternalBufferGroupT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = name(); if (_e) _o->name = _e->str(); } +} + +inline ::flatbuffers::Offset ExternalBufferGroup::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferGroupT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateExternalBufferGroup(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateExternalBufferGroup(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferGroupT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ExternalBufferGroupT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + return tflite::CreateExternalBufferGroup( + _fbb, + _name); +} + +inline ExternalBufferT *ExternalBuffer::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ExternalBufferT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ExternalBuffer::UnPackTo(ExternalBufferT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = id(); _o->id = _e; } + { auto _e = group(); _o->group = _e; } + { auto _e = offset(); _o->offset = _e; } + { auto _e = length(); _o->length = _e; } + { auto _e = packing(); if (_e) _o->packing = _e->str(); } +} + +inline ::flatbuffers::Offset ExternalBuffer::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateExternalBuffer(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateExternalBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ExternalBufferT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _id = _o->id; + auto _group = _o->group; + auto _offset = _o->offset; + auto _length = _o->length; + auto _packing = _o->packing.empty() ? 0 : _fbb.CreateString(_o->packing); + return tflite::CreateExternalBuffer( + _fbb, + _id, + _group, + _offset, + _length, + _packing); +} + inline MetadataT *Metadata::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { auto _o = std::unique_ptr(new MetadataT()); UnPackTo(_o.get(), _resolver); @@ -21721,8 +22022,8 @@ inline SignatureDefT *SignatureDef::UnPack(const ::flatbuffers::resolver_functio inline void SignatureDef::UnPackTo(SignatureDefT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->inputs[_i]) { _e->Get(_i)->UnPackTo(_o->inputs[_i].get(), _resolver); } else { _o->inputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->inputs.resize(0); } } - { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->outputs[_i]) { _e->Get(_i)->UnPackTo(_o->outputs[_i].get(), _resolver); } else { _o->outputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->outputs.resize(0); } } + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->inputs[_i]) { _e->Get(_i)->UnPackTo(_o->inputs[_i].get(), _resolver); } else { _o->inputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->inputs.resize(0); } } + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->outputs[_i]) { _e->Get(_i)->UnPackTo(_o->outputs[_i].get(), _resolver); } else { _o->outputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->outputs.resize(0); } } { auto _e = signature_key(); if (_e) _o->signature_key = _e->str(); } { auto _e = subgraph_index(); _o->subgraph_index = _e; } } @@ -21761,6 +22062,10 @@ inline ModelT::ModelT(const ModelT &o) for (const auto &metadata_ : o.metadata) { metadata.emplace_back((metadata_) ? new tflite::MetadataT(*metadata_) : nullptr); } signature_defs.reserve(o.signature_defs.size()); for (const auto &signature_defs_ : o.signature_defs) { signature_defs.emplace_back((signature_defs_) ? new tflite::SignatureDefT(*signature_defs_) : nullptr); } + external_buffer_groups.reserve(o.external_buffer_groups.size()); + for (const auto &external_buffer_groups_ : o.external_buffer_groups) { external_buffer_groups.emplace_back((external_buffer_groups_) ? new tflite::ExternalBufferGroupT(*external_buffer_groups_) : nullptr); } + external_buffers.reserve(o.external_buffers.size()); + for (const auto &external_buffers_ : o.external_buffers) { external_buffers.emplace_back((external_buffers_) ? new tflite::ExternalBufferT(*external_buffers_) : nullptr); } } inline ModelT &ModelT::operator=(ModelT o) FLATBUFFERS_NOEXCEPT { @@ -21772,6 +22077,8 @@ inline ModelT &ModelT::operator=(ModelT o) FLATBUFFERS_NOEXCEPT { std::swap(metadata_buffer, o.metadata_buffer); std::swap(metadata, o.metadata); std::swap(signature_defs, o.signature_defs); + std::swap(external_buffer_groups, o.external_buffer_groups); + std::swap(external_buffers, o.external_buffers); return *this; } @@ -21785,13 +22092,15 @@ inline void Model::UnPackTo(ModelT *_o, const ::flatbuffers::resolver_function_t (void)_o; (void)_resolver; { auto _e = version(); _o->version = _e; } - { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operator_codes[_i]) { _e->Get(_i)->UnPackTo(_o->operator_codes[_i].get(), _resolver); } else { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->operator_codes.resize(0); } } - { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->subgraphs[_i]) { _e->Get(_i)->UnPackTo(_o->subgraphs[_i].get(), _resolver); } else { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->subgraphs.resize(0); } } + { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operator_codes[_i]) { _e->Get(_i)->UnPackTo(_o->operator_codes[_i].get(), _resolver); } else { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->operator_codes.resize(0); } } + { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->subgraphs[_i]) { _e->Get(_i)->UnPackTo(_o->subgraphs[_i].get(), _resolver); } else { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->subgraphs.resize(0); } } { auto _e = description(); if (_e) _o->description = _e->str(); } - { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->buffers[_i]) { _e->Get(_i)->UnPackTo(_o->buffers[_i].get(), _resolver); } else { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->buffers.resize(0); } } + { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->buffers[_i]) { _e->Get(_i)->UnPackTo(_o->buffers[_i].get(), _resolver); } else { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->buffers.resize(0); } } { auto _e = metadata_buffer(); if (_e) { _o->metadata_buffer.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata_buffer[_i] = _e->Get(_i); } } else { _o->metadata_buffer.resize(0); } } - { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->metadata[_i]) { _e->Get(_i)->UnPackTo(_o->metadata[_i].get(), _resolver); } else { _o->metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->metadata.resize(0); } } - { auto _e = signature_defs(); if (_e) { _o->signature_defs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->signature_defs[_i]) { _e->Get(_i)->UnPackTo(_o->signature_defs[_i].get(), _resolver); } else { _o->signature_defs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->signature_defs.resize(0); } } + { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->metadata[_i]) { _e->Get(_i)->UnPackTo(_o->metadata[_i].get(), _resolver); } else { _o->metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->metadata.resize(0); } } + { auto _e = signature_defs(); if (_e) { _o->signature_defs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->signature_defs[_i]) { _e->Get(_i)->UnPackTo(_o->signature_defs[_i].get(), _resolver); } else { _o->signature_defs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->signature_defs.resize(0); } } + { auto _e = external_buffer_groups(); if (_e) { _o->external_buffer_groups.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->external_buffer_groups[_i]) { _e->Get(_i)->UnPackTo(_o->external_buffer_groups[_i].get(), _resolver); } else { _o->external_buffer_groups[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->external_buffer_groups.resize(0); } } + { auto _e = external_buffers(); if (_e) { _o->external_buffers.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->external_buffers[_i]) { _e->Get(_i)->UnPackTo(_o->external_buffers[_i].get(), _resolver); } else { _o->external_buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->external_buffers.resize(0); } } } inline ::flatbuffers::Offset Model::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -21810,6 +22119,8 @@ inline ::flatbuffers::Offset CreateModel(::flatbuffers::FlatBufferBuilder auto _metadata_buffer = _o->metadata_buffer.size() ? _fbb.CreateVector(_o->metadata_buffer) : 0; auto _metadata = _o->metadata.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateMetadata(*__va->__fbb, __va->__o->metadata[i].get(), __va->__rehasher); }, &_va ) : 0; auto _signature_defs = _o->signature_defs.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->signature_defs.size(), [](size_t i, _VectorArgs *__va) { return CreateSignatureDef(*__va->__fbb, __va->__o->signature_defs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _external_buffer_groups = _o->external_buffer_groups.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->external_buffer_groups.size(), [](size_t i, _VectorArgs *__va) { return CreateExternalBufferGroup(*__va->__fbb, __va->__o->external_buffer_groups[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _external_buffers = _o->external_buffers.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->external_buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateExternalBuffer(*__va->__fbb, __va->__o->external_buffers[i].get(), __va->__rehasher); }, &_va ) : 0; return tflite::CreateModel( _fbb, _version, @@ -21819,7 +22130,9 @@ inline ::flatbuffers::Offset CreateModel(::flatbuffers::FlatBufferBuilder _buffers, _metadata_buffer, _metadata, - _signature_defs); + _signature_defs, + _external_buffer_groups, + _external_buffers); } inline bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj, QuantizationDetails type) { diff --git a/tensorflow/compiler/mlir/lite/schema/schema_utils.cc b/tensorflow/compiler/mlir/lite/schema/schema_utils.cc index a173380940d600..cb61ce6243f3ad 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema_utils.cc +++ b/tensorflow/compiler/mlir/lite/schema/schema_utils.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include +#include +#include +#include #include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" namespace tflite { @@ -59,4 +63,51 @@ BuiltinOperator GetBuiltinCode(const OperatorCodeT* op_code) { op_code->deprecated_builtin_code)); } +size_t TensorTypeGetSize(::tflite::TensorType data_type) { + switch (data_type) { + case ::tflite::TensorType_FLOAT32: + static_assert(sizeof(float) == 4, ""); + return 4; + case ::tflite::TensorType_FLOAT16: + static_assert(sizeof(int16_t) == 2, ""); + return 2; + case ::tflite::TensorType_INT32: + static_assert(sizeof(int32_t) == 4, ""); + return 4; + case ::tflite::TensorType_UINT8: + static_assert(sizeof(uint8_t) == 1, ""); + return 1; + case ::tflite::TensorType_INT64: + static_assert(sizeof(int64_t) == 8, ""); + return 8; + case ::tflite::TensorType_BOOL: + return sizeof(bool); + case ::tflite::TensorType_INT16: + static_assert(sizeof(int16_t) == 2, ""); + return 2; + case ::tflite::TensorType_COMPLEX64: + static_assert(sizeof(std::complex) == 8, ""); + return 8; + case ::tflite::TensorType_INT8: + static_assert(sizeof(int8_t) == 1, ""); + return 1; + case ::tflite::TensorType_FLOAT64: + static_assert(sizeof(double) == 8, ""); + return 8; + case ::tflite::TensorType_COMPLEX128: + static_assert(sizeof(std::complex) == 16, ""); + return 16; + case ::tflite::TensorType_UINT64: + static_assert(sizeof(uint64_t) == 8, ""); + return 8; + case ::tflite::TensorType_UINT32: + static_assert(sizeof(uint32_t) == 4, ""); + return 4; + case ::tflite::TensorType_UINT16: + static_assert(sizeof(uint16_t) == 2, ""); + return 2; + default: + return 0; + } +} } // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/schema/schema_utils.h b/tensorflow/compiler/mlir/lite/schema/schema_utils.h index 7498aa02ebe5c2..9c32680b85117f 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema_utils.h +++ b/tensorflow/compiler/mlir/lite/schema/schema_utils.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ +#include + #include "flatbuffers/flatbuffers.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" @@ -28,6 +30,11 @@ BuiltinOperator GetBuiltinCode(const OperatorCode *op_code); BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code); +// Returns the size of the given TensorType in bytes, or 0 if the TensorType is +// not supported, this function should be aligned with TfLiteTypeGetSize in +// lite/kernels/kernel_util.h. +size_t TensorTypeGetSize(::tflite::TensorType data_type); + } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 43fada7b0d0b62..cd553040786c72 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -539,6 +539,7 @@ cc_library( ":passes_inc_gen", ":unfold_splat_constant_pass", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:case", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:custom_call", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:dot_general", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index ae672381bacafd..9a0a185443ebc0 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -3073,6 +3073,13 @@ func.func @convert_iota_ui64() -> tensor<123xui64> { func.return %0 : tensor<123xui64> } +// CHECK-LABEL: func @no_convert_iota_ui8 +func.func @no_convert_iota_ui8() -> tensor<123xui8> { + // CHECK: "mhlo.iota" + %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<123xui8> + func.return %0 : tensor<123xui8> +} + // CHECK-LABEL: func @convert_avgpool_valid( // CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { // CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index a77d02e78c1dce..1d8a63130ac1d9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -3721,14 +3721,43 @@ func.func @dynamic_broadcast_in_dim_general_case_expand_back_dims(%arg0: tensor< // CHECK: %2 = "tfl.broadcast_to"(%1, %arg1) : (tensor, tensor<4xi32>) -> tensor +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.case +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: case_func +func.func @case_func(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { + %0 = "mhlo.case"(%arg0) ({ + %2 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + %2 = mhlo.multiply %arg1, %arg1 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} + +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor +// CHECK: %[[PRED:.*]] = tfl.not_equal(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: %[[IF:.*]] = "tfl.if"(%[[PRED]]) ({ +// CHECK: %[[MUL:.*]] = tfl.mul %arg1, %arg1 {fused_activation_function = "NONE"} : tensor +// CHECK: "tfl.yield"(%[[MUL]]) : (tensor) -> () +// CHECK: }, { +// CHECK: %[[ADD:.*]] = tfl.add %arg1, %arg2 {fused_activation_function = "NONE"} : tensor +// CHECK: "tfl.yield"(%[[ADD]]) : (tensor) -> () +// CHECK: }) : (tensor) -> tensor +// CHECK: return %[[IF]] : tensor + // ----- //===----------------------------------------------------------------------===// // mhlo.if //===----------------------------------------------------------------------===// -// CHECK-LABEL: if -func.func @if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { +// CHECK-LABEL: if_label +func.func @if_label(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { %0 = mhlo.add %arg1, %arg2 : tensor %1 = "mhlo.if"(%arg0) ({ "mhlo.return"(%0) : (tensor) -> () diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 3891d0f3fe4db3..7608ff985f1eb9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -2081,8 +2081,10 @@ class ConvertIotaOpToTfRange : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { RankedTensorType type = mlir::dyn_cast_or_null(iota_op.getType()); - // TF::RangeOp doesn't support UI16. - if (!type || type.getElementType().isUnsignedInteger(16)) return failure(); + // TF::RangeOp doesn't support UI16 and UI8. + if (!type || type.getElementType().isUnsignedInteger(16) || + type.getElementType().isUnsignedInteger(8)) + return failure(); const uint64_t dimension = iota_op.getIotaDimension(); Type element_type = type.getElementType(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index 9e2f1cf33f495f..16c194df28f591 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -320,6 +320,21 @@ cc_library( ], ) +cc_library( + name = "case", + srcs = ["case.cc"], + hdrs = ["case.h"], + deps = [ + ":util", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) + cc_library( name = "if", srcs = ["if.cc"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.cc new file mode 100644 index 00000000000000..b50a5e7fd83195 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.cc @@ -0,0 +1,100 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { +namespace { + +// Legalizes mhlo.case op to tfl.if op. +// This pattern only supports mhlo.case ops with exactly two branches. +class LegalizeCaseOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::CaseOp case_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + // mhlo.case can have N branches, but tfl.if only supports two. + if (case_op.getBranches().size() != 2) { + return rewriter.notifyMatchFailure( + case_op, "can only convert mhlo.case with 2 branches"); + } + + // `mhlo.case` takes an index, `tfl.if` takes a boolean predicate. + // For a 2-branch `mhlo.case` (branch 0 and branch 1), we need to map + // the index to a boolean. + // According to the mhlo.case spec, an out-of-bounds index defaults to the + // index of the last branch, which is 1 in this case. + // So, index 0 maps to branch 0, and any other index (1, or out of bounds) + // maps to branch 1. + // This can be expressed as a predicate `index != 0` for branch 1. + + auto loc = case_op->getLoc(); + auto index = case_op.getIndex(); + auto index_type = mlir::cast(index.getType()); + + // Create a constant tensor of the same shape as the index, filled with + // zeros. + auto const_zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(index_type)); + + // Create the predicate `index != 0`. + auto pred_type = index_type.clone(rewriter.getI1Type()); + auto pred = mhlo::CompareOp::create( + rewriter, loc, pred_type, index, const_zero, + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::NE), + mhlo::ComparisonTypeAttr{}); // Default comparison type is fine for + // integers. + + // Create the tfl.if op. + auto tfl_if = + TFL::IfOp::create(rewriter, loc, case_op.getResultTypes(), pred); + + // Branch 1 of mhlo.case becomes the `then_region` of tfl.if. + tfl_if.getThenRegion().takeBody(case_op.getBranches()[1]); + ReplaceTerminatorWithYield(tfl_if.getThenRegion(), rewriter); + + // Branch 0 of mhlo.case becomes the `else_region` of tfl.if. + tfl_if.getElseRegion().takeBody(case_op.getBranches()[0]); + ReplaceTerminatorWithYield(tfl_if.getElseRegion(), rewriter); + + rewriter.replaceOp(case_op, tfl_if.getResults()); + return success(); + } +}; + +} // namespace + +void PopulateCasePatterns(MLIRContext* context, RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(context); + // Mark mhlo.case as dynamically legal: it's legal if it does NOT have + // exactly 2 branches, as those are the ones we want to convert. + target.addDynamicallyLegalOp( + [](mhlo::CaseOp op) { return op.getBranches().size() != 2; }); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h new file mode 100644 index 00000000000000..11c470a1492630 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h @@ -0,0 +1,31 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace odml { + +void PopulateCasePatterns(MLIRContext* context, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index 9518b960f17442..0c43a5c4047a64 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h" // IWYU pragma: keep @@ -479,6 +480,7 @@ void LegalizeHloToTfLitePass::runOnOperation() { PopulateWhilePatterns(context, patterns, target); PopulateGetDimensionSizePatterns(context, patterns, target); PopulateIfPatterns(context, patterns, target); + PopulateCasePatterns(context, patterns, target); PopulateLegalizeFftPatterns(context, patterns, target); PopulateCustomCallPatterns(context, patterns, target); @@ -493,7 +495,6 @@ void LegalizeHloToTfLitePass::runOnOperation() { } // namespace - // Creates an instance of the pass. std::unique_ptr> CreateLegalizeHloToTfLitePass() { return std::make_unique(); diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 6043e26cb757d8..5bc6bef17fe360 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -261,7 +261,7 @@ func.func @mul_one_quant(%arg0: tensor<32x!quant.uniform>) -> tenso // CHECK-LABEL: @elementwise_unary_ops -func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor) { +func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) { %0 = arith.constant dense<-1.0> : tensor %1 = arith.constant dense<1.0> : tensor %2 = arith.constant dense<1.0> : tensor @@ -269,6 +269,7 @@ func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, te %4 = arith.constant dense<4.0> : tensor %5 = arith.constant dense<4.0> : tensor %6 = arith.constant dense<2.0> : tensor + %one = arith.constant dense<1.0> : tensor // CHECK-DAG: [[cst0:%.*]] = arith.constant dense<1.000000e+00> : tensor // CHECK-DAG: [[cst1:%.*]] = arith.constant dense<0.841470957> : tensor @@ -277,7 +278,8 @@ func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, te // CHECK-DAG: [[cst4:%.*]] = arith.constant dense<2.000000e+00> : tensor // CHECK-DAG: [[cst5:%.*]] = arith.constant dense<5.000000e-01> : tensor // CHECK-DAG: [[cst6:%.*]] = arith.constant dense<4.000000e+00> : tensor - // CHECK: return [[cst0]], [[cst1]], [[cst2]], [[cst3]], [[cst4]], [[cst5]], [[cst6]] + // CHECK-DAG: [[cst7:%.*]] = arith.constant dense<0.761594176> : tensor + // CHECK: return [[cst0]], [[cst1]], [[cst2]], [[cst3]], [[cst4]], [[cst5]], [[cst6]], [[cst7]] %7 = "tfl.abs"(%0) : (tensor) -> tensor %8 = "tfl.sin"(%1) : (tensor) -> tensor @@ -286,8 +288,9 @@ func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, te %11 = "tfl.sqrt"(%4) : (tensor) -> tensor %12 = "tfl.rsqrt"(%5) : (tensor) -> tensor %13 = "tfl.square"(%6) : (tensor) -> tensor + %14 = "tfl.tanh"(%one) : (tensor) -> tensor - func.return %7, %8, %9, %10, %11, %12, %13 : tensor, tensor, tensor, tensor, tensor, tensor, tensor + func.return %7, %8, %9, %10, %11, %12, %13, %14 : tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor } // CHECK-LABEL: @max_with_neg_f32_max_val @@ -1126,6 +1129,15 @@ func.func @cast_f32_to_f64() -> tensor<4xf64> { // CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.500000e+00, 1.000000e+02]> : tensor<4xf64> +// CHECK-LABEL: @cast_f32_to_f16 +func.func @cast_f32_to_f16() -> tensor<4xf16> { + %cst = arith.constant dense<[-1.0, 0.0, 1.5, 100.0]> : tensor<4xf32> + %0 = "tfl.cast"(%cst) : (tensor<4xf32>) -> tensor<4xf16> + func.return %0 : tensor<4xf16> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.500000e+00, 1.000000e+02]> : tensor<4xf16> + // CHECK-LABEL: @ConstantFoldFullyConnectedSmall func.func @ConstantFoldFullyConnectedSmall() -> tensor<3xf32> { %cst_input = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/lower_quant_annotations.mlir b/tensorflow/compiler/mlir/lite/tests/lower_quant_annotations.mlir index 915db8f6867550..62434d956f8609 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower_quant_annotations.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower_quant_annotations.mlir @@ -3,6 +3,8 @@ func.func private @XlaCallModule_quant.fake_quant.impl_0(tensor<1x28x28x3xf32>) -> tensor<1x28x28x3xf32> func.func private @XlaCallModule_quant.fake_quant.impl_5_0(tensor<2x1x1x1xf32>) -> tensor<2x1x1x1xf32> func.func private @XlaCallModule_quant.fake_quant.impl_17_0(tensor<1x30x30x2xf32>) -> tensor<1x30x30x2xf32> +func.func private @XlaCallModule_quant.fake_quant.impl_i2_0(tensor<1x4xf32>) -> tensor<1x4xf32> +func.func private @XlaCallModule_quant.fake_quant.impl_i2_1(tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK-LABEL: func.func @serving_default func.func @serving_default(%arg0: tensor<1x28x28x3xf32>) -> (tensor<1x30x30x2xf32>) { %cst = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> @@ -22,4 +24,15 @@ func.func @serving_default(%arg0: tensor<1x28x28x3xf32>) -> (tensor<1x30x30x2xf3 // CHECK-OFF: %[[DEQUANT2:.+]] = "tfl.dequantize"(%[[QUANT2]]) : (tensor<1x30x30x2x!quant.uniform>) -> tensor<1x30x30x2xf32> %5 = stablehlo.composite "quant.fake_quant" %4 {composite_attributes = {dtype = "i8", narrow_range = false, scale = dense<0.0180494692> : tensor<1xf32>, zero_point = dense<8> : tensor<1xi32>}, decomposition = @XlaCallModule_quant.fake_quant.impl_17_0} : (tensor<1x30x30x2xf32>) -> tensor<1x30x30x2xf32> return %5 : tensor<1x30x30x2xf32> +} + +// CHECK-LABEL: func.func @i2_test +func.func @i2_test(%arg0: tensor<1x4xf32>) -> (tensor<1x4xf32>) { + // CHECK: %[[QUANT0:.+]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x4x!quant.uniform>}> : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: %[[DEQUANT0:.+]] = "tfl.dequantize"(%[[QUANT0]]) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> + %0 = stablehlo.composite "quant.fake_quant" %arg0 {composite_attributes = {dtype = "i2", narrow_range = false, scale = dense<1.0> : tensor<1xf32>, zero_point = dense<0> : tensor<1xi32>}, decomposition = @XlaCallModule_quant.fake_quant.impl_i2_0} : (tensor<1x4xf32>) -> tensor<1x4xf32> + // CHECK: %[[QUANT1:.+]] = "tfl.quantize"(%[[DEQUANT0]]) <{qtype = tensor<1x4x!quant.uniform:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>>}> : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>> + // CHECK: %[[DEQUANT1:.+]] = "tfl.dequantize"(%[[QUANT1]]) : (tensor<1x4x!quant.uniform:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>>) -> tensor<1x4xf32> + %1 = stablehlo.composite "quant.fake_quant" %0 {composite_attributes = {dtype = "i2", narrow_range = true, quantization_dimension = 1 : i32, scale = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>}, decomposition = @XlaCallModule_quant.fake_quant.impl_i2_1} : (tensor<1x4xf32>) -> tensor<1x4xf32> + return %1 : tensor<1x4xf32> } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 035f210a73a7f4..063e25944da6fe 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -4802,23 +4802,6 @@ func.func @RealDivWithConstDivisor(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %0 : tensor<2x3xf32> } -// When the const tensor cst is very large, `1 / cst` div introduced by -// div->mul conversion may not be folded and the `1 / cst` div may trigger -// the div->mul conversion again. -// This test checks the div->mul conversion will not be done infinitively. -// -// CHECK-LABEL: @RealDivWithLargeSizeConstDivisor -func.func @RealDivWithLargeSizeConstDivisor(%arg0: tensor<1x16x4096x4096xf32>) -> tensor<1x16x4096x4096xf32> { - %cst = arith.constant dense<5.000000e+01> : tensor<1x16x4096x4096xf32> - %1 = tfl.div %arg0, %cst {fused_activation_function = "NONE"} : tensor<1x16x4096x4096xf32> - func.return %1 : tensor<1x16x4096x4096xf32> - // CHECK-NEXT: %[[CST0:.*]] = arith.constant dense<1.000000e+00> : tensor - // CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<5.000000e+01> : tensor<1x16x4096x4096xf32> - // CHECK-NEXT: %[[DIV:.*]] = tfl.div(%[[CST0]], %[[CST1]]) <{fused_activation_function = "NONE"}> : (tensor, tensor<1x16x4096x4096xf32>) -> tensor<1x16x4096x4096xf32> - // CHECK-NEXT: %[[MUL:.*]] = tfl.mul %arg0, %[[DIV]] {fused_activation_function = "NONE"} : tensor<1x16x4096x4096xf32> - // CHECK-NEXT: return %[[MUL]] : tensor<1x16x4096x4096xf32> -} - //CHECK-LABEL: @PushTransposeThroughSqueezeNoDims func.func @PushTransposeThroughSqueezeNoDims(%arg0: tensor<1x1x2x3xf32>) -> (tensor<3x2xf32>) { %cst = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index e950a5d91b9876..2ce933112a0a43 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -323,21 +323,19 @@ absl::Status ConvertTFExecutorToStablehloFlatbuffer( // TODO: b/264218457 - Refactor the component below once StableHLO Quantizer // can run DRQ. Temporarily using TF Quantization for StableHLO DRQ. - if (!converter_flags.has_quantization_options()) { - // The default minimum number of elements a weights array must have to be - // quantized by this transformation. - const int kWeightsMinNumElementsDefault = 1024; - - quantization::QuantizationOptions quantization_options; - - quantization_options.mutable_quantization_method()->set_preset_method( - quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8); - quantization_options.set_op_set(quantization::UNIFORM_QUANTIZED); - quantization_options.set_min_num_elements_for_weights( - kWeightsMinNumElementsDefault); - quantization::AddQuantizePtqDynamicRangePasses(pass_manager, - quantization_options); - } + // The default minimum number of elements a weights array must have to be + // quantized by this transformation. + const int kWeightsMinNumElementsDefault = 1024; + + quantization::QuantizationOptions quantization_options; + + quantization_options.mutable_quantization_method()->set_preset_method( + quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8); + quantization_options.set_op_set(quantization::UNIFORM_QUANTIZED); + quantization_options.set_min_num_elements_for_weights( + kWeightsMinNumElementsDefault); + quantization::AddQuantizePtqDynamicRangePasses(pass_manager, + quantization_options); if (failed(pass_manager.run(module))) { return status_handler.ConsumeStatus(); } @@ -350,10 +348,6 @@ absl::Status ConvertTFExecutorToStablehloFlatbuffer( pass_manager.addPass(mlir::odml::createPrintOpStatsPass( mlir::odml::GetAcceptedStableHLODialects())); mlir::odml::AddStablehloOptimizationPasses(pass_manager); - if (converter_flags.has_quantization_options()) { - stablehlo::quantization::AddQuantizationPasses( - pass_manager, converter_flags.quantization_options()); - } if (failed(pass_manager.run(module))) { return status_handler.ConsumeStatus(); } diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc index f8438cd2231ad5..77b8fed82ab939 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.cc @@ -73,7 +73,7 @@ std::vector GetOpSignatureTensorSpecs( // Check if the tensor is a constant tensor. if (buffer_idx != 0 && buffer_idx < model->buffers()->size()) { auto* buffer = model->buffers()->Get(buffer_idx); - if (buffer->data() && buffer->data()->size() != 0) { + if (buffer->data() && !buffer->data()->empty()) { tensor_spec.is_const = true; } } @@ -143,8 +143,8 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, const QuantizationParameters* weight_quant = weight_tensor->quantization(); if (weight_quant && weight_quant->scale() && - weight_quant->scale()->size() && weight_tensor->shape() && - weight_tensor->shape()->size()) { + !weight_quant->scale()->empty() && weight_tensor->shape() && + !weight_tensor->shape()->empty()) { op_sig.ext_options.fully_connected.is_per_channel_quantized = IsTensorSizeEqual(weight_quant->scale()->size(), weight_tensor->shape()->Get(0)); @@ -152,7 +152,7 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, } break; case BuiltinOperator_MUL: { - if (op->inputs()->size() < 2 || op->outputs()->size() < 1) { + if (op->inputs()->size() < 2 || op->outputs()->empty()) { break; } const Tensor* input1_tensor = @@ -167,10 +167,10 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, const QuantizationParameters* output_quant = output_tensor->quantization(); if (input1_quant && input1_quant->scale() && - input1_quant->scale()->size() && input2_qunt && - input2_qunt->scale() && input2_qunt->scale()->size() && + !input1_quant->scale()->empty() && input2_qunt && + input2_qunt->scale() && !input2_qunt->scale()->empty() && output_quant && output_quant->scale() && - output_quant->scale()->size()) { + !output_quant->scale()->empty()) { op_sig.ext_options.mul.input1_scale = input1_quant->scale()->Get(0); op_sig.ext_options.mul.input2_scale = input2_qunt->scale()->Get(0); op_sig.ext_options.mul.output_scale = output_quant->scale()->Get(0); @@ -192,7 +192,7 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, filter_quant->scale()->size() == static_cast(num_filters)) { op_sig.ext_options.conv_2d.is_per_channel_quantized = true; } - if (input_tensor->shape() && input_tensor->shape()->size()) { + if (input_tensor->shape() && !input_tensor->shape()->empty()) { int num_input_channels = input_tensor->shape()->Get(3); int num_filter_input_channels = filter_tensor->shape()->Get(3); op_sig.ext_options.conv_2d.is_grouped_convolution = @@ -249,8 +249,9 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, const Tensor* table_tensor = subgraph->tensors()->Get(op->inputs()->Get(1)); const QuantizationParameters* table_quant = table_tensor->quantization(); - if (table_quant && table_quant->scale() && table_quant->scale()->size() && - table_tensor->shape() && table_tensor->shape()->size()) { + if (table_quant && table_quant->scale() && + !table_quant->scale()->empty() && table_tensor->shape() && + !table_tensor->shape()->empty()) { op_sig.ext_options.embedding_lookup.is_per_channel_quantized = table_quant->scale()->size() > 1 && IsTensorSizeEqual(table_quant->scale()->size(), diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc index 30c564c41c503c..9ccda1d0c95e69 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc @@ -177,6 +177,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { reinterpret_cast(op_sig.builtin_data); TFLITE_DCHECK(fully_connected_params != nullptr); + if (op_sig.inputs.at(1).type == kTfLiteInt2) { + return 14; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16 && op_sig.inputs.at(1).type == kTfLiteInt4 && op_sig.outputs.at(0).type == kTfLiteInt16) { @@ -464,6 +468,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_SLICE: + if (op_sig.inputs.at(0).type == kTfLiteInt4) { + return 7; + } if (op_sig.inputs.at(0).type == kTfLiteUInt32) { return 6; } @@ -473,7 +480,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { if (op_sig.inputs.at(0).type == kTfLiteInt16) { return 4; } - // Version 3 supports string input types. if (op_sig.inputs.at(0).type == kTfLiteString) { return 3; } @@ -499,6 +505,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_DEQUANTIZE: + if (op_sig.inputs.at(0).type == kTfLiteInt2) { + return 7; + } if (op_sig.inputs.at(0).type == kTfLiteInt4) { return 6; } @@ -1073,8 +1082,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 2; case BuiltinOperator_CAST: - if (op_sig.inputs.at(0).type == kTfLiteBFloat16 || - op_sig.outputs.at(0).type == kTfLiteBFloat16) { + if (op_sig.inputs.at(0).type == kTfLiteInt2 || + op_sig.outputs.at(0).type == kTfLiteInt2) { + return 8; + } else if (op_sig.inputs.at(0).type == kTfLiteBFloat16 || + op_sig.outputs.at(0).type == kTfLiteBFloat16) { return 7; } else if (op_sig.inputs.at(0).type == kTfLiteInt4 && op_sig.outputs.at(0).type == kTfLiteFloat32) { diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc index aaf335682a0358..87313665d1811f 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc @@ -733,6 +733,15 @@ TEST(OpVersionTest, VersioningFullyConnectedTest) { }; fake_op_sig.ext_options.fully_connected.is_per_channel_quantized = true; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 12); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt2}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 14); } TEST(OpVersionTest, VersioningDequantizeTest) { @@ -757,6 +766,12 @@ TEST(OpVersionTest, VersioningDequantizeTest) { fake_op_sig.ext_options.dequantize.is_per_channel_quantized = true; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + fake_op_sig = { + .op = BuiltinOperator_DEQUANTIZE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt2), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); + fake_op_sig = { .op = BuiltinOperator_DEQUANTIZE, .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), @@ -1467,4 +1482,72 @@ TEST(OpVersionTest, VersioningSqrtTest) { fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt16); EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); } + +TEST(OpVersionTest, VersioningCastTest) { + OpSignature fake_op_sig = {}; + fake_op_sig.op = BuiltinOperator_CAST; + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt2); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt2); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteBFloat16); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteBFloat16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt4); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat64); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat64); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteFloat16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt16); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt16); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteUInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); + + fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + fake_op_sig.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt32); + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1); +} } // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc index 4f4dc835c91d6c..aca1b463878966 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc @@ -112,6 +112,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_CAST, 5}, "2.12.0"}, {{BuiltinOperator_CAST, 6}, "2.15.0"}, {{BuiltinOperator_CAST, 7}, "2.17.0"}, + {{BuiltinOperator_CAST, 8}, "2.21.0"}, {{BuiltinOperator_CONCATENATION, 1}, "1.5.0"}, {{BuiltinOperator_CONCATENATION, 2}, "1.14.0"}, {{BuiltinOperator_CONCATENATION, 3}, "2.3.0"}, @@ -138,6 +139,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_FULLY_CONNECTED, 11}, "2.15.0"}, {{BuiltinOperator_FULLY_CONNECTED, 12}, "2.17.0"}, {{BuiltinOperator_FULLY_CONNECTED, 13}, "2.18.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 14}, "2.21.0"}, {{BuiltinOperator_GATHER, 1}, "1.6.0"}, {{BuiltinOperator_GATHER, 2}, "1.14.0"}, {{BuiltinOperator_GATHER, 3}, "1.15.0"}, @@ -293,6 +295,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_SLICE, 4}, "2.4.0"}, {{BuiltinOperator_SLICE, 5}, "2.5.0"}, {{BuiltinOperator_SLICE, 6}, "2.14.0"}, + {{BuiltinOperator_SLICE, 7}, "2.21.0"}, {{BuiltinOperator_TANH, 1}, "1.14.0"}, {{BuiltinOperator_TANH, 2}, "1.14.0"}, {{BuiltinOperator_TANH, 3}, "2.3.0"}, @@ -325,6 +328,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"}, {{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"}, {{BuiltinOperator_DEQUANTIZE, 6}, "2.18.0"}, + {{BuiltinOperator_DEQUANTIZE, 7}, "2.21.0"}, {{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"}, {{BuiltinOperator_EQUAL, 1}, "1.14.0"}, {{BuiltinOperator_EQUAL, 2}, "1.14.0"}, diff --git a/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc b/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc index 4886f09dd5c4bc..6b92b5f63ee66f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc +++ b/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc @@ -49,7 +49,7 @@ class DecomposeHybridQuantizationPass : public impl::DecomposeHybridQuantizationPassBase< DecomposeHybridQuantizationPass> { public: - explicit DecomposeHybridQuantizationPass() {} + explicit DecomposeHybridQuantizationPass() = default; void runOnOperation() override; }; diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 77b112d28a5098..0564dd56961b35 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -54,7 +54,7 @@ namespace { class DefaultQuantParamsPass : public impl::DefaultQuantParamsPassBase { public: - DefaultQuantParamsPass() {} + DefaultQuantParamsPass() = default; explicit DefaultQuantParamsPass(double default_min, double default_max, bool is_signed) { diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.cc b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.cc index 2959f6764354d0..6caa2107799844 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.cc @@ -71,12 +71,15 @@ LogicalResult FillCompositeParams(stablehlo::CompositeOp op, return failure(); } std::string dtype = dtype_attr.getValue().str(); - if (dtype == "i8") { - num_bits = 8; + if (dtype == "i2") { + num_bits = 2; is_signed = true; } else if (dtype == "i4") { num_bits = 4; is_signed = true; + } else if (dtype == "i8") { + num_bits = 8; + is_signed = true; } else { return failure(); } @@ -110,7 +113,16 @@ LogicalResult GetStorageParams(unsigned num_bits, bool narrow_range, bool is_signed, MLIRContext* ctx, Type& storage_type, int64_t& qmin, int64_t& qmax) { - if (num_bits <= 4) { + if (num_bits == 2) { + storage_type = IntegerType::get(ctx, 2); + if (is_signed) { + qmin = -2; + qmax = 1; + } else { + qmin = 0; + qmax = 3; + } + } else if (num_bits <= 4) { storage_type = IntegerType::get(ctx, 4); if (is_signed) { qmin = -8; diff --git a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc index feec5f23ca015b..29bb4e7134b598 100644 --- a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc +++ b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc @@ -42,7 +42,7 @@ struct ModifyIONodesPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ModifyIONodesPass) - explicit ModifyIONodesPass() {} + explicit ModifyIONodesPass() = default; explicit ModifyIONodesPass(mlir::Type input_type, mlir::Type output_type) { this->input_type = input_type; this->output_type = output_type; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 3e9cc005dafe01..c3d28495a31fde 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -2155,18 +2155,13 @@ def ReorderGatherAndCast : Pat< // Replace division by a constant with a multiplication by a reciprocal of that // constant. Floating point division can be ~10x more expensive than a // multiplication. -// Only do the replacement when arg0 is not a constant, otherwise the newly -// generated div will be converted to mul again if the const div is not -// folded (that could happen when const tensor is very large), and that will -// cause infinite recursion. def RealDivWithF32ConstDivisor : Pat< (TFL_DivOp:$src $arg0, (Arith_ConstantOp FloatElementsAttr<32>:$value), $activation), (TFL_MulOp:$dest1 $arg0, (TFL_DivOp (Arith_ConstantOp (GetScalarOfType<1> (Arith_ConstantOp $value))), (Arith_ConstantOp $value), TFL_AF_None), - $activation), - [(NotConstantLike $arg0)]>; + $activation)>; // Replace casting a boolean tensor to a numeric type, followed by comparing // with zero. Note it doesn't matter what type we're casting to. HasSameType diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 402674f7cbcf95..81a1a4e286f174 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -170,7 +170,7 @@ class PrepareCompositeFunctionsPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareCompositeFunctionsPass) - explicit PrepareCompositeFunctionsPass() {} + explicit PrepareCompositeFunctionsPass() = default; private: // TODO(b/160915525): Consolidate FuncAttr and StringAttr into one. diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc index 96412f20633f6a..7453ed54975a5a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc @@ -43,7 +43,7 @@ limitations under the License. namespace mlir { namespace TFL { namespace { -#define GEN_PASS_CLASSES +#define GEN_PASS_DEF_QUANTIZEVARIABLESPASS #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" using ResourceIdMap = @@ -80,7 +80,7 @@ Type GetDequantizedTypeFromAssigneVariableOp(VarHandleOp var_handle_op) { } class QuantizeVariablesPass - : public QuantizeVariablesPassBase { + : public impl::QuantizeVariablesPassBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeVariablesPass) explicit QuantizeVariablesPass() = default; diff --git a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc index 25d9b15fec858a..80e0986209e8d0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc @@ -42,7 +42,7 @@ struct RaiseCustomOpsPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseCustomOpsPass) - explicit RaiseCustomOpsPass() {} + explicit RaiseCustomOpsPass() = default; explicit RaiseCustomOpsPass(const std::vector &target_ops) { this->target_ops_ = target_ops; } diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc index a814f35a385c29..c735517dd2f1f3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc @@ -31,7 +31,7 @@ class RuntimeVerifyPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RuntimeVerifyPass) - explicit RuntimeVerifyPass() {} + explicit RuntimeVerifyPass() = default; private: void runOnOperation() override; diff --git a/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/while_loop_outline_pass.cc b/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/while_loop_outline_pass.cc index a8ef6ac3b0d711..29576e8e06676a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/while_loop_outline_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/while_loop_outline_pass.cc @@ -59,10 +59,10 @@ bool IsCompatibleTypeWithTFLCastOp(Type type) { elemType.isF64()) return true; - // I1, I4, I8, I16, I32, I64 types are allowed. - if (elemType.isInteger(1) || elemType.isInteger(4) || elemType.isInteger(8) || - elemType.isInteger(16) || elemType.isInteger(32) || - elemType.isInteger(64)) + // I1, I2, I4, I8, I16, I32, I64 types are allowed. + if (elemType.isInteger(1) || elemType.isInteger(2) || elemType.isInteger(4) || + elemType.isInteger(8) || elemType.isInteger(16) || + elemType.isInteger(32) || elemType.isInteger(64)) return true; // Complex> is allowed. diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc index f5699502eb134f..f88fc74b017555 100644 --- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -44,7 +44,7 @@ class TrimFunctionsPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TrimFunctionsPass) - explicit TrimFunctionsPass() {} + explicit TrimFunctionsPass() = default; explicit TrimFunctionsPass(llvm::ArrayRef trim_funcs_allowlist) { this->trim_funcs_allowlist_ = trim_funcs_allowlist; } diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index d464edc4078618..1b82ca5b0e61dc 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -67,14 +67,14 @@ template llvm::SmallVector ReadAsHostEndian(ArrayRef bytes) { llvm::SmallVector ret; size_t read_size = sizeof(T); - int bytes_len = bytes.size(); + size_t bytes_len = bytes.size(); assert(bytes_len % read_size == 0); - int elem_count = bytes_len / read_size; + size_t elem_count = bytes_len / read_size; ret.reserve(elem_count); const char* data_ptr = reinterpret_cast(bytes.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { T val = llvm::support::endian::readNext(data_ptr); ret.push_back(mlir::APInt(sizeof(T) * 8, val)); @@ -301,9 +301,17 @@ StatusOr ConvertIntBuffer( return mlir::ElementsAttr( DenseElementsAttr::get(shaped_type, ArrayRef(boolValues))); } + case 2: { + auto i2Values = tflite::UnpackDenseLowBitIntoInt8( + buffer, shaped_type.getNumElements(), /*bit_width=*/2); + // Use `getFromRawBuffer()` instead of `get()` to bypass a templated size + // check which doesn't work with int2 because int2_t doesn't exist. + return mlir::ElementsAttr(DenseElementsAttr::getFromRawBuffer( + shaped_type, ArrayRef(i2Values))); + } case 4: { - auto i4Values = - tflite::UnpackDenseInt4IntoInt8(buffer, shaped_type.getNumElements()); + auto i4Values = tflite::UnpackDenseLowBitIntoInt8( + buffer, shaped_type.getNumElements(), /*bit_width=*/4); // Use `getFromRawBuffer()` instead of `get()` to bypass a templated size // check which doesn't work with int4 because int4_t doesn't exist. return mlir::ElementsAttr(DenseElementsAttr::getFromRawBuffer( @@ -354,7 +362,7 @@ StatusOr ConvertFloatBuffer( assert(bytes_len % 2 == 0); // Supports both BF16 and F16. assert(elem_type.isF16() || elem_type.isBF16()); - int elem_count = bytes_len / 2; + size_t elem_count = bytes_len / 2; if (elem_type.isF16()) { std::vector values; @@ -362,7 +370,7 @@ StatusOr ConvertFloatBuffer( const char* data = reinterpret_cast(buffer.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { uint16_t bit_repr = llvm::support::endian::readNext< uint16_t, llvm::endianness::native, llvm::support::unaligned>( data); @@ -377,7 +385,7 @@ StatusOr ConvertFloatBuffer( const char* data = reinterpret_cast(buffer.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { uint16_t bit_repr = llvm::support::endian::readNext< uint16_t, llvm::endianness::native, llvm::support::unaligned>( data); @@ -390,13 +398,13 @@ StatusOr ConvertFloatBuffer( } case 32: { assert(bytes_len % 4 == 0); - int elem_count = bytes_len / 4; + size_t elem_count = bytes_len / 4; std::vector values; values.reserve(elem_count); const char* data = reinterpret_cast(buffer.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { uint32_t bit_repr = llvm::support::endian::readNext(data); @@ -407,13 +415,13 @@ StatusOr ConvertFloatBuffer( } case 64: { assert(bytes_len % 8 == 0); - int elem_count = bytes_len / 8; + size_t elem_count = bytes_len / 8; std::vector values; values.reserve(elem_count); const char* data = reinterpret_cast(buffer.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { uint64_t bit_repr = llvm::support::endian::readNext(data); diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index b118bab483048a..d774055fd2928a 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -114,6 +114,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { return mlir::ComplexType::get(builder.getF32Type()); case tflite::TensorType_COMPLEX128: return mlir::ComplexType::get(builder.getF64Type()); + case tflite::TensorType_INT2: + return builder.getIntegerType(2); case tflite::TensorType_INT4: return builder.getIntegerType(4); case tflite::TensorType_INT8: @@ -143,7 +145,9 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { return tensorflow::DT_FLOAT; case tflite::TensorType_FLOAT64: return tensorflow::DT_DOUBLE; - // TODO(b/246806634): Tensorflow DT_INT4 type doesn't exist yet + // TODO(b/246806634): Tensorflow DT_INT2/4 type doesn't exist yet + case tflite::TensorType_INT2: + return tensorflow::DT_INT8; case tflite::TensorType_INT4: return tensorflow::DT_INT8; case tflite::TensorType_INT8: diff --git a/tensorflow/compiler/mlir/lite/utils/low_bit_utils.cc b/tensorflow/compiler/mlir/lite/utils/low_bit_utils.cc index aa2e9697595b89..d0710f8b4d49d8 100644 --- a/tensorflow/compiler/mlir/lite/utils/low_bit_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/low_bit_utils.cc @@ -21,39 +21,41 @@ limitations under the License. namespace tflite { -std::vector PackInt4ValuesDensely(std::vector src_buffer) { +std::vector PackLowBitValuesDensely(std::vector src_buffer, + int bit_width) { auto num_elements = src_buffer.size(); - auto packed_size = (num_elements + 1) / 2; - std::vector packed_buffer((num_elements + 1) / 2); + const int elements_per_byte = 8 / bit_width; + auto packed_size = (num_elements + elements_per_byte - 1) / elements_per_byte; + std::vector packed_buffer(packed_size, 0); + const uint8_t mask = (1 << bit_width) - 1; - for (int i = 0; i < num_elements - 1; i += 2) { - packed_buffer[i / 2] = src_buffer[i] & 0x0F; - packed_buffer[i / 2] |= src_buffer[i + 1] << 4; - } - - // Copy the final nibble if the buffer is odd-lengthed - if (num_elements % 2 != 0) { - packed_buffer[packed_size - 1] = src_buffer[num_elements - 1] & 0x0F; + for (int i = 0; i < num_elements; ++i) { + int byte_index = i / elements_per_byte; + int bit_offset = (i % elements_per_byte) * bit_width; + packed_buffer[byte_index] |= (src_buffer[i] & mask) << bit_offset; } return packed_buffer; } -std::vector UnpackDenseInt4IntoInt8( - const std::vector& src_buffer, int64_t num_elements) { +std::vector UnpackDenseLowBitIntoInt8( + const std::vector& src_buffer, int64_t num_elements, + int bit_width) { std::vector unpacked_buffer; unpacked_buffer.reserve(num_elements); + const int elements_per_byte = 8 / bit_width; + const int sign_bit_shift = 8 - bit_width; for (uint8_t value : src_buffer) { - // Cast to signed before right-shifting to ensure correct sign extension - unpacked_buffer.push_back(static_cast(value << 4) >> 4); - unpacked_buffer.push_back(static_cast(value) >> 4); - } - - // The last element might be a padded zero, so check and pop if needed - if (unpacked_buffer.size() > num_elements) { - assert(unpacked_buffer.size() == num_elements + 1); - unpacked_buffer.pop_back(); + for (int i = 0; i < elements_per_byte; ++i) { + if (unpacked_buffer.size() == num_elements) break; + int bit_offset = i * bit_width; + uint8_t extracted_value = (value >> bit_offset); + // Sign extend + unpacked_buffer.push_back( + static_cast(extracted_value << sign_bit_shift) >> + sign_bit_shift); + } } return unpacked_buffer; diff --git a/tensorflow/compiler/mlir/lite/utils/low_bit_utils.h b/tensorflow/compiler/mlir/lite/utils/low_bit_utils.h index fa9bd851eab284..f0633410a45c66 100644 --- a/tensorflow/compiler/mlir/lite/utils/low_bit_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/low_bit_utils.h @@ -20,17 +20,18 @@ limitations under the License. #include namespace tflite { -// Assumes that `src_tensor` is a buffer where each element is a 4-bit value -// stored in 8-bit. -// Returns a new buffer that is packed densely with 2 4-bit values in a byte. -// The packing format is low-bits-first, i.e. the lower nibble of a byte is -// filled first, followed by the upper nibble. -std::vector PackInt4ValuesDensely(std::vector src_buffer); - -// Assumes `src_buffer` contains 2 4-bit elements packed in 8-bit. -// Returns a vector where each int8 element contains a int4 sign-extended value. -std::vector UnpackDenseInt4IntoInt8( - const std::vector& src_buffer, int64_t num_elements); +// Assumes that `src_tensor` is a buffer where each element is a low bit value +// (e.g. 2 or 4-bit) stored in 8-bit. +// Returns a new buffer that is packed densely. +// The packing format is low-bits-first. +std::vector PackLowBitValuesDensely(std::vector src_buffer, + int bit_width); + +// Assumes `src_buffer` contains densely packed low bit elements. +// Returns a vector where each int8 element contains a sign-extended value. +std::vector UnpackDenseLowBitIntoInt8( + const std::vector& src_buffer, int64_t num_elements, + int bit_width); } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LOW_BIT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc index 2acb4dccb88a18..0ae1247e2a156a 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc @@ -43,13 +43,13 @@ void Register(const std::string& op_name, OpRegistry* registry) { } // namespace TEST(TfTextUtilsTest, TestTfTextRegistered) { - std::unique_ptr registry(new OpRegistry); + std::unique_ptr registry = std::make_unique(); Register("WhitespaceTokenizeWithOffsets", registry.get()); EXPECT_TRUE(IsTFTextRegistered(registry.get())); } TEST(TfTextUtilsTest, TestTfTextNotRegistered) { - std::unique_ptr registry(new OpRegistry); + std::unique_ptr registry = std::make_unique(); Register("Test", registry.get()); EXPECT_FALSE(IsTFTextRegistered(registry.get())); } diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index 9d7e689f3b6a3c..0c6a636d38b822 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -124,7 +124,7 @@ class ModifyMlirModulePass : public MlirOptimizationPass { }; FunctionDef XTimesTwo() { - const Tensor kTwo = test::AsScalar(2); + const Tensor kTwo = test::AsScalar(2); return FunctionDefHelper::Define( // Name "XTimesTwo", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 5eaf5d736262ca..4f2384347a7802 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -251,7 +251,7 @@ std::string ExperimentalConvertSavedModelToMlir( // Convert the SavedModelV2Bundle to an MLIR module. - std::vector exported_names = + std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); @@ -270,10 +270,10 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite( const std::string& saved_model_path, const std::string& exported_names_str, const std::string& tags, bool upgrade_legacy, bool show_debug_info, TF_Status* status) { - std::unordered_set tag_set = + std::unordered_set tag_set = absl::StrSplit(tags, ',', absl::SkipEmpty()); - std::vector exported_names = + std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); @@ -299,7 +299,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( bool show_debug_info, TF_Status* status) { // Load the saved model into a SavedModelBundle. - std::unordered_set tag_set = + std::unordered_set tag_set = absl::StrSplit(tags, ',', absl::SkipEmpty()); tensorflow::SavedModelBundle bundle; @@ -311,7 +311,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( } // Convert the SavedModelBundle to an MLIR module. - std::vector exported_names = + std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 5c0de51a4f059a..969e84996acb4d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -125,6 +125,7 @@ cc_library( "@local_tsl//tsl/platform:regexp", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", + "@shardy//shardy/dialect/sdy/ir:register", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_passes", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc index 9a82ea7194614e..5d6d36ed3a6c7d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc @@ -36,8 +36,6 @@ using ::testing::HasSubstr; using ::testing::Key; using ::testing::SizeIs; using ::testing::StrEq; -using ::tsl::testing::IsOk; -using ::tsl::testing::StatusIs; TEST(CreateRepresentativeDatasetFileMapTest, ConfigWithoutExplicitSignatureKeyMappedToServingDefault) { @@ -52,7 +50,7 @@ TEST(CreateRepresentativeDatasetFileMapTest, representative_dataset_file_map = CreateRepresentativeDatasetFileMap(representative_dataset_configs); - ASSERT_THAT(representative_dataset_file_map, IsOk()); + ASSERT_THAT(representative_dataset_file_map, absl_testing::IsOk()); ASSERT_THAT(*representative_dataset_file_map, SizeIs(1)); EXPECT_THAT(*representative_dataset_file_map, Contains(Key("serving_default"))); @@ -74,7 +72,7 @@ TEST(CreateRepresentativeDatasetFileMapTest, ConfigWithExplicitSignatureKey) { representative_dataset_file_map = CreateRepresentativeDatasetFileMap(representative_dataset_configs); - ASSERT_THAT(representative_dataset_file_map, IsOk()); + ASSERT_THAT(representative_dataset_file_map, absl_testing::IsOk()); ASSERT_THAT(*representative_dataset_file_map, SizeIs(1)); EXPECT_THAT(*representative_dataset_file_map, Contains(Key(StrEq("test_signature_key")))); @@ -103,8 +101,9 @@ TEST(CreateRepresentativeDatasetFileMapTest, CreateRepresentativeDatasetFileMap(representative_dataset_configs); EXPECT_THAT(representative_dataset_file_map, - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("duplicate signature key: serving_default"))); + absl_testing::StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("duplicate signature key: serving_default"))); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc index a3a09bdb35daaa..2fb8f11a4e4349 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc @@ -49,23 +49,23 @@ class TestEnvBrokenFileSystem : public tsl::Env { public: TestEnvBrokenFileSystem() = default; - bool MatchPath(const tsl::string& path, const tsl::string& pattern) override { + bool MatchPath(const std::string& path, const std::string& pattern) override { return false; } void SleepForMicroseconds(int64_t micros) override {} - tsl::string GetRunfilesDir() override { return tsl::string("dummy_path"); } + std::string GetRunfilesDir() override { return std::string("dummy_path"); } int64_t GetCurrentThreadId() override { return 0; } tsl::Thread* StartThread(const tsl::ThreadOptions& thread_options, - const tsl::string& name, + const std::string& name, absl::AnyInvocable fn) override { return nullptr; } - bool GetCurrentThreadName(tsl::string* name) override { return false; } + bool GetCurrentThreadName(std::string* name) override { return false; } void SchedClosure(absl::AnyInvocable closure) override {} @@ -82,9 +82,9 @@ class TestEnvBrokenFileSystem : public tsl::Env { return absl::OkStatus(); } - tsl::string FormatLibraryFileName(const tsl::string& name, - const tsl::string& version) override { - return tsl::string("dummy_path"); + std::string FormatLibraryFileName(const std::string& name, + const std::string& version) override { + return std::string("dummy_path"); } // This is the part that would break the `CreateTmpDir` function because it @@ -95,7 +95,7 @@ class TestEnvBrokenFileSystem : public tsl::Env { } private: - void GetLocalTempDirectories(std::vector* list) override { + void GetLocalTempDirectories(std::vector* list) override { list->push_back("/tmp"); } }; @@ -107,7 +107,7 @@ class TestEnvBrokenFileSystemAndNoLocalTempDirs private: // This is the part that essentially breaks the `GetLocalTmpFileName` function // because it doesn't provide any available temp dirs. - void GetLocalTempDirectories(std::vector* list) override {} + void GetLocalTempDirectories(std::vector* list) override {} }; TEST(IoTest, GetLocalTmpFileNameGivesValidFileName) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index babda33245a7c8..0818c8013e534e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -246,8 +246,8 @@ class ConvertTfQuantToMhloIntTest : public Test { // Convert to double for comparison. This is needed for comparing integers // since it LiteralTestUtil asserts different integers even if it is within // error_spec. - TF_ASSERT_OK_AND_ASSIGN(auto expected_double, expected->Convert(xla::F64)) - TF_ASSERT_OK_AND_ASSIGN(auto result_double, result->Convert(xla::F64)) + TF_ASSERT_OK_AND_ASSIGN(auto expected_double, expected->Convert(xla::F64)); + TF_ASSERT_OK_AND_ASSIGN(auto result_double, result->Convert(xla::F64)); EXPECT_TRUE(xla::LiteralTestUtil::Near(expected_double, result_double, xla::ErrorSpec(error_tolerance))); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc index b55bf3f5d18558..7ee6bbd98f61e6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "shardy/dialect/sdy/ir/register.h" // from @shardy #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" // IWYU pragma: keep @@ -54,6 +55,7 @@ absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( } MLIRContext context; + mlir::sdy::loadAllRequiredDialects(&context); OwningOpRef stablehlo_module_op = mlir::stablehlo::deserializePortableArtifact(serialized_stablehlo_module, &context); @@ -77,7 +79,8 @@ absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( std::string bytecode; llvm::raw_string_ostream os(bytecode); if (failed(mlir::stablehlo::serializePortableArtifact( - stablehlo_module_op.get(), version.value().toString(), os))) { + stablehlo_module_op.get(), version.value().toString(), os, + /*allowOtherDialects=*/true))) { return absl::InternalError("Failed to serialize StableHLO module."); } return bytecode; diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc index c2e91c5da16e93..1f6464d85f5ef4 100644 --- a/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc @@ -4822,7 +4822,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ + std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; @@ -5022,7 +5022,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ + std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index cc178e762ecadd..cbd6bc3b283504 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -962,12 +962,18 @@ cc_library( hdrs = ["utils/deserialize_mlir_module_utils.h"], deps = [ ":error_util", - "//tensorflow/core/platform:status", + "//tensorflow/core:lib", + "//tensorflow/core:lib_headers_for_pybind", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_xla//xla:status_macros", + "@local_xla//xla/tsl/lib/io:inputstream_interface", + "@local_xla//xla/tsl/lib/io:zlib_compression_options", + "@local_xla//xla/tsl/lib/io:zlib_inputstream", ], ) @@ -976,10 +982,20 @@ cc_library( srcs = ["utils/serialize_mlir_module_utils.cc"], hdrs = ["utils/serialize_mlir_module_utils.h"], deps = [ - "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:flags_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_xla//xla/tsl/lib/io:zlib_compression_options", + "@local_xla//xla/tsl/lib/io:zlib_outputbuffer", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:errors", ], ) @@ -987,6 +1003,7 @@ tf_cc_test( name = "serialize_mlir_module_utils_test", srcs = ["utils/serialize_mlir_module_utils_test.cc"], deps = [ + ":deserialize_mlir_module_utils", ":serialize_mlir_module_utils", "//tensorflow/compiler/jit:flags", "//tensorflow/core:test", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 8a641e06d93519..e8d0ea525943fd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -1107,5 +1107,8 @@ LogicalResult IslandOp::fold(FoldAdaptor, // TableGen'd op method definitions //===----------------------------------------------------------------------===// +using mlir::tf_executor::ControlType; +using mlir::tf_executor::TokenType; + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index e23f510182259f..4104cf412acfd8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -335,7 +335,6 @@ def TF_IfRegionOp : TF_Op<"IfRegion", "areTypesCompatible", "getEntrySuccessorOperands", "getRegionInvocationBounds", - "getSuccessorRegions" ]> ]> { let summary = "output = cond ? then_branch output : else_branch output"; @@ -395,7 +394,6 @@ def TF_GeneratorDatasetRegionOp : TF_Op<"GeneratorDatasetRegion", "areTypesCompatible", "getEntrySuccessorOperands", "getRegionInvocationBounds", - "getSuccessorRegions" ]>, SingleBlockImplicitTerminator<"YieldOp">, TF_GeneratorOpSideEffect, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index ee65668078ca59..6382f325a47505 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -3003,14 +3003,14 @@ void GeneratorDatasetRegionOp::getRegionInvocationBounds( } OperandRange GeneratorDatasetRegionOp::getEntrySuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { auto end = this->getOperation()->operand_end(); - if (point.isParent()) { + if (successor.isParent()) { // The op itself doesn't branch back to itself. return ::mlir::OperandRange(end, end); - } else if (point.getRegionOrNull() == &getInit()) { + } else if (successor.getSuccessor() == &getInit()) { return getInitFuncOtherArgs(); - } else if (point.getRegionOrNull() == &getNext()) { + } else if (successor.getSuccessor() == &getNext()) { return getNextFuncOtherArgs(); } else /* finalize region */ { return getFinalizeFuncOtherArgs(); @@ -3024,13 +3024,15 @@ void GeneratorDatasetRegionOp::getSuccessorRegions( // The op itself branches to `init` first. regions.push_back( RegionSuccessor(&getInit(), getInit().front().getArguments())); - } else if (point.getRegionOrNull() == &getInit()) { + } else if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getInit()) { // `init` branches to `next`, passing along the arguments given to `init`'s // yield. Said arguments precede the "other args". n = getInitFuncOtherArgs().size(); regions.push_back(RegionSuccessor( &getNext(), getNext().front().getArguments().drop_back(n))); - } else if (point.getRegionOrNull() == &getNext()) { + } else if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getNext()) { // `next` branches to itself, or to `finalize`, passing all arguments given // to `next`s yield. @@ -3045,7 +3047,8 @@ void GeneratorDatasetRegionOp::getSuccessorRegions( &getFinalize(), getFinalize().front().getArguments().slice(0, num))); } else { // `finalize` branches back to the op itself, not passing any arguments. - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + point.getTerminatorPredecessorOrNull()->getParentRegion())); } } @@ -3261,11 +3264,12 @@ void IfRegionOp::getRegionInvocationBounds( invocationBounds.assign(2, {0, 1}); } -OperandRange IfRegionOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange IfRegionOp::getEntrySuccessorOperands(RegionSuccessor successor) { // IfRegionOp currently only allows one op (the condition), so there are no // remaining operands for the successor. - assert((point.isParent() || - (point == (*this)->getRegion(0) || point == (*this)->getRegion(1))) && + assert((successor.isParent() || + (successor.getSuccessor() == &(*this)->getRegion(0) || + successor.getSuccessor() == &(*this)->getRegion(1))) && "Invalid IfRegionOp region index."); auto end = this->getOperation()->operand_end(); return ::mlir::OperandRange(end, end); @@ -3275,16 +3279,20 @@ void IfRegionOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl& regions) { if (!point.isParent()) { // The `then` and the `else` region branch back to the parent operation. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back( + RegionSuccessor(point.getTerminatorPredecessorOrNull(), getResults())); return; } else { // The parent can branch to either `then` or `else`. - regions.push_back(RegionSuccessor(&getThenBranch())); + regions.push_back( + RegionSuccessor(&getThenBranch(), getThenBranch().getArguments())); Region* elseRegion = &this->getElseBranch(); if (!elseRegion->empty()) - regions.push_back(RegionSuccessor(elseRegion)); + regions.push_back( + RegionSuccessor(elseRegion, elseRegion->getArguments())); else - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + point.getTerminatorPredecessorOrNull()->getParentRegion())); } } @@ -3727,5 +3735,7 @@ LogicalResult BitcastOp::verify() { // TableGen'd op method definitions //===----------------------------------------------------------------------===// +using namespace mlir; // NOLINT + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index e4100657db7081..23683673fe189a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -3611,8 +3611,8 @@ SmallVector WhileRegionOp::getLoopRegions() { return {&getBody()}; } //===----------------------------------------------------------------------===// OperandRange WhileRegionOp::getEntrySuccessorOperands( - RegionBranchPoint point) { - if (point.isParent()) { + RegionSuccessor successor) { + if (successor.isParent()) { // WhileRegionOp branches to the condition, which branches to the body. But // the op itself doesn't branch back to itself. So this range is empty. auto end = this->getOperation()->operand_end(); @@ -3628,21 +3628,28 @@ OperandRange WhileRegionOp::getEntrySuccessorOperands( void WhileRegionOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { - if (!point.isParent() && point == (*this)->getRegion(0)) { + if (!point.isParent() && + (point.getTerminatorPredecessorOrNull() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &(*this)->getRegion(0))) { // 'cond' branches to the body or returns. Operation *yield = getCond().front().getTerminator(); if (yield->getOperands().size() == 1 + this->getOperation()->getOperands().size()) { regions.push_back( RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.push_back(getResults()); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } else { // For compatibility with older code, we allow the "yield" in a condition // to only yield a single boolean. In that case we can't forward any args. regions.push_back(RegionSuccessor(&getBody())); - regions.push_back(RegionSuccessor()); // branch back to parent, no args + regions.push_back( + RegionSuccessor(getOperation(), getResults().take_front(0))); } - } else if (!point.isParent() && point == (*this)->getRegion(1)) { + } else if (!point.isParent() && + (point.getTerminatorPredecessorOrNull() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &(*this)->getRegion(1))) { // 'body' branches back to 'cond'. regions.push_back( RegionSuccessor(&getCond(), getCond().front().getArguments())); @@ -4510,7 +4517,7 @@ LogicalResult UniformQuantizedClipByValueOp::verify() { //===----------------------------------------------------------------------===// MutableOperandRange YieldOp::getMutableSuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { if (auto whileOp = llvm::dyn_cast(this->getOperation()->getParentOp())) { if (&whileOp.getCond() == this->getOperation()->getParentRegion()) { @@ -4538,5 +4545,7 @@ MutableOperandRange YieldOp::getMutableSuccessorOperands( // TableGen'd op method definitions //===----------------------------------------------------------------------===// +using namespace mlir; // NOLINT + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc index b5ce10d1500be8..7419149074fb8a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc @@ -188,5 +188,6 @@ std::optional _SendOp::GetResourceInstanceStr() { // TableGen'd op method definitions //===----------------------------------------------------------------------===// +using namespace mlir; // NOLINT #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 765ed1171a8449..a3305eef8a0819 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1317,7 +1317,7 @@ func.func @testIfRegionElseTerminator(%arg0: tensor, %arg1: tensor<2xf32>) - // tf.Region yield number of results should match op number of results func.func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Region #0 to parent results: source has 2 operands, but target successor needs 1}} + // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Operation tf.Yield to parent results: source has 2 operands, but target successor needs 1}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> () @@ -1332,7 +1332,7 @@ func.func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) // ----- func.func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Region #1 to parent results: source has 2 operands, but target successor needs 1}} + // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Operation tf.Yield to parent results: source has 2 operands, but target successor needs 1}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py index 118d7f38ebf959..1087c6e3a679bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py @@ -36,7 +36,8 @@ class TestModule(tf.Module): def __init__(self): super(TestModule, self).__init__() self.v42 = tf.Variable(42.0) - self.c43 = tf.constant(43.0) + # Use convert_to_tensor to avoid forcing eager `.numpy()` in graph/XLA mode. + self.c43 = tf.convert_to_tensor(43.0, dtype=tf.float32) # During serialization, the constants are given internal (non-user-accessible, non-semantically-load-bearing) exported names. # CHECK: "tf_saved_model.global_tensor"() <{sym_name = "[[CONST:[a-zA-Z_0-9.]+]]", type = tensor, value = dense<4.300000e+01> : tensor}> {tf_saved_model.exported_names = [{{.*}}]} : () -> () diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index bc4487a4e3fd7d..954c318b416150 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" #include -#include #include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" @@ -230,7 +230,7 @@ std::optional> EquationToMap( llvm::StringRef equation) { llvm::SmallDenseMap map; for (int64_t i = 0; i < equation.size(); ++i) { - if (!std::isalpha(equation[i])) { + if (!llvm::isAlpha(equation[i])) { // Unsupported character in the equation. return std::nullopt; } @@ -263,7 +263,7 @@ std::optional> GetAvailableLabels( const int lhs_size = lhs.size(); for (int i = 0; i < lhs_size; ++i) { const char label = lhs[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { labels.remove(label); ++lhs_count; } else if (label == '.') { @@ -280,7 +280,7 @@ std::optional> GetAvailableLabels( const int rhs_size = rhs.size(); for (int i = 0; i < rhs_size; ++i) { const char label = rhs[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { labels.remove(label); ++rhs_count; } else if (label == '.') { @@ -318,7 +318,7 @@ std::tuple FlattenEllipsis( std::string new_lhs; for (int i = 0; i < lhs.size(); ++i) { const char label = lhs[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { new_lhs.push_back(label); } else { // Encounter ellipsis: generate unnamed labels then insert to the new @@ -333,7 +333,7 @@ std::tuple FlattenEllipsis( std::string new_rhs, new_rhs_labels; for (int i = 0; i < rhs.size(); ++i) { const char label = rhs[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { new_rhs.push_back(label); } else { // Encounter ellipsis: generate unnamed labels then insert to the new @@ -352,7 +352,7 @@ std::tuple FlattenEllipsis( std::string new_output; for (int i = 0; i < out.size(); ++i) { const char label = out[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { new_output.push_back(label); } else { // Encounter ellipsis: we will just insert the generated labels to the new diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc index a2c4a7031ed14b..0cdb563a45eed7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc @@ -49,7 +49,7 @@ static constexpr int kTextFileIndex_LineNumber = -1; class InitTextFileToImportPass : public impl::InitTextFileToImportPassBase { public: - InitTextFileToImportPass() {} + InitTextFileToImportPass() = default; InitTextFileToImportPass(const InitTextFileToImportPass&) {} explicit InitTextFileToImportPass(std::string saved_model_dir) { saved_model_dir_ = saved_model_dir; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc index a985cdc11611b4..41c5cd4234f1cc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc @@ -46,7 +46,7 @@ class InitTextFileToImportTestPass : public impl::InitTextFileToImportTestPassBase< InitTextFileToImportTestPass> { public: - explicit InitTextFileToImportTestPass() {} + explicit InitTextFileToImportTestPass() = default; StringRef getArgument() const final { return "tf-init-text-file-to-import-test"; @@ -115,7 +115,7 @@ class InitTextFileToImportSavedModelTestPass : public impl::InitTextFileToImportSavedModelTestPassBase< InitTextFileToImportSavedModelTestPass> { public: - explicit InitTextFileToImportSavedModelTestPass() {} + explicit InitTextFileToImportSavedModelTestPass() = default; private: void runOnOperation() override; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 2e023e3e057096..57a41f538f277f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -36,6 +35,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -289,8 +289,10 @@ ObjectNames::ObjectNames(const SavedObjectGraph& object_graph, // - `model.variables.0` // - `model.keras_api.layers.1.keras_api.trainable_variables.0` // - ... 10 more long aliases ending in digits ... - return std::make_tuple(isdigit(a.back()), a.size(), a) < - std::make_tuple(isdigit(b.back()), b.size(), b); + return std::make_tuple(absl::ascii_isdigit(a.back()), a.size(), + a) < + std::make_tuple(absl::ascii_isdigit(b.back()), b.size(), + b); }); for (const std::string& name : kv.second) { if (IsExported(name)) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.cc b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.cc index c48f52576df4e3..0288006ee4d105 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.cc @@ -39,13 +39,13 @@ limitations under the License. namespace tensorflow { absl::Status ParseOutputArrayInfo(absl::string_view array_names, - std::vector* outputs) { + std::vector* outputs) { TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs)); return absl::OkStatus(); } -absl::Status ParseOutputArrayInfo(const std::vector& output_names, - std::vector* outputs) { +absl::Status ParseOutputArrayInfo(const std::vector& output_names, + std::vector* outputs) { for (auto& output_name : output_names) { if (output_name.empty()) continue; outputs->push_back(output_name); @@ -57,8 +57,8 @@ absl::Status ParseInputArrayInfo(absl::string_view array_names, absl::string_view data_types, absl::string_view shapes, GraphImportConfig::InputArrays* inputs) { - std::vector node_names; - std::vector node_dtypes; + std::vector node_names; + std::vector node_dtypes; std::vector>> node_shapes; TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names)); TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes)); @@ -113,8 +113,8 @@ static absl::Status HandleSubtype(absl::string_view subtype, } absl::Status ParseInputArrayInfo( - const std::vector& node_names, - const std::vector& node_dtypes, + const std::vector& node_names, + const std::vector& node_dtypes, const std::vector>>& node_shapes, GraphImportConfig::InputArrays* inputs) { std::vector used_node_dtypes; @@ -148,7 +148,7 @@ absl::Status ParseInputArrayInfo( // StringMap doesn't support reserve else reserve input map size here. for (int i = 0, end = node_names.size(); i < end; i++) { auto& name = node_names[i]; - const string& type = used_node_dtypes[i]; + const std::string& type = used_node_dtypes[i]; if (name.empty()) continue; auto it_inserted_pair = inputs->insert({name, {}}); @@ -193,7 +193,7 @@ absl::Status ParseNodeShapes( std::vector>>& shapes_vector) { shapes_vector.clear(); if (!shapes_str.empty()) { - std::vector node_shapes_str = absl::StrSplit(shapes_str, ':'); + std::vector node_shapes_str = absl::StrSplit(shapes_str, ':'); for (int i = 0; i < node_shapes_str.size(); i++) { if (node_shapes_str[i] == "*") { shapes_vector.push_back(std::nullopt); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h index 1119d4e2b33c4f..176773da45fcbc 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h @@ -35,10 +35,10 @@ namespace tensorflow { // Parses the command line flag strings to the specification of nodes in // the Graph. absl::Status ParseOutputArrayInfo(absl::string_view array_names, - std::vector* outputs); + std::vector* outputs); -absl::Status ParseOutputArrayInfo(const std::vector& output_names, - std::vector* outputs); +absl::Status ParseOutputArrayInfo(const std::vector& output_names, + std::vector* outputs); // Parses the command line flag strings to the specification of nodes in // the Graph. `data_types` input string can be empty since the flag is optional. @@ -48,8 +48,8 @@ absl::Status ParseInputArrayInfo(absl::string_view array_names, GraphImportConfig::InputArrays* inputs); absl::Status ParseInputArrayInfo( - const std::vector& node_names, - const std::vector& node_dtypes, + const std::vector& node_names, + const std::vector& node_dtypes, const std::vector>>& node_shapes, GraphImportConfig::InputArrays* inputs); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index 858c70a54a58d6..3706b8afe34d78 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include -#include #include #include #include #include "absl/log/log.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_split.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" @@ -99,8 +99,7 @@ std::vector BridgeLoggerConfig::GetFilter( bool BridgeLoggerConfig::ShouldOnlyDumpTopLevelPasses() { const char* env_var = getenv(kEnableOnlyTopLevelPassesEnvVar); - std::string value(env_var); - std::transform(value.begin(), value.end(), value.begin(), ::tolower); + std::string value = absl::AsciiStrToLower(env_var); // Return true if value is "1" or "true"; otherwise, false. return value == "1" || value == "true"; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index b0ad4e265633d8..550ab547498f45 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -249,14 +249,14 @@ absl::StatusOr ConvertTensor(const Tensor& input_tensor, CONVERT_FLAT(DT_BOOL, bool) CONVERT_FLAT(DT_FLOAT, float) CONVERT_FLAT(DT_DOUBLE, double) - CONVERT_FLAT(DT_INT8, int8) - CONVERT_FLAT(DT_INT16, int16) - CONVERT_FLAT(DT_INT32, int32) + CONVERT_FLAT(DT_INT8, int8_t) + CONVERT_FLAT(DT_INT16, int16_t) + CONVERT_FLAT(DT_INT32, int32_t) CONVERT_FLAT(DT_INT64, int64_t) - CONVERT_FLAT(DT_UINT8, uint8) - CONVERT_FLAT(DT_UINT16, uint16) - CONVERT_FLAT(DT_UINT32, uint32) - CONVERT_FLAT(DT_UINT64, uint64) + CONVERT_FLAT(DT_UINT8, uint8_t) + CONVERT_FLAT(DT_UINT16, uint16_t) + CONVERT_FLAT(DT_UINT32, uint32_t) + CONVERT_FLAT(DT_UINT64, uint64_t) CONVERT_FLAT(DT_COMPLEX64, std::complex) CONVERT_FLAT(DT_COMPLEX128, std::complex) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index a34553623408d8..b120b6c786edb6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -162,11 +162,11 @@ TEST_F(ConvertTensorTest, Simple) { ASSERT_NO_FATAL_FAILURE(VerifyConversion( {static_cast(1), static_cast(-1)}, DT_INT4, mlir::IntegerType::get(&context, 4))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32))); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64))); @@ -175,19 +175,19 @@ TEST_F(ConvertTensorTest, Simple) { {static_cast(1), static_cast(2)}, DT_UINT4, mlir::IntegerType::get( &context, 4, mlir::IntegerType::SignednessSemantics::Unsigned))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT8, mlir::IntegerType::get( &context, 8, mlir::IntegerType::SignednessSemantics::Unsigned))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT16, mlir::IntegerType::get( &context, 16, mlir::IntegerType::SignednessSemantics::Unsigned))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT32, mlir::IntegerType::get( &context, 32, mlir::IntegerType::SignednessSemantics::Unsigned))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT64, mlir::IntegerType::get( &context, 64, mlir::IntegerType::SignednessSemantics::Unsigned))); @@ -222,11 +222,11 @@ TEST_F(ConvertTensorTest, SimpleDenseResourceElements) { ASSERT_NO_FATAL_FAILURE(VerifyConversion( {static_cast(1), static_cast(-1)}, DT_INT4, mlir::IntegerType::get(&context, 4), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32), true)); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64), true)); @@ -236,22 +236,22 @@ TEST_F(ConvertTensorTest, SimpleDenseResourceElements) { mlir::IntegerType::get(&context, 4, mlir::IntegerType::SignednessSemantics::Unsigned), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT8, mlir::IntegerType::get(&context, 8, mlir::IntegerType::SignednessSemantics::Unsigned), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT16, mlir::IntegerType::get(&context, 16, mlir::IntegerType::SignednessSemantics::Unsigned), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT32, mlir::IntegerType::get(&context, 32, mlir::IntegerType::SignednessSemantics::Unsigned), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT64, mlir::IntegerType::get(&context, 64, mlir::IntegerType::SignednessSemantics::Unsigned), diff --git a/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config_test.cc index 09a76102557c4f..a4f2861276a9bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config_test.cc @@ -59,9 +59,9 @@ TEST(DataDumperLoggerConfig, TestPassFilter) { 1); setenv("TF_DUMP_GRAPH_PREFIX", "sponge", 1); - const string kTestFilename = "test.txt"; + const std::string kTestFilename = "test.txt"; int print_callback_count = 0; - auto get_filename_fn = [](const string &filename, mlir::Operation *op) { + auto get_filename_fn = [](const std::string& filename, mlir::Operation* op) { return filename; }; auto print_callback = [&](llvm::raw_ostream &out) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/deserialize_mlir_module_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/deserialize_mlir_module_utils.cc index da7917c9c21a4c..bcd3164cd10f7c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/deserialize_mlir_module_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/deserialize_mlir_module_utils.cc @@ -15,7 +15,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/deserialize_mlir_module_utils.h" +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -23,8 +32,54 @@ limitations under the License. #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "xla/status_macros.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" +#include "tensorflow/core/platform/tstring.h" namespace tensorflow { +namespace { +// Wrap memory buffer into InputStreamInterface +class MemoryInputStream : public tensorflow::io::InputStreamInterface { + public: + explicit MemoryInputStream(const char* buffer, size_t length) + : buf_(buffer), len_(length), pos_(0) {} + + ~MemoryInputStream() override = default; + + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { + result->clear(); + if (bytes_to_read < 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Can't read a negative number of bytes: ", bytes_to_read)); + } + absl::Status status = absl::OkStatus(); + int64_t bytes = bytes_to_read; + if (pos_ + bytes_to_read > len_) { + bytes = len_ - pos_; + status = absl::OutOfRangeError("Reached end of file"); + } + if (bytes > 0) { + result->resize(bytes); + memcpy(&(*result)[0], &buf_[pos_], bytes); + pos_ += bytes; + } + return status; + } + + int64_t Tell() const override { return pos_; } + + absl::Status Reset() override { + pos_ = 0; + return absl::OkStatus(); + } + + private: + const char* buf_; // Not owned. + int64_t len_; + int64_t pos_ = 0; // Tracks where we are in the file. +}; +} // namespace absl::Status DeserializeMlirModule( llvm::StringRef serialized_mlir_module, mlir::MLIRContext* mlir_context, @@ -37,13 +92,44 @@ absl::Status DeserializeMlirModule( // error reporting system. mlir::StatusScopedDiagnosticHandler error_handler(mlir_context); - // Parse the module. - *mlir_module = mlir::parseSourceString(serialized_mlir_module, - mlir_context); - if (!*mlir_module) - return error_handler.Combine( - absl::InvalidArgumentError("could not parse MLIR module")); - + // Look for the GZIP magic number to check if this is a compressed bytecode. + if (serialized_mlir_module.starts_with("\x1f\x8b")) { + // Try to uncompress the and parse the bytecode. + auto input_stream = std::make_unique( + serialized_mlir_module.data(), serialized_mlir_module.size()); + io::ZlibCompressionOptions options = io::ZlibCompressionOptions::GZIP(); + auto zlib_stream = std::make_unique( + input_stream.get(), options.input_buffer_size, + options.output_buffer_size, options); + tstring uncompressed_bytecode; + absl::Status s = zlib_stream->ReadNBytes(/*bytes_to_read=*/INT_MAX, + &uncompressed_bytecode); + // OK status means the decompression is successful. + // OutOfRange error means the decompression is successful but end of input + // was reached before *bytes_to_read* bytes were read. + if (!s.ok() && !absl::IsOutOfRange(s)) { + // Failed to uncompress the bytecode and it is not the end of the input. + return error_handler.Combine(absl::InvalidArgumentError( + absl::StrCat("Failed to uncompress MLIR module", s.message()))); + } + // Parse the uncompressed bytecode. + auto uncompressed_bytecode_str = + std::string(uncompressed_bytecode.data(), uncompressed_bytecode.size()); + *mlir_module = mlir::parseSourceString( + uncompressed_bytecode_str, mlir_context); + if (!*mlir_module) { + // Uncompressing was successful but the parsed MLIR module is invalid. + return error_handler.Combine(absl::InvalidArgumentError( + "Failed to parse MLIR module after uncompressing")); + } + } else { + *mlir_module = mlir::parseSourceString( + serialized_mlir_module, mlir_context); + if (!*mlir_module) { + return error_handler.Combine( + absl::InvalidArgumentError("could not parse MLIR module")); + } + } return absl::OkStatus(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index d9249d472b334c..3329bff4c02737 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -126,7 +126,8 @@ void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) { // For device that do not have any metadata, or if we failed to parse metadata // from the DeviceSet, we add a unit attribute to the `tf.devices` attribute. for (Device* device : device_set->devices()) { - string name = DeviceNameUtils::ParsedNameToString(device->parsed_name()); + std::string name = + DeviceNameUtils::ParsedNameToString(device->parsed_name()); if (device->device_type() == DEVICE_GPU) { auto metadata = ParseGpuDeviceMetadata(*device, &builder); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index c3e7ae75022348..abf357873a6153 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -52,8 +52,8 @@ class FakeDevice : public Device { return errors::Unimplemented("FakeDevice::Sync()"); } - static std::unique_ptr Make(const string& name, - const string& desc = "") { + static std::unique_ptr Make(const std::string& name, + const std::string& desc = "") { DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParseFullName(name, &parsed_name); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc index 7e92860e5ff03e..9d9780d231523f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc @@ -26,12 +26,12 @@ limitations under the License. namespace tensorflow { namespace { -void ExpectHasSubstr(const string& s, const string& expected) { +void ExpectHasSubstr(const std::string& s, const std::string& expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } -void ExpectHasNoSubstr(const string& s, const string& expected) { +void ExpectHasNoSubstr(const std::string& s, const std::string& expected) { EXPECT_FALSE(absl::StrContains(s, expected)) << "'" << s << "' should not contain '" << expected << "'"; } @@ -39,7 +39,7 @@ void ExpectHasNoSubstr(const string& s, const string& expected) { // WritableFile that simply concats into string. class StringWritableFile : public WritableFile { public: - explicit StringWritableFile(string* str) : str_(*str) {} + explicit StringWritableFile(std::string* str) : str_(*str) {} absl::Status Append(absl::string_view data) override { absl::StrAppend(&str_, data); @@ -62,7 +62,7 @@ class StringWritableFile : public WritableFile { } private: - string& str_; + std::string& str_; }; TEST(Dump, TextualIrToFileSuccess) { @@ -72,10 +72,10 @@ TEST(Dump, TextualIrToFileSuccess) { setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1); UseMlirForGraphDump(MlirDumpConfig()); - string ret = DumpGraphToFile("tir", graph); + std::string ret = DumpGraphToFile("tir", graph); ASSERT_EQ(ret, io::JoinPath(testing::TmpDir(), "tir.mlir")); - string actual; + std::string actual; TF_ASSERT_OK(ReadFileToString(Env::Default(), ret, &actual)); } @@ -86,12 +86,12 @@ TEST(Dump, TextualIrWithOptions) { .Attr("dtype", DT_FLOAT) .Finalize(&graph, &node)); - string actual; + std::string actual; StringWritableFile file(&actual); TF_ASSERT_OK(DumpTextualIRToFile(MlirDumpConfig().emit_location_information(), graph, /*flib_def=*/nullptr, &file)); - string expected_substr = R"(loc(#loc))"; + std::string expected_substr = R"(loc(#loc))"; ExpectHasSubstr(actual, expected_substr); } @@ -100,17 +100,17 @@ TEST(Dump, DumpToTFG) { Node* node; TF_CHECK_OK(NodeBuilder("A", "NoOp").Finalize(&graph, &node)); - string actual; + std::string actual; StringWritableFile file(&actual); TF_ASSERT_OK(DumpTextualIRToFile( MlirDumpConfig().emit_dialect(MlirDumpConfig::Dialect::kTFG), graph, /*flib_def=*/nullptr, &file)); - string expected_substr("tfg.graph"); + std::string expected_substr("tfg.graph"); ExpectHasSubstr(actual, expected_substr); - string not_expected_substr("tf_executor.island"); + std::string not_expected_substr("tf_executor.island"); ExpectHasNoSubstr(actual, not_expected_substr); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index b970ca84b326cf..138e13e3719328 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -44,7 +44,7 @@ struct NameCounts { llvm::StringMap counts; }; -std::string MakeUniqueFilename(string name) { +std::string MakeUniqueFilename(std::string name) { static NameCounts& instance = *new NameCounts; // Remove illegal characters from `name`. @@ -274,7 +274,7 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { // Output dirs "sponge" (case-insensitive) have a special meaning: Dump into // the directory specified by the environment variable // TEST_UNDECLARED_OUTPUTS_DIR. - string lower_path = absl::AsciiStrToLower(path); + std::string lower_path = absl::AsciiStrToLower(path); if (lower_path == "sponge") { if (!tensorflow::io::GetTestUndeclaredOutputsDir(&path)) { LOG(ERROR) << "MLIR crash reproducer is set to '" << dir_path.str() diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 9ec1b9970ae777..9e07ece4e0999e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -400,12 +400,12 @@ absl::Status ConvertAttributes( if (auto symbol_ref = mlir::dyn_cast(attr)) { TF_RETURN_IF_ERROR(ConvertAttribute( mlir::cast(symbol_ref), &value)); - func_call_attrs[string(name)] = std::move(value); + func_call_attrs[std::string(name)] = std::move(value); continue; } if (auto func_attr = mlir::dyn_cast(attr)) { TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value)); - func_call_attrs[string(name)] = std::move(value); + func_call_attrs[std::string(name)] = std::move(value); continue; } if (mlir::isa(attr)) { @@ -434,12 +434,12 @@ absl::Status ConvertAttributes( // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in // the attribute from MLIR, it is treated as an attribute from function // calls. - std::vector name_tokens = + std::vector name_tokens = absl::StrSplit(name, '.', absl::SkipEmpty()); TF_RET_CHECK(name_tokens.size() <= 2); auto it = func_call_attrs.find(name_tokens[0]); if (it == func_call_attrs.end()) { - (*values)[string(name)] = std::move(value); + (*values)[std::string(name)] = std::move(value); } else { (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = std::move(value); @@ -457,7 +457,7 @@ absl::Status SetShapeAttribute(absl::string_view name, AttrValue value; SetTensorShapeProto(shaped_type, value.mutable_list()->add_shape()); - auto result = values->insert({string(name), value}); + auto result = values->insert({std::string(name), value}); if (!result.second) { // This should be extremely rare as it means we are adding the same // attribute multiple times/have some redundancy in representing this diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc index 50306edb28b067..fa2ff3c8a281fa 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc @@ -59,7 +59,7 @@ absl::Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { if (std::error_code error = file_or_err.getError()) { return errors::InvalidArgument( "Could not open input file ", - string(input_filename.data(), input_filename.size()).c_str()); + std::string(input_filename.data(), input_filename.size()).c_str()); } const auto& input_file = *file_or_err; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc index a189cc14555143..fbcdc9e894fbd9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc @@ -41,7 +41,7 @@ const char kTensorPrefix[] = "tftensor$"; } // namespace -string MangleAttributeName(absl::string_view str) { +std::string MangleAttributeName(absl::string_view str) { return absl::StrCat(kAttributePrefix, str); } @@ -66,7 +66,7 @@ MangledKind GetMangledKind(absl::string_view str) { } } -string MangleShape(const TensorShapeProto& shape) { +std::string MangleShape(const TensorShapeProto& shape) { return absl::StrCat(kTensorShapePrefix, PrintShortTextProto(shape)); } @@ -74,7 +74,7 @@ absl::Status DemangleShape(absl::string_view str, TensorShapeProto* proto) { return ParseTextProto(str, kTensorShapePrefix, proto); } -string MangleTensor(const TensorProto& tensor) { +std::string MangleTensor(const TensorProto& tensor) { return absl::StrCat(kTensorPrefix, PrintShortTextProto(tensor)); } @@ -82,7 +82,7 @@ absl::Status DemangleTensor(absl::string_view str, TensorProto* proto) { return ParseTextProto(str, kTensorPrefix, proto); } -string MangleDataType(const DataType& dtype) { +std::string MangleDataType(const DataType& dtype) { return absl::StrCat(kDataTypePrefix, DataType_Name(dtype)); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h index a0c14f27b5b38f..7e95a27f0290f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h @@ -28,7 +28,7 @@ namespace mangling_util { enum class MangledKind { kUnknown, kDataType, kTensorShape, kTensor }; // Mangles an attribute name, marking the attribute as a TensorFlow attribute. -string MangleAttributeName(absl::string_view str); +std::string MangleAttributeName(absl::string_view str); // Returns true if 'str' was mangled with MangleAttributeName. bool IsMangledAttributeName(absl::string_view str); @@ -41,17 +41,17 @@ absl::string_view DemangleAttributeName(absl::string_view str); MangledKind GetMangledKind(absl::string_view str); // Return a TensorShapeProto mangled as a string. -string MangleShape(const TensorShapeProto& shape); +std::string MangleShape(const TensorShapeProto& shape); // Demangle a string mangled with MangleShape. absl::Status DemangleShape(absl::string_view str, TensorShapeProto* proto); // Return a TensorProto mangled as a string. -string MangleTensor(const TensorProto& tensor); +std::string MangleTensor(const TensorProto& tensor); // Demangle a string mangled with MangleTensor. absl::Status DemangleTensor(absl::string_view str, TensorProto* proto); // Return a DataType mangled as a string. -string MangleDataType(const DataType& dtype); +std::string MangleDataType(const DataType& dtype); // Demangle a string mangled with MangleDataType. absl::Status DemangleDataType(absl::string_view str, DataType* proto); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc index abea6d6602b862..e960de8acc494e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc @@ -18,12 +18,77 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/file_system.h" + namespace tensorflow { +namespace { +class WritableStringFile : public tsl::WritableFile { + public: + explicit WritableStringFile(std::string* data) : data_(data) {}; + ~WritableStringFile() override = default; + + absl::Status Append(absl::string_view data) override { + absl::StrAppend(data_, data); + return absl::OkStatus(); + } + + absl::Status Close() override { return absl::OkStatus(); } + absl::Status Flush() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } + + private: + std::string* data_; +}; +} // namespace + +absl::StatusOr SerializeMlirModuleToCompressedBytecode( + mlir::ModuleOp module_op) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module_op, os, config))) { + return absl::InternalError("Failed to serialize MLIR module to bytecode."); + } + std::string compressed_bytecode; + WritableStringFile f(&compressed_bytecode); + + tsl::io::ZlibCompressionOptions options = + tsl::io::ZlibCompressionOptions::GZIP(); + tsl::io::ZlibOutputBuffer buffer(&f, options.input_buffer_size, + options.output_buffer_size, options); + TF_RETURN_IF_ERROR(buffer.Init()); + TF_RETURN_IF_ERROR(buffer.Append(bytecode)); + TF_RETURN_IF_ERROR(buffer.Close()); + return compressed_bytecode; +} std::string SerializeMlirModule(mlir::ModuleOp module_op) { + if (GetMlirCommonFlags()->tf_serialize_mlir_to_compressed_bytecode) { + auto compressed_bytecode = + SerializeMlirModuleToCompressedBytecode(module_op); + if (compressed_bytecode.ok()) { + return compressed_bytecode.value(); + } + LOG_IF(ERROR, !compressed_bytecode.ok()) + << "Failed to serialize MLIR module to " + "compressed bytecode." + << compressed_bytecode.status(); + return ""; + } std::string serialized_mlir_module; llvm::raw_string_ostream os(serialized_mlir_module); mlir::OpPrintingFlags print_flags; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h index 78c7fb6c3857b3..4e264c5f566a9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h @@ -18,10 +18,13 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project namespace tensorflow { - +// Serializes a MLIR module `module_op` to a compressed bytecode string. +absl::StatusOr SerializeMlirModuleToCompressedBytecode( + mlir::ModuleOp module_op); // Prints a MLIR module `module_op` and returns it as a string. std::string SerializeMlirModule(mlir::ModuleOp module_op); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils_test.cc index d4f7eb11f81dff..d373e38cbaacbf 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/deserialize_mlir_module_utils.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -42,5 +43,17 @@ TEST(SerializeMlirModuleUtilsTest, DebugInfoSerialization) { EXPECT_FALSE(absl::StrContains(serialized_module, "loc(")); } +TEST(SerializeMlirModuleUtilsTest, CompressedBytecodeSerializationRoundTrip) { + GetMlirCommonFlags()->tf_serialize_mlir_to_compressed_bytecode = true; + mlir::MLIRContext context; + mlir::OwningOpRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + std::string mlir_module_str = tensorflow::SerializeMlirModule(*module_ref); + mlir::OwningOpRef deserialized_module; + EXPECT_TRUE(tensorflow::DeserializeMlirModule(mlir_module_str, &context, + &deserialized_module) + .ok()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index c9a6f6e85c9d4d..c1479fead3a595 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -133,7 +133,7 @@ absl::Status SetTypeAttribute(absl::string_view name, ContainerT types, type_list.add_type(dtype); } - auto result = values->insert({string(name), value}); + auto result = values->insert({std::string(name), value}); assert(result.second && "cannot have multiple attributes with the same name"); (void)result; @@ -164,7 +164,7 @@ void SetShapeAttribute(absl::string_view name, ContainerT shapes, // If shape is already set, override it. This can happen if we import // without shape inference enabled and so couldn't be removed on import and // are not explicitly dropped later. - (*values)[string(name)] = value; + (*values)[std::string(name)] = value; } // Collects all the unregistered attributes for an TF dialect operation. diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc index 8cb797a9a9b214..b13e099fde3557 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc @@ -214,7 +214,7 @@ absl::StatusOr> BuildConstOpGraphWithOutputShapes() { std::initializer_list dims = {2, 3, 4, 5}; Tensor tensor(data_type, TensorShape(dims)); for (int i = 0; i < 2 * 3 * 4 * 5; ++i) { - tensor.flat()(i) = i; + tensor.flat()(i) = i; } NodeDef node; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index 46f7f5de1d0856..74b7304b745033 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -106,9 +106,9 @@ namespace { // Time the execution of kernels (in CPU cycles). Meant to be used as RAII. struct CompilationTimer { - uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); + uint64_t start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); - uint64 ElapsedCycles() { + uint64_t ElapsedCycles() { return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles; } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 243f4333a88525..2ab0c3c619b292 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -4864,7 +4864,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ + std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; @@ -5064,7 +5064,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ + std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 159bc8b17bc36b..a6ee4c3e1ffbd0 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -308,7 +308,7 @@ py_strict_library( "//tensorflow/python/framework:op_def_registry", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:tf_inspect", - "@pypi_gast//:pkg", + "@pypi//gast", ], ) @@ -339,7 +339,7 @@ py_strict_library( "//tensorflow/python/autograph/pyct:transpiler", "//tensorflow/python/framework:op_def_registry", "//tensorflow/python/util:tf_inspect", - "@pypi_gast//:pkg", + "@pypi//gast", ], ) diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.cc b/tensorflow/compiler/mlir/tfr/utils/utils.cc index f9e70b228c0b71..ddff766c789450 100644 --- a/tensorflow/compiler/mlir/tfr/utils/utils.cc +++ b/tensorflow/compiler/mlir/tfr/utils/utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" #include "mlir/IR/Block.h" // from @llvm-project @@ -92,9 +93,9 @@ std::string GetComposeFuncName(StringRef tf_op_name) { } if (tf_op_name[i] == '.') { compose_func_name.push_back('_'); - } else if (tf_op_name[i] >= 'A' && tf_op_name[i] <= 'Z') { + } else if (llvm::isUpper(tf_op_name[i])) { compose_func_name.push_back('_'); - compose_func_name.push_back(tf_op_name[i] + 'a' - 'A'); + compose_func_name.push_back(llvm::toLower(tf_op_name[i])); } else { compose_func_name.push_back(tf_op_name[i]); } @@ -106,13 +107,13 @@ std::string GetTFOpName(StringRef compose_func_name) { std::string tf_op_name; bool after_underscore = false; for (int i = 0; i < compose_func_name.size(); ++i) { - if (compose_func_name[i] >= 'A' && compose_func_name[i] <= 'Z') { + if (llvm::isUpper(compose_func_name[i])) { // The field name must not contain uppercase letters. return {}; } if (after_underscore) { - if (compose_func_name[i] >= 'a' && compose_func_name[i] <= 'z') { - tf_op_name.push_back(compose_func_name[i] + 'A' - 'a'); + if (llvm::isLower(compose_func_name[i])) { + tf_op_name.push_back(llvm::toUpper(compose_func_name[i])); after_underscore = false; } else { // The character after a "_" must be a lowercase letter. diff --git a/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir index 42e2e7ccb5086a..d6a4b0c3fcbf97 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir @@ -195,6 +195,28 @@ func.func @sink_in_stateful_call(%arg0: tensor {tf_saved_model.index_path = func.return %2 : tensor } +// Test VarHandleOp getting sinked when it is used by the called function and returned by the called function. + +// CHECK: func private @func_use_and_return_varhandle([[arg0:.+]]: tensor>>) +func.func private @func_use_and_return_varhandle(%arg0: tensor>>) -> (tensor, tensor>>) { + // CHECK: tf.VarHandleOp + // CHECK-NEXT: tf.ReadVariableOp + %0 = "tf.ReadVariableOp"(%arg0) {device = "cpu"} : (tensor>>) -> tensor + + func.return %0, %arg0 : tensor, tensor>> +} + +// CHECK-LABEL: func @sink_in_stateful_call_varhandle_return +func.func @sink_in_stateful_call_varhandle_return(%arg0: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["test_sink_in_stateful_call_varhandle_return"]} { + // CHECK: tf.VarHandleOp + %0 = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"(%0) + %1:2 = "tf.StatefulPartitionedCall"(%0) {device = "/CPU:0", config = "", config_proto = "", executor_type = "", f = @func_use_and_return_varhandle} : (tensor>>) -> (tensor, tensor>>) + %2 = "tf.AddV2"(%arg0, %1#0) {device = "/CPU:0"} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + // CHECK-LABEL: func @sink_in_if func.func @sink_in_if(%arg0: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) attributes {tf_saved_model.exported_names = ["test_sink_in_if"]} { @@ -374,3 +396,54 @@ func.func @nested_sink_in_if(%arg: tensor {tf_saved_model.index_path = ["in } } + +// ----- + +module attributes {tf_saved_model.semantics} { + +// Test sinks crossing nested tf.While and BatchFunction, while the sinkable ops are only copied at the target. + +// CHECK-LABEL: func private @batched_function +func.func private @batched_function(%arg0: tensor>>) -> tensor + attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { + // CHECK: tf.VarHandleOp + // CHECK-NEXT: tf.ReadVariableOp + %1 = "tf.ReadVariableOp"(%arg0) {device = "/device:CPU:0"} : (tensor>>) -> tensor + %2 = "tf.Identity"(%1) {device = "/device:CPU:0"} : (tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: func private @while_cond_func +func.func private @while_cond_func( + %arg0: tensor, + %arg1: tensor, + %arg: tensor>>) -> tensor { + // CHECK: [[handle:%.*]] = "tf.VarHandleOp"() + // CHECK: "tf.ReadVariableOp"([[handle]]) + %0 = "tf.ReadVariableOp"(%arg) {device = "cpu"} : (tensor>>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func private @while_body_func +func.func private @while_body_func( + %arg0: tensor, + %arg1: tensor, + %arg2: tensor>>) -> (tensor, tensor, tensor>>) { + // CHECK: "tf.BatchFunction"(%arg2) + %0 = "tf.BatchFunction"(%arg2) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operandSegmentSizes = array, shared_name = "batch/"} : (tensor>>) -> tensor + func.return %0, %arg0, %arg2 : tensor, tensor, tensor>> +} + +// CHECK-LABEL: func @nested_sink_in_while_and_batch_functions +func.func @nested_sink_in_while_and_batch_functions(%arg: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["test_sink_in_while_and_batch_functions"]} { + // CHECK: [[handle:%.*]] = "tf.VarHandleOp"() + %handle = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> + // CHECK: [[cond:%.*]] = "tf.Const"() + %cond = "tf.Const"() {device = "/CPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: "tf.While"([[cond]], [[cond]], [[handle]]) + %x:3 = "tf.While"(%cond, %cond, %handle) {body = @while_body_func, cond = @while_cond_func, is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor, tensor, tensor>>) -> (tensor, tensor, tensor>>) + func.return %x#0 : tensor +} + +} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc index 32898953f8973e..5340015658621a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc @@ -81,10 +81,10 @@ class IfrtBackendCompilerTest : public ::testing::Test { } void verifyModules() { - absl::MutexLock l(&ServingExecutableRegistry::mu_); + absl::MutexLock l(ServingExecutableRegistry::mu_); for (const auto& [_, executable] : *ServingExecutableRegistry::executables_) { - absl::MutexLock l(&executable->mutex_); + absl::MutexLock l(executable->mutex_); executable->module_->walk([](mlir::func::FuncOp func) { ASSERT_FALSE(func->hasAttr("tfrt_ifrt_serving.program_id")); }); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h index c64672cdb10e69..9d0efd51791b87 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h @@ -25,6 +25,10 @@ namespace ifrt_serving { struct DtypeAndShape { tensorflow::DataType dtype; tensorflow::TensorShape shape; + + bool operator==(const DtypeAndShape& other) const { + return dtype == other.dtype && shape == other.shape; + } }; } // namespace ifrt_serving diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h index 2cb92cb8baac1f..6ff373d7ce0b43 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -77,6 +77,8 @@ class TfToHloCompiler { virtual absl::StatusOr Key(const Tf2HloArg& arg); virtual absl::StatusOr CompileTfToHlo(Tf2HloArg& arg); + + virtual bool IsXlaCompilationDisabled() const { return false; } }; } // namespace ifrt_serving diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index cc59c9150da769..7f4a602b1330a6 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -906,10 +906,6 @@ void CreateFallbackInitializationFunction( builder.create( func_op.getLoc(), /*resultTypes=*/mlir::TypeRange{}, /*operands=*/mlir::ValueRange{}, op->getAttrs()); - } else { - // TODO: b/381849919 - Remove this log once the bug is fixed. - LOG_FIRST_N(WARNING, 100) - << "Skip creation of fallback kernel for op index " << op_index; } } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index ddff1b2bde43f9..990b3da433c327 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -152,7 +152,6 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( pm.addPass(mlir::createInlinerPass()); pm.addNestedPass( mlir::TF::CreateRemoveUnusedWhileResultsPass()); - pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); // Apply standard optimization after optimizing control flow ops. pm.addPass(mlir::createInlinerPass()); @@ -163,6 +162,7 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( // by performing shape inference again after reference variable to resource // variable conversion. We should remove this after b/187876545 is fixed. pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); pm.addNestedPass( mlir::TFDevice::CreateLaunchToDeviceAttributePass()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc index 4615c521edb059..fddb217d4c57ee 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc @@ -49,15 +49,28 @@ bool IsSinkCandidate(mlir::Operation *op) { // Check if the op is allowed to be sinked. We are being conservative here to // whilelist very limited set of ops here. struct AllowSinkHelper { - explicit AllowSinkHelper(mlir::Operation *op, int arg_index) { + explicit AllowSinkHelper(mlir::Operation* sinked_op, mlir::Operation* user, + int arg_index) { if (llvm::isa(op)) { + mlir::TF::StatefulPartitionedCallOp>(user)) { allow_sink_to = true; callee_arg_index = arg_index; return; } - if (llvm::isa(op) && arg_index > 0) { + // We tend to limit this support on WhileOp to only VarHandleOp to satisfy + // IFRT lowering requirements. + // Sinking other invariants like ConstOp is error-prone because it requires + // non-trivial effort to avoid sinking Consts when they are used by cond + // function and we don't need such support. + if (llvm::isa(sinked_op) && + llvm::isa(user)) { + allow_sink_to = true; + callee_arg_index = arg_index; + return; + } + + if (llvm::isa(user) && arg_index > 0) { allow_sink_to = true; callee_arg_index = arg_index - 1; return; @@ -107,7 +120,8 @@ void FindSinkTarget( for (mlir::OpOperand &use : value.getUses()) { auto *user = use.getOwner(); - AllowSinkHelper helper(user, use.getOperandNumber()); + AllowSinkHelper helper(original.getDefiningOp(), user, + use.getOperandNumber()); if (helper.allow_sink_to) { auto values = FindValueInCallees(symbol_table, symbol_users, user, @@ -116,6 +130,14 @@ void FindSinkTarget( FindSinkTarget(symbol_table, symbol_users, original, value, targets); } } else if (value != original) { + // If the sinked op is directly used by ReturnOp, we don't sink it. + // One example is for tf.WhileOp, the input and output of the cond + // function and the body function must be the same. If the cond function + // has an input of type tf.VarHandleOp and it just return the VarHandleOp, + // we don't need to sink it. + if (llvm::isa(user)) { + continue; + } targets[&use].insert(original); } } diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index e8004f17a24b47..d6d93d9f2d6f34 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -204,11 +204,18 @@ absl::Status ConvertTfMlirToRuntimeExecutable( tensorflow::tf2xla::v2::RunFunctionTf2xlaClusteringBridge( module, /*is_supported_by_replicated_brige*/ true, /*is_in_fallback_enabled_mode=*/VLOG_IS_ON(1))); + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("after_tf2xla_clustering_bridge", module); + } TF_RETURN_IF_ERROR( tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline( module, tsl::DeviceType(DEVICE_TPU_XLA_JIT))); + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("after_lower_cluster_to_runtime_ops", + module); + } TF_RETURN_IF_ERROR( tensorflow::tf2xla::v2::ExportFromTensorflowDialectToExecutor(module)); } else if (options.device_target == TfrtDeviceInfraTarget::kTfFallback) { diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD index b1a4a4b96f3b72..527a724c491b96 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD @@ -19,6 +19,7 @@ cc_library( srcs = ["mlir_to_bytecode.cc"], hdrs = ["mlir_to_bytecode.h"], deps = [ + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", "//tensorflow/core/tfrt/mlrt/bytecode:function", @@ -43,6 +44,7 @@ tf_cc_test( data = glob(["testdata/**"]), deps = [ ":mlir_to_bytecode", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", "//tensorflow/core/tfrt/mlrt/interpreter:attribute_span", diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc index 52b1826f4a1f65..2324f958f19266 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/bytecode/function.h" @@ -169,19 +170,26 @@ struct FunctionEmitterContext { struct RegInfo { int num_uses = 0; int id = -1; + bool persistent = false; // True if the register should not be freed }; int next_reg_id = 0; llvm::DenseMap register_table; std::vector free_regs; - int AssignRegId() { - if (free_regs.empty()) { + int AssignRegId(bool is_persistent) { + if (is_persistent) { + // Persistent types ALWAYS get a brand new ID. return next_reg_id++; } - int id = free_regs.back(); - free_regs.pop_back(); - return id; + + // Non-persistent types can reuse from free_regs. + if (!free_regs.empty()) { + int id = free_regs.back(); + free_regs.pop_back(); + return id; + } + return next_reg_id++; } void FreeRegId(int id) { free_regs.push_back(id); } @@ -202,7 +210,7 @@ void EmitKernel(FunctionEmitterContext& function_context, auto iter = function_context.register_table.find(result); CHECK(iter != function_context.register_table.end()); // Crash Ok CHECK_EQ(iter->second.id, -1); // Crash Ok - iter->second.id = function_context.AssignRegId(); + iter->second.id = function_context.AssignRegId(iter->second.persistent); results.push_back(iter->second.id); } constructor.construct_results(results.size()) @@ -218,9 +226,12 @@ void EmitKernel(FunctionEmitterContext& function_context, int id = iter->second.id; CHECK_NE(id, -1); // Crash Ok last_uses.push_back(0); - if (--iter->second.num_uses == 0) { - function_context.FreeRegId(id); - last_uses.back() = 1; + auto& reg_info = iter->second; + if (!reg_info.persistent) { + if (--reg_info.num_uses == 0) { + function_context.FreeRegId(id); + last_uses.back() = 1; + } } arguments.push_back(id); } @@ -282,18 +293,23 @@ void EmitFunction(const ModuleEmitterContext& module_context, std::vector input_regs; input_regs.reserve(block.getNumArguments()); for (auto arg : block.getArguments()) { - int id = function_context.AssignRegId(); + bool persistent = mlir::isa(arg.getType()); + int id = function_context.AssignRegId(persistent); input_regs.push_back(id); register_table[arg] = {static_cast(std::distance(arg.getUses().begin(), arg.getUses().end())), - id}; + id, persistent}; } constructor.construct_input_regs(input_regs); for (auto& op : block) { for (auto result : op.getResults()) { - register_table[result] = {static_cast( - std::distance(result.getUses().begin(), result.getUses().end()))}; + bool persistent = + mlir::isa(result.getType()); + register_table[result] = { + static_cast( + std::distance(result.getUses().begin(), result.getUses().end())), + -1, persistent}; } } diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc index b69f53cc9c7a2c..53f2e7591c8a9a 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" #include "xla/tsl/platform/resource_loader.h" #include "xla/tsl/platform/status_matchers.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" @@ -45,8 +46,6 @@ namespace { using ::testing::ElementsAreArray; using ::testing::FloatEq; using ::testing::IsEmpty; -using ::tsl::testing::IsOkAndHolds; -using ::tsl::testing::StatusIs; TEST(MlirToByteCodeTest, Basic) { constexpr char kBasicMlir[] = @@ -147,16 +146,20 @@ TEST(MlirToByteCodeTest, BasicAttributes) { EXPECT_EQ(*attr_iter, "ts"); ++attr_iter; - EXPECT_THAT(DecodeAttribute(*attr_iter), IsOkAndHolds(100)); + EXPECT_THAT(DecodeAttribute(*attr_iter), + absl_testing::IsOkAndHolds(100)); ++attr_iter; - EXPECT_THAT(DecodeAttribute(*attr_iter), IsOkAndHolds(200)); + EXPECT_THAT(DecodeAttribute(*attr_iter), + absl_testing::IsOkAndHolds(200)); ++attr_iter; - EXPECT_THAT(DecodeAttribute(*attr_iter), IsOkAndHolds(FloatEq(3.0))); + EXPECT_THAT(DecodeAttribute(*attr_iter), + absl_testing::IsOkAndHolds(FloatEq(3.0))); ++attr_iter; - EXPECT_THAT(DecodeAttribute(*attr_iter), IsOkAndHolds(0)); + EXPECT_THAT(DecodeAttribute(*attr_iter), + absl_testing::IsOkAndHolds(0)); ++attr_iter; bc::Vector list_of_i64((*attr_iter).data()); @@ -171,7 +174,8 @@ TEST(MlirToByteCodeTest, BasicAttributes) { EXPECT_THAT(list_of_str, ElementsAreArray({"string 0", "string 1"})); ++attr_iter; - EXPECT_THAT(DecodeAttribute(*attr_iter), IsOkAndHolds(1)); + EXPECT_THAT(DecodeAttribute(*attr_iter), + absl_testing::IsOkAndHolds(1)); EXPECT_EQ(executable.functions()[1].name().Get(), "callee"); ++attr_iter; @@ -272,9 +276,10 @@ TEST(MlirToByteCodeTest, UnsupportedAttributes) { &mlir_context); AttributeEncoderRegistry attribute_encoder_registry; - EXPECT_THAT(EmitExecutable(attribute_encoder_registry, mlir_module.get()), - StatusIs(absl::StatusCode::kInvalidArgument, - "Try to encode unsupported attribute: unit")); + EXPECT_THAT( + EmitExecutable(attribute_encoder_registry, mlir_module.get()), + absl_testing::StatusIs(absl::StatusCode::kInvalidArgument, + "Try to encode unsupported attribute: unit")); } class CustomDense { @@ -378,5 +383,129 @@ TEST(MlirToByteCodeTest, CustomDense) { } } +TEST(MlirToByteCodeTest, AsyncNotFreed) { + constexpr char kAsyncMlir[] = + "tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/async.mlir"; + + mlir::DialectRegistry registry; + registry.insert(); + mlir::MLIRContext mlir_context(registry); + mlir_context.allowUnregisteredDialects(); + auto mlir_module = mlir::parseSourceFile( + tsl::GetDataDependencyFilepath(kAsyncMlir), &mlir_context); + + AttributeEncoderRegistry attribute_encoder_registry; + bc::Buffer buffer = + EmitExecutable(attribute_encoder_registry, mlir_module.get()).value(); + + bc::Executable executable(buffer.data()); + + auto kernel_names = executable.kernel_names(); + EXPECT_THAT(kernel_names, + ElementsAreArray({"test_mlbc.add.i32", "return", "mlrt.async", + "mlrt.await_handle"})); + + auto functions = executable.functions(); + ASSERT_EQ(functions.size(), 2); + + auto function = functions[1]; + EXPECT_EQ(function.name().str(), "main"); + EXPECT_EQ(function.num_regs(), 4); + EXPECT_THAT(function.input_regs(), ElementsAreArray({0, 1})); + EXPECT_THAT(function.output_regs(), ElementsAreArray({1})); + EXPECT_THAT(function.output_last_uses(), ElementsAreArray({true})); + + auto kernels = function.kernels(); + ASSERT_EQ(kernels.size(), 5); + + EXPECT_EQ(kernels[0].code(), 2); // mlrt.async + EXPECT_THAT(kernels[0].arguments(), ElementsAreArray({0, 1})); + // The returned handle is in register 2, which is never used by other kernels. + EXPECT_THAT(kernels[0].results(), ElementsAreArray({2})); + EXPECT_THAT(kernels[0].last_uses(), ElementsAreArray({false, false})); + + EXPECT_EQ(kernels[1].code(), 3); // mlrt.await_handle + EXPECT_THAT(kernels[1].arguments(), ElementsAreArray({2})); + EXPECT_THAT(kernels[1].results(), IsEmpty()); + + EXPECT_EQ(kernels[2].code(), 0); // test_mlbc.add.i32 + EXPECT_THAT(kernels[2].arguments(), ElementsAreArray({0, 1})); + EXPECT_THAT(kernels[2].results(), ElementsAreArray({3})); + EXPECT_THAT(kernels[2].last_uses(), ElementsAreArray({true, true})); + + EXPECT_EQ(kernels[3].code(), 0); // test_mlbc.add.i32 + EXPECT_THAT(kernels[3].arguments(), ElementsAreArray({3, 3})); + EXPECT_THAT(kernels[3].results(), ElementsAreArray({1})); + EXPECT_THAT(kernels[3].last_uses(), ElementsAreArray({false, true})); + + EXPECT_EQ(kernels[4].code(), 1); // return + EXPECT_THAT(kernels[4].arguments(), ElementsAreArray({1})); + EXPECT_THAT(kernels[4].results(), IsEmpty()); + EXPECT_THAT(kernels[4].last_uses(), ElementsAreArray({true})); +} + +TEST(MlirToByteCodeTest, AsyncUseNewId) { + constexpr char kAsyncMlir[] = + "tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/async2.mlir"; + + mlir::DialectRegistry registry; + registry.insert(); + mlir::MLIRContext mlir_context(registry); + mlir_context.allowUnregisteredDialects(); + auto mlir_module = mlir::parseSourceFile( + tsl::GetDataDependencyFilepath(kAsyncMlir), &mlir_context); + + AttributeEncoderRegistry attribute_encoder_registry; + bc::Buffer buffer = + EmitExecutable(attribute_encoder_registry, mlir_module.get()).value(); + + bc::Executable executable(buffer.data()); + + auto kernel_names = executable.kernel_names(); + EXPECT_THAT(kernel_names, + ElementsAreArray({"test_mlbc.add.i32", "return", "mlrt.async", + "mlrt.await_handle"})); + + auto functions = executable.functions(); + ASSERT_EQ(functions.size(), 2); + + auto function = functions[1]; + EXPECT_EQ(function.name().str(), "main"); + EXPECT_EQ(function.num_regs(), 4); + EXPECT_THAT(function.input_regs(), ElementsAreArray({0, 1})); + EXPECT_THAT(function.output_regs(), ElementsAreArray({1})); + EXPECT_THAT(function.output_last_uses(), ElementsAreArray({true})); + + auto kernels = function.kernels(); + ASSERT_EQ(kernels.size(), 5); + + EXPECT_EQ(kernels[0].code(), 0); // test_mlbc.add.i32 + EXPECT_THAT(kernels[0].arguments(), ElementsAreArray({0, 1})); + EXPECT_THAT(kernels[0].results(), ElementsAreArray({2})); + EXPECT_THAT(kernels[0].last_uses(), ElementsAreArray({true, true})); + + EXPECT_EQ(kernels[1].code(), 2); // mlrt.async + EXPECT_THAT(kernels[1].arguments(), ElementsAreArray({2, 2})); + // The returned handle is in register 3, which is never used by other kernels. + EXPECT_THAT(kernels[1].results(), ElementsAreArray({3})); + EXPECT_THAT(kernels[1].last_uses(), ElementsAreArray({false, false})); + + EXPECT_EQ(kernels[2].code(), 3); // mlrt.await_handle + EXPECT_THAT(kernels[2].arguments(), ElementsAreArray({3})); + EXPECT_THAT(kernels[2].results(), IsEmpty()); + EXPECT_THAT(kernels[2].last_uses(), ElementsAreArray({false})); + + EXPECT_EQ(kernels[3].code(), 0); // test_mlbc.add.i32 + EXPECT_THAT(kernels[3].arguments(), ElementsAreArray({2, 2})); + // AsyncHandle does not free its register. So this can only use 1. + EXPECT_THAT(kernels[3].results(), ElementsAreArray({1})); + EXPECT_THAT(kernels[3].last_uses(), ElementsAreArray({false, true})); + + EXPECT_EQ(kernels[4].code(), 1); // return + EXPECT_THAT(kernels[4].arguments(), ElementsAreArray({1})); + EXPECT_THAT(kernels[4].results(), IsEmpty()); + EXPECT_THAT(kernels[4].last_uses(), ElementsAreArray({true})); +} + } // namespace } // namespace mlrt diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/async.mlir b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/async.mlir new file mode 100644 index 00000000000000..f3816531218c81 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/async.mlir @@ -0,0 +1,15 @@ +func.func @add_i32(%arg0: i32, %arg1: i32) -> i32 { + %0 = "test_mlbc.add.i32"(%arg0, %arg1) : (i32, i32) -> i32 + func.return %0 : i32 +} + +func.func @main(%arg0: i32, %arg1: i32) -> i32 { + %handle = "mlrt.async"(%arg0, %arg1) {callee = @add_i32} : (i32, i32) -> !mlrt.async_handle + + "mlrt.await_handle"(%handle) : (!mlrt.async_handle) -> () + + %c1 = "test_mlbc.add.i32"(%arg0, %arg1) : (i32, i32) -> i32 + %c2 = "test_mlbc.add.i32"(%c1, %c1) : (i32, i32) -> i32 + + func.return %c2 : i32 +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/async2.mlir b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/async2.mlir new file mode 100644 index 00000000000000..c960fedd2adc25 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/async2.mlir @@ -0,0 +1,16 @@ +func.func @add_i32(%arg0: i32, %arg1: i32) -> i32 { + %0 = "test_mlbc.add.i32"(%arg0, %arg1) : (i32, i32) -> i32 + func.return %0 : i32 +} + +func.func @main(%arg0: i32, %arg1: i32) -> i32 { + %c1 = "test_mlbc.add.i32"(%arg0, %arg1) : (i32, i32) -> i32 + + %handle = "mlrt.async"(%c1, %c1) {callee = @add_i32} : (i32, i32) -> !mlrt.async_handle + + "mlrt.await_handle"(%handle) : (!mlrt.async_handle) -> () + + %c2 = "test_mlbc.add.i32"(%c1, %c1) : (i32, i32) -> i32 + + func.return %c2 : i32 +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 7b3625103efc1f..079500b8cd1ccf 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -163,9 +163,11 @@ class GpuKernelToBlobPass target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; }; - TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx( - llvm_module_copy.get(), cc, - options, enable_fusion)); + TF_ASSIGN_OR_RETURN( + std::string ptx, + xla::gpu::nvptx::CompileToPtx( + llvm_module_copy.get(), stream_executor::GpuComputeCapability(cc), + options, enable_fusion)); if (print_ptx_) { llvm::dbgs() << "Generated PTX code for module '" << gpu_module.getName() << "' on architecture sm_" << arch diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.h b/tensorflow/compiler/mlir/tosa/tfl_passes.h index 96d3cabf0c1f1f..02bd007f6fa36c 100644 --- a/tensorflow/compiler/mlir/tosa/tfl_passes.h +++ b/tensorflow/compiler/mlir/tosa/tfl_passes.h @@ -42,8 +42,8 @@ struct TOSATFLLegalizationPipelineOptions llvm::cl::desc("Dequantize the TFLite softmax"), llvm::cl::init(false)}; TOSATFLLegalizationPipelineOptions() { - disabled_patterns = std::nullopt; - enabled_patterns = std::nullopt; + disabled_patterns = {}; + enabled_patterns = {}; } }; diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h index de0872b660d4ec..0475d46a37a091 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.h +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h @@ -53,8 +53,8 @@ std::unique_ptr> createFuseBiasTFPass(); // `enabledPatterns` is a set of labels used to filter out input patterns that // do not have one of the labels in this set. std::unique_ptr> createLegalizeTFLPass( - ArrayRef disabled_patterns = std::nullopt, - ArrayRef enabled_patterns = std::nullopt); + ArrayRef disabled_patterns = {}, + ArrayRef enabled_patterns = {}); std::unique_ptr> createRetainCallOnceFuncsPass(); std::unique_ptr> createStripModuleMetadataPass(); diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc index fd50116ba7d1a7..fb5bb77644c211 100644 --- a/tensorflow/compiler/mlir/utils/name_utils.cc +++ b/tensorflow/compiler/mlir/utils/name_utils.cc @@ -31,8 +31,8 @@ namespace { // Checks if a character is legal for a TensorFlow node name, with special // handling if a character is at the beginning. bool IsLegalChar(char c, bool first_char) { - if (isalpha(c)) return true; - if (isdigit(c)) return true; + if (llvm::isAlpha(c)) return true; + if (llvm::isDigit(c)) return true; if (c == '.') return true; if (c == '_') return true; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f957ec4b08e322..995ae2b5740ae7 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -2,7 +2,14 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_cuda_cc_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/tests:build_combined_defs.bzl", "tf_xla_combined_py_test") -load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites", "tf_xla_py_strict_test") +load( + "//tensorflow/compiler/tests:build_defs.bzl", + "generate_backend_suites", + "tf_xla_py_strict_test", + # copybara:uncomment_begin(google-only) + # "tpu_backends", + # copybara:uncomment_end +) load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", @@ -93,6 +100,7 @@ py_strict_test( tf_xla_combined_py_test( name = "combined_ops_test_a", size = "medium", + timeout = "long", package = "tensorflow.compiler.tests", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -213,9 +221,8 @@ tf_xla_combined_py_test( name = "combined_ops_test_f", size = "medium", timeout = "long", - # copybara:uncomment_begin - # #TODO(b/286470564): Remove once the bug is fixed. - # disable_tpu_tfrt = True, + # copybara:uncomment_begin(google-only) + # disabled_backends = tpu_backends(), # copybara:uncomment_end exec_properties = { "cpp_link.mem": "16g", @@ -340,10 +347,6 @@ tf_xla_py_strict_test( name = "add_n_test", size = "small", srcs = ["add_n_test.py"], - # copybara:uncomment_begin - # #TODO(b/286470564): Remove once the bug is fixed. - # disable_tpu_tfrt = True, - # copybara:uncomment_end tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", @@ -496,10 +499,6 @@ tf_xla_py_strict_test( name = "cond_test", size = "small", srcs = ["cond_test.py"], - # copybara:uncomment_begin - # #TODO(b/286470564): Remove once the bug is fixed. - # disable_tpu_tfrt = True, - # copybara:uncomment_end tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", @@ -1746,12 +1745,8 @@ tf_xla_py_strict_test( name = "tensor_list_ops_test", size = "small", srcs = ["tensor_list_ops_test.py"], - # copybara:uncomment_begin - # #TODO(b/286470564): Remove once the bug is fixed. - # disable_tpu_tfrt = True, - # copybara:uncomment_end - # TensorList ops are not implemented in the on-demand compilation model yet. - disabled_backends = ["cpu_ondemand"], + # TensorList ops are only implemented on CPU. + enabled_backends = ["cpu"], tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -1906,10 +1901,6 @@ tf_xla_py_strict_test( name = "while_test", size = "small", srcs = ["while_test.py"], - # copybara:uncomment_begin - # #TODO(b/291130193): Remove once the bug is fixed. - # disable_tpu_tfrt = True, - # copybara:uncomment_end tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", @@ -2082,7 +2073,7 @@ tf_xla_py_strict_test( tf_xla_py_strict_test( name = "xla_device_test", - size = "small", + size = "medium", srcs = ["xla_device_test.py"], tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2166,7 +2157,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - env = {"XLA_FLAGS": "--xla_backend_extra_options=xla_cpu_disable_new_fusion_emitters=true"}, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -2430,9 +2420,6 @@ tf_xla_py_strict_test( name = "where_op_tpu_test", size = "small", srcs = ["where_op_test.py"], - args = [ - "--tpu_use_tfrt=true", - ], disabled_backends = [ "cpu", "cpu_ondemand", diff --git a/tensorflow/compiler/tests/cast_test.py b/tensorflow/compiler/tests/cast_test.py index bc35db4e05f7d5..453cbeb1312648 100644 --- a/tensorflow/compiler/tests/cast_test.py +++ b/tensorflow/compiler/tests/cast_test.py @@ -35,9 +35,10 @@ def test_cast(self): dtypes.uint32, dtypes.uint64, } - for src_type in types: - for dst_type in types: - self._test_cast(src_type, dst_type) + with self.session() as session: + for src_type in types: + for dst_type in types: + self._test_cast(src_type, dst_type, session) def test_cast_fp8(self): if platform.system() == "Darwin": @@ -61,12 +62,13 @@ def test_cast_fp8(self): dtypes.uint32, dtypes.uint64, } - for fp8_type in fp8_types: - for other_type in other_types | fp8_types: - self._test_cast(fp8_type, other_type) - self._test_cast(other_type, fp8_type) + with self.session() as session: + for fp8_type in fp8_types: + for other_type in other_types | fp8_types: + self._test_cast(fp8_type, other_type, session) + self._test_cast(other_type, fp8_type, session) - def _test_cast(self, src_type, dst_type): + def _test_cast(self, src_type, dst_type, session): with self.subTest(src_type=src_type, dst_type=dst_type): shapes = [[], [4], [2, 3], [2, 0, 4]] src_np_dtype = src_type.as_numpy_dtype @@ -83,6 +85,7 @@ def _test_cast(self, src_type, dst_type): lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), src, expected=dst, + local_session=session, ) # Check special values. @@ -112,6 +115,7 @@ def _test_cast(self, src_type, dst_type): lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), src, expected=dst, + local_session=session, ) def test_give_me_a_name(self): diff --git a/tensorflow/compiler/tests/float_ops_test.py b/tensorflow/compiler/tests/float_ops_test.py index d8743016c20756..67a1ecc967f24c 100644 --- a/tensorflow/compiler/tests/float_ops_test.py +++ b/tensorflow/compiler/tests/float_ops_test.py @@ -23,449 +23,522 @@ class FloatOpsTest(xla_test.XLATestCase): def test_float_ops(self): - for dtype in self.float_types: - x = np.arange(-0.90, 0.90, 0.25) - self.assert_op_output_matches_expected( - math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype) - ) - self.assert_op_output_matches_expected( - math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype) - ) - x = np.arange(-3, 3).reshape(1, 3, 2) - self.assert_op_output_matches_expected( - math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype) - ) - - self.assert_op_output_matches_expected( - math_ops.acosh, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.asinh, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.atanh, - np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), - expected=np.array( - [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.ceil, - np.array([[-1.7, 1.2]], dtype=dtype), - expected=np.array([[-1, 2]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.cosh, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype - ), - ) - - # Disable float16 testing for now - if dtype != np.float16: - x = np.arange(-10, 10, 1).astype(dtype) - with self.session() as session: + with self.session() as session: + for dtype in self.float_types: + x = np.arange(-0.90, 0.90, 0.25) + self.assert_op_output_matches_expected( + math_ops.acos, + x.astype(dtype), + expected=np.arccos(x).astype(dtype), + local_session=session, + ) + self.assert_op_output_matches_expected( + math_ops.asin, + x.astype(dtype), + expected=np.arcsin(x).astype(dtype), + local_session=session, + ) + x = np.arange(-3, 3).reshape(1, 3, 2) + self.assert_op_output_matches_expected( + math_ops.atan, + x.astype(dtype), + expected=np.arctan(x).astype(dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.acosh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.asinh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.atanh, + np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), + expected=np.array( + [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.ceil, + np.array([[-1.7, 1.2]], dtype=dtype), + expected=np.array([[-1, 2]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.cosh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype + ), + local_session=session, + ) + + # Disable float16 testing for now + if dtype != np.float16: + x = np.arange(-10, 10, 1).astype(dtype) erf_x = session.run(math_ops.erf(x)) erfc_x = session.run(math_ops.erfc(x)) - self.assert_op_output_matches_expected(math_ops.erf, x, expected=erf_x) - self.assert_op_output_matches_expected( - math_ops.erfc, x, expected=erfc_x - ) - - self.assert_op_output_matches_expected( - math_ops.exp, - np.array([[-1, 1]], dtype=dtype), - expected=np.array([[0.36787945, 2.7182817]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.expm1, - np.array([[-1, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), - rtol=1e-5, - ) - - self.assert_op_output_matches_expected( - math_ops.floor, - np.array([[-1.7, 1.2]], dtype=dtype), - expected=np.array([[-2, 1]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.is_finite, - np.array( - [[-np.inf, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype - ), - expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool_), - ) - - # Tests for tf.nn ops. - self.assert_op_output_matches_expected( - nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0) - ) - - self.assert_op_output_matches_expected(nn_ops.l2_loss, dtype(4), dtype(8)) - - self.assert_op_output_matches_expected( - nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10) - ) - - self.assert_op_output_matches_expected( - math_ops.reciprocal, - np.array([[1, 2]], dtype=dtype), - expected=np.array([[1, 0.5]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.log, - np.array([[1, 2]], dtype=dtype), - expected=np.array([[0, 0.69314718]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.sin, - np.array([[1, 2]], dtype=dtype), - expected=np.array([[0.841478, 0.909302]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.cos, - np.array([[1, 2]], dtype=dtype), - expected=np.array([[0.540297, -0.41614]], dtype=dtype), - ) - - # Confirm that log1p will remain precise across a range of small values. - self.assert_op_output_matches_expected( - math_ops.log1p, - np.array( - [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], - dtype=dtype, - ), - expected=np.log1p( - np.array( - [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], - dtype=dtype, - ) - ).astype(dtype), - rtol=1e-15 if dtype == np.float64 else 1e-4, - atol=1e-15 if dtype == np.float64 else 1e-4, - ) - - self.assert_op_output_matches_expected( - math_ops.rint, - np.array( - [ - [-1.7, 1.2, 4.0, 0.0], - [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5], - ], - dtype=dtype, - ), - expected=np.array( - [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype - ), - ) - self.assert_op_output_matches_expected( - math_ops.round, - np.array( - [ - [-1.7, 1.2, 4.0, 0.0], - [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5], - ], - dtype=dtype, - ), - expected=np.array( - [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.rsqrt, - np.array([[4, 16]], dtype=dtype), - expected=np.array([[0.5, 0.25]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.sigmoid, - np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), - expected=np.array( - [ - [0.7310586, 0.7310586, 0.7310586, 0.7310586], - [0.7310586, 0.880797, 0.95257413, 0.98201376], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - math_ops.sigmoid, - np.array([-300, -150, 0, 150, 300], dtype=dtype), - expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.sinh, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.sqrt, - np.array([[4, 9]], dtype=dtype), - expected=np.array([[2, 3]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.tan, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.tanh, - np.array( - [[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], [19, -19, 22, -22]], - dtype=dtype, - ), - expected=np.array( - [ - [0.76159418, 0.96402758, 0.99505478, 0.99932933], - [1.0, -1.0, np.nan, 1.0], - [1.0, -1.0, 1.0, -1.0], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.log_softmax, - np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), - expected=np.array( - [ - [-1.3862944, -1.3862944, -1.3862944, -1.3862944], - [-3.4401896, -2.4401896, -1.4401897, -0.44018969], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.elu, - np.array([[-1, 0, 1, -1e-6]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), - rtol=1e-5, - atol=1e-6, - ) - - self.assert_op_output_matches_expected( - nn_ops.selu, - np.array([[-1, 0, 1, -1e-5]], dtype=dtype), - expected=np.array( - [[-1.11133074, 0.0, 1.05070099, -1.758090550379974e-05]], - dtype=dtype, - ), - rtol=1e-5, - atol=1e-6, - ) - - self.assert_op_output_matches_expected( - nn_ops.relu, - np.array([[-1, 1]], dtype=dtype), - expected=np.array([[0, 1]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - nn_ops.relu6, - np.array([[-0.05, 6.05, 5]], dtype=dtype), - expected=np.array([[0, 6, 5]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - nn_ops.leaky_relu, - np.array([[-2, -1, 0, 1, 2]], dtype=dtype), - expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - nn_ops.softmax, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [0.032058604, 0.087144323, 0.23688284, 0.64391428], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.softmax, - np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), - expected=np.array( - [ - [0.25, 0.25, 0.25, 0.25], - [0.032058604, 0.087144323, 0.23688284, 0.64391428], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.softmax, - np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), - expected=np.array( - [ - [[0.5, 0.5], [0.5, 0.5]], - [[0.26894142, 0.73105858], [0.26894142, 0.73105858]], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.softsign, - np.array([[-2, -1, 0, 1, 2]], dtype=dtype), - expected=np.array( - [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.sign, - np.array( - [[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0, float("nan")]], dtype=dtype - ), - expected=np.array( - [[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0, float("nan")]], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.is_finite, - np.array( - [[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype - ), - expected=np.array( - [[True, False, True], [False, True, True]], dtype=np.bool_ - ), - ) - - self.assert_op_output_matches_expected( - math_ops.lgamma, - np.array(0.5, dtype=dtype), - expected=np.array(np.log(np.pi) / 2, dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.lgamma, - np.array( - [ - [1, 2, 3], - [4, 5, 6], - [1 / 2, 3 / 2, 5 / 2], - [-3 / 2, -7 / 2, -11 / 2], - ], - dtype=dtype, - ), - expected=np.array( - [ - [0, 0, np.log(2.0)], - [np.log(6.0), np.log(24.0), np.log(120)], - [ - np.log(np.pi) / 2, - np.log(np.pi) / 2 - np.log(2), - np.log(np.pi) / 2 - np.log(4) + np.log(3), - ], - [ - np.log(np.pi) / 2 - np.log(3) + np.log(4), - np.log(np.pi) / 2 - np.log(105) + np.log(16), - np.log(np.pi) / 2 - np.log(10395) + np.log(64), - ], - ], - dtype=dtype, - ), - ) - - # The actual result is complex. Take the real part. - self.assert_op_output_matches_expected( - math_ops.lgamma, - np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype), - expected=np.array( - [ - np.log(np.pi) / 2 + np.log(2), - np.log(np.pi) / 2 - np.log(15) + np.log(8), - np.log(np.pi) / 2 - np.log(945) + np.log(32), - ], - dtype=dtype, - ), - atol=1e-4, - ) - - self.assert_op_output_matches_expected( - math_ops.digamma, - np.array( - [ - [1.0, 0.5, 1 / 3.0], - [0.25, 1 / 6.0, 0.125], - [2.0, 3.0, 4.0], - [6.0, 8.0, 9.0], - ], - dtype=dtype, - ), - expected=np.array( - [ - [ - -np.euler_gamma, - -2 * np.log(2) - np.euler_gamma, - -np.pi / 2 / np.sqrt(3) - - 3 * np.log(3) / 2 - - np.euler_gamma, - ], - [ - -np.pi / 2 - 3 * np.log(2) - np.euler_gamma, - -np.pi * np.sqrt(3) / 2 - - 2 * np.log(2) - - 3 * np.log(3) / 2 - - np.euler_gamma, - -np.pi / 2 - - 4 * np.log(2) - - ( - np.pi - + np.log(2 + np.sqrt(2)) - - np.log(2 - np.sqrt(2)) - ) - / np.sqrt(2) - - np.euler_gamma, - ], - [ - 1 - np.euler_gamma, - 1.5 - np.euler_gamma, - 11 / 6.0 - np.euler_gamma, - ], - [ - 137 / 60.0 - np.euler_gamma, - 363 / 140.0 - np.euler_gamma, - 761 / 280.0 - np.euler_gamma, - ], - ], - dtype=dtype, - ), - ) + self.assert_op_output_matches_expected( + math_ops.erf, + x, + expected=erf_x, + local_session=session, + ) + self.assert_op_output_matches_expected( + math_ops.erfc, + x, + expected=erfc_x, + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.exp, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[0.36787945, 2.7182817]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.expm1, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), + local_session=session, + rtol=1e-5, + ) + + self.assert_op_output_matches_expected( + math_ops.floor, + np.array([[-1.7, 1.2]], dtype=dtype), + expected=np.array([[-2, 1]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.is_finite, + np.array( + [[-np.inf, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype + ), + expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool_), + local_session=session, + ) + + # Tests for tf.nn ops. + self.assert_op_output_matches_expected( + nn_ops.l2_loss, + np.array([[[]]], dtype=dtype), + expected=dtype(0), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.l2_loss, + dtype(4), + dtype(8), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.l2_loss, + np.array([[-2, 4]], dtype=dtype), + expected=dtype(10), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.reciprocal, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[1, 0.5]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.log, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0, 0.69314718]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sin, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0.841478, 0.909302]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.cos, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0.540297, -0.41614]], dtype=dtype), + local_session=session, + ) + + # Confirm that log1p will remain precise across a range of small values. + self.assert_op_output_matches_expected( + math_ops.log1p, + np.array( + [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], + dtype=dtype, + ), + expected=np.log1p( + np.array( + [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], + dtype=dtype, + ) + ).astype(dtype), + local_session=session, + rtol=1e-15 if dtype == np.float64 else 1e-4, + atol=1e-15 if dtype == np.float64 else 1e-4, + ) + + self.assert_op_output_matches_expected( + math_ops.rint, + np.array( + [ + [-1.7, 1.2, 4.0, 0.0], + [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5], + ], + dtype=dtype, + ), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype + ), + local_session=session, + ) + self.assert_op_output_matches_expected( + math_ops.round, + np.array( + [ + [-1.7, 1.2, 4.0, 0.0], + [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5], + ], + dtype=dtype, + ), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.rsqrt, + np.array([[4, 16]], dtype=dtype), + expected=np.array([[0.5, 0.25]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sigmoid, + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), + expected=np.array( + [ + [0.7310586, 0.7310586, 0.7310586, 0.7310586], + [0.7310586, 0.880797, 0.95257413, 0.98201376], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sigmoid, + np.array([-300, -150, 0, 150, 300], dtype=dtype), + expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sinh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sqrt, + np.array([[4, 9]], dtype=dtype), + expected=np.array([[2, 3]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.tan, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.tanh, + np.array( + [ + [1, 2, 3, 4], + [np.inf, -np.inf, np.nan, 20], + [19, -19, 22, -22], + ], + dtype=dtype, + ), + expected=np.array( + [ + [0.76159418, 0.96402758, 0.99505478, 0.99932933], + [1.0, -1.0, np.nan, 1.0], + [1.0, -1.0, 1.0, -1.0], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.log_softmax, + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), + expected=np.array( + [ + [-1.3862944, -1.3862944, -1.3862944, -1.3862944], + [-3.4401896, -2.4401896, -1.4401897, -0.44018969], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.elu, + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array( + [[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype + ), + rtol=1e-5, + atol=1e-6, + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.selu, + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0.0, 1.05070099, -1.758090550379974e-05]], + dtype=dtype, + ), + rtol=1e-5, + atol=1e-6, + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.relu, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[0, 1]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.relu6, + np.array([[-0.05, 6.05, 5]], dtype=dtype), + expected=np.array([[0, 6, 5]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.leaky_relu, + np.array([[-2, -1, 0, 1, 2]], dtype=dtype), + expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.softmax, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [0.032058604, 0.087144323, 0.23688284, 0.64391428], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.softmax, + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), + expected=np.array( + [ + [0.25, 0.25, 0.25, 0.25], + [0.032058604, 0.087144323, 0.23688284, 0.64391428], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.softmax, + np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), + expected=np.array( + [ + [[0.5, 0.5], [0.5, 0.5]], + [[0.26894142, 0.73105858], [0.26894142, 0.73105858]], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.softsign, + np.array([[-2, -1, 0, 1, 2]], dtype=dtype), + expected=np.array( + [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sign, + np.array( + [[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0, float("nan")]], dtype=dtype + ), + expected=np.array( + [[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0, float("nan")]], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.is_finite, + np.array( + [[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype + ), + expected=np.array( + [[True, False, True], [False, True, True]], dtype=np.bool_ + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.lgamma, + np.array(0.5, dtype=dtype), + expected=np.array(np.log(np.pi) / 2, dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.lgamma, + np.array( + [ + [1, 2, 3], + [4, 5, 6], + [1 / 2, 3 / 2, 5 / 2], + [-3 / 2, -7 / 2, -11 / 2], + ], + dtype=dtype, + ), + expected=np.array( + [ + [0, 0, np.log(2.0)], + [np.log(6.0), np.log(24.0), np.log(120)], + [ + np.log(np.pi) / 2, + np.log(np.pi) / 2 - np.log(2), + np.log(np.pi) / 2 - np.log(4) + np.log(3), + ], + [ + np.log(np.pi) / 2 - np.log(3) + np.log(4), + np.log(np.pi) / 2 - np.log(105) + np.log(16), + np.log(np.pi) / 2 - np.log(10395) + np.log(64), + ], + ], + dtype=dtype, + ), + local_session=session, + ) + + # The actual result is complex. Take the real part. + self.assert_op_output_matches_expected( + math_ops.lgamma, + np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype), + expected=np.array( + [ + np.log(np.pi) / 2 + np.log(2), + np.log(np.pi) / 2 - np.log(15) + np.log(8), + np.log(np.pi) / 2 - np.log(945) + np.log(32), + ], + dtype=dtype, + ), + local_session=session, + atol=1e-4, + ) + + self.assert_op_output_matches_expected( + math_ops.digamma, + np.array( + [ + [1.0, 0.5, 1 / 3.0], + [0.25, 1 / 6.0, 0.125], + [2.0, 3.0, 4.0], + [6.0, 8.0, 9.0], + ], + dtype=dtype, + ), + expected=np.array( + [ + [ + -np.euler_gamma, + -2 * np.log(2) - np.euler_gamma, + -np.pi / 2 / np.sqrt(3) + - 3 * np.log(3) / 2 + - np.euler_gamma, + ], + [ + -np.pi / 2 - 3 * np.log(2) - np.euler_gamma, + -np.pi * np.sqrt(3) / 2 + - 2 * np.log(2) + - 3 * np.log(3) / 2 + - np.euler_gamma, + -np.pi / 2 + - 4 * np.log(2) + - ( + np.pi + + np.log(2 + np.sqrt(2)) + - np.log(2 - np.sqrt(2)) + ) + / np.sqrt(2) + - np.euler_gamma, + ], + [ + 1 - np.euler_gamma, + 1.5 - np.euler_gamma, + 11 / 6.0 - np.euler_gamma, + ], + [ + 137 / 60.0 - np.euler_gamma, + 363 / 140.0 - np.euler_gamma, + 761 / 280.0 - np.euler_gamma, + ], + ], + dtype=dtype, + ), + local_session=session, + ) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 3061b37aaa354c..88b379331b32ef 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -110,12 +110,12 @@ namespace { int64_t tf_xla_random_seed = 0; int32_t tf_xla_test_repetitions = 20; int64_t tf_xla_max_tensor_size = 10000LL; -string* tf_xla_test_device_ptr; // initial value set in main() -string* tf_xla_reference_device_ptr; // initial value set in main() +std::string* tf_xla_test_device_ptr; // initial value set in main() +std::string* tf_xla_reference_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; bool tf_xla_test_use_mlir = false; -string LocalDeviceToFullDeviceName(const string& device) { +std::string LocalDeviceToFullDeviceName(const std::string& device) { return absl::StrCat("/job:localhost/replica:0/task:0/device:", device); } @@ -129,7 +129,7 @@ constexpr std::array kAllNumberTypes = { // operator. class OpTestBuilder { public: - explicit OpTestBuilder(const string& op_name); + explicit OpTestBuilder(const std::string& op_name); // Adds an input 'tensor' as a Placeholder node. OpTestBuilder& Input(const Tensor& tensor); @@ -161,10 +161,11 @@ class OpTestBuilder { // sets it to the NodeDef of the operator under test. Fills 'inputs' and // 'outputs' with the names of the input placeholder nodes and the output // identity nodes, respectively. - absl::Status BuildGraph(const string& name_prefix, const string& device, - bool use_jit, GraphDef* graphdef, - NodeDef** test_node_def, std::vector* inputs, - std::vector* outputs) const; + absl::Status BuildGraph(const std::string& name_prefix, + const std::string& device, bool use_jit, + GraphDef* graphdef, NodeDef** test_node_def, + std::vector* inputs, + std::vector* outputs) const; struct InputDescription { Tensor tensor; @@ -182,7 +183,7 @@ class OpTestBuilder { std::vector inputs_; }; -OpTestBuilder::OpTestBuilder(const string& op_name) { +OpTestBuilder::OpTestBuilder(const std::string& op_name) { node_def_.set_op(op_name); } @@ -247,12 +248,10 @@ OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, return *this; } -absl::Status OpTestBuilder::BuildGraph(const string& name_prefix, - const string& device, bool use_jit, - GraphDef* graphdef, - NodeDef** test_node_def, - std::vector* inputs, - std::vector* outputs) const { +absl::Status OpTestBuilder::BuildGraph( + const std::string& name_prefix, const std::string& device, bool use_jit, + GraphDef* graphdef, NodeDef** test_node_def, + std::vector* inputs, std::vector* outputs) const { OpRegistryInterface* op_registry = OpRegistry::Global(); const OpDef* op_def; @@ -275,7 +274,7 @@ absl::Status OpTestBuilder::BuildGraph(const string& name_prefix, // Build feed and fetch nodes. for (int i = 0; i < input_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = absl::StrCat(name_prefix, "_input_", i); + std::string name = absl::StrCat(name_prefix, "_input_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder") .Device(device) .Attr("dtype", input_types[i]) @@ -286,7 +285,7 @@ absl::Status OpTestBuilder::BuildGraph(const string& name_prefix, for (int i = 0; i < output_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = absl::StrCat(name_prefix, "_output_", i); + std::string name = absl::StrCat(name_prefix, "_output_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity") .Device(device) .Attr("T", output_types[i]) @@ -494,7 +493,7 @@ class OpTest : public ::testing::Test { const std::vector& spatial_dims); // Converts an int64 vector to an int32 vector. - std::vector AsInt32s(const std::vector& int64s); + std::vector AsInt32s(const std::vector& int64s); std::mt19937& generator() { return *generator_; } @@ -664,16 +663,16 @@ class TensorGeneratorComplex64 : public TensorGenerator { } }; -class TensorGeneratorInt32 : public TensorGenerator { +class TensorGeneratorInt32 : public TensorGenerator { public: explicit TensorGeneratorInt32(OpTest& test) : TensorGenerator(test) {} DataType dtype() override { return DT_INT32; } - void RandomVals(std::optional lo, std::optional hi, + void RandomVals(std::optional lo, std::optional hi, bool needs_unique_values, - absl::FixedArray& vals) override { - absl::flat_hash_set already_generated; - std::uniform_int_distribution distribution(lo.value_or(-(1 << 20)), - hi.value_or(1 << 20)); + absl::FixedArray& vals) override { + absl::flat_hash_set already_generated; + std::uniform_int_distribution distribution(lo.value_or(-(1 << 20)), + hi.value_or(1 << 20)); for (int64_t i = 0; i < vals.size(); ++i) { int32_t generated; do { @@ -685,13 +684,13 @@ class TensorGeneratorInt32 : public TensorGenerator { } }; -class TensorGeneratorInt64 : public TensorGenerator { +class TensorGeneratorInt64 : public TensorGenerator { public: explicit TensorGeneratorInt64(OpTest& test) : TensorGenerator(test) {} DataType dtype() override { return DT_INT64; } - void RandomVals(std::optional lo, std::optional hi, + void RandomVals(std::optional lo, std::optional hi, bool needs_unique_values, - absl::FixedArray& vals) override { + absl::FixedArray& vals) override { absl::flat_hash_set already_generated; std::uniform_int_distribution distribution( lo.value_or(-(1LL << 40)), hi.value_or(1LL << 40)); @@ -928,18 +927,19 @@ Tensor OpTest::RandomBoundedTensor(DataType dtype, Tensor lo, Tensor hi) { break; } case DT_INT32: { - auto lo_flat = lo.flat(); - auto hi_flat = hi.flat(); - test::FillFn(&tensor, [this, &lo_flat, &hi_flat](int i) -> int32 { - std::uniform_int_distribution distribution(lo_flat(i), - hi_flat(i)); - return distribution(generator()); - }); + auto lo_flat = lo.flat(); + auto hi_flat = hi.flat(); + test::FillFn( + &tensor, [this, &lo_flat, &hi_flat](int i) -> int32_t { + std::uniform_int_distribution distribution(lo_flat(i), + hi_flat(i)); + return distribution(generator()); + }); break; } case DT_INT64: { - auto lo_flat = lo.flat(); - auto hi_flat = hi.flat(); + auto lo_flat = lo.flat(); + auto hi_flat = hi.flat(); test::FillFn( &tensor, [this, &lo_flat, &hi_flat](int i) -> int64_t { std::uniform_int_distribution distribution(lo_flat(i), @@ -1021,21 +1021,21 @@ OpTest::BroadcastableDims() { Tensor OpTest::RandomReductionIndices(int rank) { std::bernoulli_distribution random_bool; - std::vector indices; + std::vector indices; for (int i = 0; i < rank; ++i) { if (random_bool(generator())) { indices.push_back(i); } } - return test::AsTensor(indices); + return test::AsTensor(indices); } // Helper that converts 'values' to an int32 or int64 Tensor. static Tensor AsIntTensor(DataType dtype, const std::vector& values) { switch (dtype) { case DT_INT32: { - std::vector values32(values.begin(), values.end()); - return test::AsTensor(values32); + std::vector values32(values.begin(), values.end()); + return test::AsTensor(values32); } case DT_INT64: return test::AsTensor(values); @@ -1092,9 +1092,9 @@ OpTest::ConcatArguments OpTest::ChooseConcatArguments(bool int64_idx_allowed) { std::vector dims = RandomDims(1, 4, 0, 64); int axis = - std::uniform_int_distribution(0, dims.size() - 1)(generator()); - a.axis = - use_int64_idx ? test::AsScalar(axis) : test::AsScalar(axis); + std::uniform_int_distribution(0, dims.size() - 1)(generator()); + a.axis = use_int64_idx ? test::AsScalar(axis) + : test::AsScalar(axis); for (int i = 0; i < a.n; ++i) { std::vector shape = dims; @@ -1113,7 +1113,7 @@ OpTest::EinsumArguments OpTest::ChooseEinsumArguments() { switch (op_kind) { case matmul: case batchmatmul: { - std::vector dims; + std::vector dims; if (op_kind == matmul) { a.equation = "ij,jk->ik"; dims = RandomDims(2, 2); @@ -1131,7 +1131,7 @@ OpTest::EinsumArguments OpTest::ChooseEinsumArguments() { } case dot: { a.equation = "i,i->"; - std::vector dims = RandomDims(1, 1); + std::vector dims = RandomDims(1, 1); a.lhs_dims = dims; a.rhs_dims = dims; break; @@ -1166,11 +1166,11 @@ OpTest::GatherArguments OpTest::ChooseGatherArguments(bool axis_0) { a.batch_dims, kDefaultMaxRank - 1); axis = axis_distribution(generator()); } - a.axis = test::AsScalar((int32)axis); + a.axis = test::AsScalar((int32_t)axis); a.params_shape = RandomDims(axis + 1, kDefaultMaxRank, 1, 16); std::vector indices_shape = RandomDims(0, 3, 0, 16); - a.indices = RandomBoundedTensor(DT_INT32, 0, a.params_shape[axis] - 1, - false, indices_shape); + a.indices = RandomBoundedTensor( + DT_INT32, 0, a.params_shape[axis] - 1, false, indices_shape); return a; } @@ -1209,7 +1209,7 @@ OpTest::ScatterArguments OpTest::ChooseScatterArguments() { a.indices_type = DT_INT32; a.shape = RandomDims(1, kDefaultMaxRank, 1); int rank = a.shape.size(); - std::uniform_int_distribution index_len_dist(1, rank); + std::uniform_int_distribution index_len_dist(1, rank); int index_len = index_len_dist(generator()); std::vector indices_first = RandomDims(1, kDefaultMaxRank - 1, 1); std::vector indices_shape(indices_first); @@ -1219,9 +1219,9 @@ OpTest::ScatterArguments OpTest::ChooseScatterArguments() { updates_shape.push_back(a.shape[index_len + i]); } Tensor indices_lo(a.indices_type, TensorShape(indices_shape)); - test::FillFn(&indices_lo, [](int i) -> int32 { return 0; }); + test::FillFn(&indices_lo, [](int i) -> int32_t { return 0; }); Tensor indices_hi(a.indices_type, TensorShape(indices_shape)); - test::FillFn(&indices_hi, [index_len, &a](int i) -> int32 { + test::FillFn(&indices_hi, [index_len, &a](int i) -> int32_t { int idx_dim = i % index_len; return a.shape[idx_dim] - 1; }); @@ -1239,16 +1239,16 @@ OpTest::SliceArguments OpTest::ChooseSliceArguments(bool neg_one_size) { a.shape = RandomDims(); int rank = a.shape.size(); - std::vector indices(rank); + std::vector indices(rank); a.size.resize(rank); for (int i = 0; i < rank; ++i) { indices[i] = - std::uniform_int_distribution(0, a.shape[i])(generator()); + std::uniform_int_distribution(0, a.shape[i])(generator()); int64_t low = neg_one_size ? -1 : 0; a.size[i] = std::uniform_int_distribution( low, a.shape[i] - indices[i])(generator()); } - a.indices = test::AsTensor(indices); + a.indices = test::AsTensor(indices); return a; } @@ -1341,8 +1341,8 @@ std::vector OpTest::ImageDims( return dims; } -std::vector OpTest::AsInt32s(const std::vector& int64s) { - return std::vector(int64s.begin(), int64s.end()); +std::vector OpTest::AsInt32s(const std::vector& int64s) { + return std::vector(int64s.begin(), int64s.end()); } // Functions for comparing tensors. @@ -1382,11 +1382,11 @@ bool IsClose(const complex64& x, const complex64& y, double atol, } template -string Str(T x) { +std::string Str(T x) { return absl::StrCat(x); } template <> -string Str(complex64 x) { +std::string Str(complex64 x) { return absl::StrCat("(", x.real(), ", ", x.imag(), ")"); } @@ -1460,7 +1460,7 @@ absl::Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, case DT_COMPLEX64: return TensorsAreCloseImpl(a, b, atol, rtol); case DT_INT32: - return TensorsAreEqualImpl(a, b); + return TensorsAreEqualImpl(a, b); case DT_INT64: return TensorsAreEqualImpl(a, b); case DT_BOOL: @@ -1499,9 +1499,10 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Input: " << input_tensors.back().DebugString(); } - string reference_device = + std::string reference_device = LocalDeviceToFullDeviceName(*tf_xla_reference_device_ptr); - string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); + std::string test_device = + LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; if (!DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)) { @@ -1512,8 +1513,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( ++num_tests_; GraphDef graph; - std::vector expected_inputs, test_inputs; - std::vector expected_fetches, test_fetches; + std::vector expected_inputs, test_inputs; + std::vector expected_fetches, test_fetches; absl::Status status = builder.BuildGraph( absl::StrCat("test", num_tests_, "_expected"), reference_device, /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs, @@ -1550,8 +1551,9 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( return kFatalError; } - std::vector> expected_feeds(expected_inputs.size()); - std::vector> test_feeds(test_inputs.size()); + std::vector> expected_feeds( + expected_inputs.size()); + std::vector> test_feeds(test_inputs.size()); CHECK_EQ(input_tensors.size(), expected_inputs.size()); CHECK_EQ(input_tensors.size(), test_inputs.size()); @@ -1707,12 +1709,12 @@ TEST_F(OpTest, ArgMax) { auto type = Choose({DT_BOOL, DT_FLOAT}); std::vector dims = RandomDims(1, 5, 1); int num_dims = dims.size(); - int reduce_dim = - std::uniform_int_distribution(-num_dims, num_dims)(generator()); + int reduce_dim = std::uniform_int_distribution( + -num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMax") .RandomInput(type, dims) - .Input(test::AsScalar(reduce_dim)) + .Input(test::AsScalar(reduce_dim)) .Attr("T", type) .Attr("Tidx", DT_INT32) .Attr("output_type", DT_INT32)); @@ -1724,12 +1726,12 @@ TEST_F(OpTest, ArgMin) { auto type = Choose({DT_BOOL, DT_FLOAT}); std::vector dims = RandomDims(1, 5, 1); int num_dims = dims.size(); - int reduce_dim = - std::uniform_int_distribution(-num_dims, num_dims)(generator()); + int reduce_dim = std::uniform_int_distribution( + -num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMin") .RandomInput(type, dims) - .Input(test::AsScalar(reduce_dim)) + .Input(test::AsScalar(reduce_dim)) .Attr("T", type) .Attr("Tidx", DT_INT32) .Attr("output_type", DT_INT32)); @@ -1786,7 +1788,7 @@ TEST_F(OpTest, AvgPool) { std::uniform_int_distribution(1, dims[2])(generator()); int stride_rows = random_int(generator()), stride_cols = random_int(generator()); - string padding = Choose({"SAME", "VALID"}); + std::string padding = Choose({"SAME", "VALID"}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool") .RandomInput(DT_FLOAT, dims) @@ -1817,7 +1819,7 @@ TEST_F(OpTest, AvgPool3D) { int64_t batch = dims[3]; int64_t feature = dims[4]; - string padding = Choose({"SAME", "VALID"}); + std::string padding = Choose({"SAME", "VALID"}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool3D") .RandomInput(DT_FLOAT, @@ -1837,13 +1839,13 @@ TEST_F(OpTest, AvgPoolGrad) { Repeatedly([this]() { int batch = RandomDim(1), features = RandomDim(1); WindowedSpatialDims d = ChooseWindowedSpatialDims(2); - std::vector input_dims = + std::vector input_dims = AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims)); std::vector output_dims = ImageDims(FORMAT_NHWC, batch, features, d.output_dims); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPoolGrad") - .Input(test::AsTensor(input_dims)) + .Input(test::AsTensor(input_dims)) .RandomInput(DT_FLOAT, output_dims) .Attr("T", DT_FLOAT) .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims)) @@ -1859,13 +1861,13 @@ TEST_F(OpTest, AvgPool3DGrad) { Repeatedly([this]() { int batch = RandomDim(1), features = RandomDim(1); WindowedSpatialDims d = ChooseWindowedSpatialDims(3); - std::vector input_dims = + std::vector input_dims = AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims)); std::vector output_dims = ImageDims(FORMAT_NHWC, batch, features, d.output_dims); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool3DGrad") - .Input(test::AsTensor(input_dims)) + .Input(test::AsTensor(input_dims)) .RandomInput(DT_FLOAT, output_dims) .Attr("T", DT_FLOAT) .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims)) @@ -1976,8 +1978,8 @@ TEST_F(OpTest, BatchToSpaceND) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("BatchToSpaceND") .RandomInput(type, input_dims) - .Input(test::AsTensor( - std::vector(block_dims.begin(), block_dims.end()))) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) .Input(crops) .Attr("T", type)); }); @@ -2198,15 +2200,15 @@ TEST_F(OpTest, ConcatOffset) { std::vector dims = RandomDims(1); int concat_dim = - std::uniform_int_distribution(0, dims.size() - 1)(generator()); + std::uniform_int_distribution(0, dims.size() - 1)(generator()); OpTestBuilder builder("ConcatOffset"); - builder.Input(test::AsScalar(concat_dim)); + builder.Input(test::AsScalar(concat_dim)); builder.Attr("N", n); for (int i = 0; i < n; ++i) { - std::vector shape(dims.begin(), dims.end()); + std::vector shape(dims.begin(), dims.end()); shape[concat_dim] = RandomDim(); - builder.Input(test::AsTensor(shape)); + builder.Input(test::AsTensor(shape)); } return ExpectTfAndXlaOutputsAreClose(builder); }); @@ -2280,7 +2282,8 @@ TEST_F(OpTest, IFFT3D) { TEST_F(OpTest, RFFT) { Repeatedly([this]() { std::vector dims = RandomDims(1, kDefaultMaxRank, 3); - Tensor fft_shape = test::AsTensor(AsInt32s({dims[dims.size() - 1]})); + Tensor fft_shape = + test::AsTensor(AsInt32s({dims[dims.size() - 1]})); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape)); }); @@ -2289,7 +2292,7 @@ TEST_F(OpTest, RFFT) { TEST_F(OpTest, RFFT2D) { Repeatedly([this]() { std::vector dims = RandomDims(2, kDefaultMaxRank, 3); - Tensor fft_shape = test::AsTensor( + Tensor fft_shape = test::AsTensor( AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]})); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); @@ -2299,7 +2302,7 @@ TEST_F(OpTest, RFFT2D) { TEST_F(OpTest, RFFT3D) { Repeatedly([this]() { std::vector dims = RandomDims(3, kDefaultMaxRank, 3); - Tensor fft_shape = test::AsTensor(AsInt32s( + Tensor fft_shape = test::AsTensor(AsInt32s( {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]})); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); @@ -2311,7 +2314,7 @@ TEST_F(OpTest, IRFFT) { std::vector dims = RandomDims(1, kDefaultMaxRank, 3); int64_t orig_size = dims[dims.size() - 1]; dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; - Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT") .RandomInput(DT_COMPLEX64, dims) .Input(fft_shape)); @@ -2324,7 +2327,7 @@ TEST_F(OpTest, IRFFT2D) { std::vector orig_size = {dims[dims.size() - 2], dims[dims.size() - 1]}; dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; - Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D") .RandomInput(DT_COMPLEX64, dims) .Input(fft_shape)); @@ -2337,7 +2340,7 @@ TEST_F(OpTest, IRFFT3D) { std::vector orig_size = { dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}; dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; - Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D") .RandomInput(DT_COMPLEX64, dims) .Input(fft_shape)); @@ -2383,7 +2386,7 @@ TEST_F(OpTest, Conv2DBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); std::vector backprop = ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); - Tensor kernel_shape = test::AsTensor(AsInt32s( + Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out})); DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( @@ -2405,7 +2408,7 @@ TEST_F(OpTest, Conv2DBackpropInput) { int features_in = random_int(generator()); int features_out = random_int(generator()); int32_t batch = RandomDim(); - Tensor in_shape = test::AsTensor( + Tensor in_shape = test::AsTensor( AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); std::vector backprop = ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); @@ -2461,7 +2464,7 @@ TEST_F(OpTest, Conv3DBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); std::vector backprop = ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); - Tensor kernel_shape = test::AsTensor( + Tensor kernel_shape = test::AsTensor( AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out})); DataType type = DT_FLOAT; @@ -2485,7 +2488,7 @@ TEST_F(OpTest, Conv3DBackpropInput) { int features_in = random_int(generator()); int features_out = random_int(generator()); int32_t batch = RandomDim(1); - Tensor in_shape = test::AsTensor( + Tensor in_shape = test::AsTensor( AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); std::vector backprop = ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); @@ -2583,7 +2586,7 @@ TEST_F(OpTest, DepthwiseConv2DNativeBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); std::vector backprop = ImageDims( FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims); - Tensor kernel_shape = test::AsTensor(AsInt32s( + Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier})); std::vector strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims); strides[2] = strides[1]; // Current impl only supports equal strides @@ -2608,7 +2611,7 @@ TEST_F(OpTest, DepthwiseConv2DBackpropInput) { int features_in = random_int(generator()); int depth_multiplier = random_int(generator()); int32_t batch = RandomDim(); - Tensor in_shape = test::AsTensor( + Tensor in_shape = test::AsTensor( AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); std::vector backprop = ImageDims( FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims); @@ -2713,15 +2716,15 @@ TEST_F(OpTest, DynamicStitch) { // implementation does so require. However, the native TF implementation // leaves undefined values if we don't cover everything, so we can't // really test that case anyway. - std::vector indices(size); + std::vector indices(size); std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), generator()); int pos = 0; for (int i = 0; i < n; ++i) { TensorShape shape(index_dims[i]); - Tensor t = test::AsTensor( - absl::Span(indices).subspan(pos, shape.num_elements()), + Tensor t = test::AsTensor( + absl::Span(indices).subspan(pos, shape.num_elements()), shape); builder.Input(t); pos += t.NumElements(); @@ -2781,8 +2784,8 @@ TEST_F(OpTest, EluGrad) { TEST_F(OpTest, ScatterNd) { Repeatedly([this]() { auto a = ChooseScatterArguments(); - auto shape = test::AsTensor( - std::vector(a.shape.begin(), a.shape.end())); + auto shape = test::AsTensor( + std::vector(a.shape.begin(), a.shape.end())); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ScatterNd") .Input(a.indices) .Input(a.updates) @@ -2855,8 +2858,9 @@ TEST_F(OpTest, ExpandDims) { auto type = Choose(kAllXlaTypes); std::vector in_dims = RandomDims(); Tensor dim(DT_INT32, TensorShape()); - std::uniform_int_distribution d(-1 - in_dims.size(), in_dims.size()); - dim.scalar()() = d(generator()); + std::uniform_int_distribution d(-1 - in_dims.size(), + in_dims.size()); + dim.scalar()() = d(generator()); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ExpandDims") .RandomInput(type, in_dims) .Input(dim) @@ -2868,10 +2872,10 @@ TEST_F(OpTest, Fill) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(); - std::vector shape(dims.begin(), dims.end()); + std::vector shape(dims.begin(), dims.end()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Fill") - .Input(test::AsTensor(shape)) + .Input(test::AsTensor(shape)) .RandomInput(type, {}) .Attr("T", type)); }); @@ -2949,9 +2953,9 @@ TEST_F(OpTest, GatherNd) { std::vector output_shape(output_outer_shape); output_shape.push_back(index_len); Tensor lo(indices_type, TensorShape(output_shape)); - test::FillFn(&lo, [](int i) -> int32 { return 0; }); + test::FillFn(&lo, [](int i) -> int32_t { return 0; }); Tensor hi(indices_type, TensorShape(output_shape)); - test::FillFn(&hi, [index_len, ¶ms_shape](int i) -> int32 { + test::FillFn(&hi, [index_len, ¶ms_shape](int i) -> int32_t { int idx_dim = i % index_len; return params_shape[idx_dim] - 1; }); @@ -3016,7 +3020,7 @@ TEST_F(OpTest, InplaceUpdate) { x_dims.insert(x_dims.end(), common_dims.begin(), common_dims.end()); std::vector i_shape{v_dims[0]}; Tensor i = - RandomBoundedTensor(DT_INT32, 0, x_dims[0] - 1, true, i_shape); + RandomBoundedTensor(DT_INT32, 0, x_dims[0] - 1, true, i_shape); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("InplaceUpdate") .RandomInput(type, x_dims) .Input(i) @@ -3046,7 +3050,7 @@ TEST_F(OpTest, InvertPermutation) { // TODO(b/211012712): Once needs_unique_values case is linear instead of // quadratic time, use default Dim max instead of 8. int64_t len = RandomDim(0, 8); - Tensor x = RandomBoundedTensor(DT_INT32, 0, len - 1, true, {len}); + Tensor x = RandomBoundedTensor(DT_INT32, 0, len - 1, true, {len}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("InvertPermutation").Input(x).Attr("T", DT_INT32)); }); @@ -3151,7 +3155,7 @@ TEST_F(OpTest, Lgamma) { TEST_F(OpTest, LinSpace) { Repeatedly([this]() { auto ToScalar = [](DataType type, int x) { - if (type == DT_INT32) return test::AsScalar(x); + if (type == DT_INT32) return test::AsScalar(x); return test::AsScalar(x); }; std::uniform_int_distribution distribution(-50, 50); @@ -3290,11 +3294,11 @@ TEST_F(OpTest, MatrixBandPart) { auto type = Choose(kAllXlaTypes); auto index_type = Choose({DT_INT32, DT_INT64}); auto num_lower = - RandomBoundedTensor(index_type, -2 * kDefaultMaxDimensionSize, - 2 * kDefaultMaxDimensionSize, false, {}); + RandomBoundedTensor(index_type, -2 * kDefaultMaxDimensionSize, + 2 * kDefaultMaxDimensionSize, false, {}); auto num_upper = - RandomBoundedTensor(index_type, -2 * kDefaultMaxDimensionSize, - 2 * kDefaultMaxDimensionSize, false, {}); + RandomBoundedTensor(index_type, -2 * kDefaultMaxDimensionSize, + 2 * kDefaultMaxDimensionSize, false, {}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixBandPart") .RandomInput(type) .Input(num_lower) @@ -3330,12 +3334,12 @@ TEST_F(OpTest, MatrixDiagPartV3) { auto type = Choose(kAllXlaTypes); auto align = Choose( {"LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT"}); - auto k0 = std::uniform_int_distribution( + auto k0 = std::uniform_int_distribution( -2 * kDefaultMaxDimensionSize, 2 * kDefaultMaxDimensionSize)(generator()); - auto k1 = std::uniform_int_distribution( + auto k1 = std::uniform_int_distribution( k0, 2 * kDefaultMaxDimensionSize)(generator()); - auto k = test::AsTensor({k0, k1}); + auto k = test::AsTensor({k0, k1}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPartV3") .RandomInput(type) .Input(k) @@ -3369,10 +3373,10 @@ TEST_F(OpTest, MatrixSetDiagV2) { int64_t max_num_diags = shape[rank - 2] + shape[rank - 1] - 1; int64_t num_diags = std::uniform_int_distribution(2, max_num_diags)(generator()); - int32 k0 = std::uniform_int_distribution( + int32_t k0 = std::uniform_int_distribution( -shape[rank - 2] + 1, shape[rank - 1] - num_diags)(generator()); - int32 k1 = k0 + num_diags - 1; - Tensor k = test::AsTensor({k0, k1}); + int32_t k1 = k0 + num_diags - 1; + Tensor k = test::AsTensor({k0, k1}); int64_t max_diag_len = std::min(shape[rank - 2] + std::min(k1, 0), shape[rank - 1] + std::min(-k0, 0)); std::vector diagonal_shape(shape); @@ -3424,7 +3428,7 @@ TEST_F(OpTest, MaxPool) { int stride_rows = random_int(generator()), stride_cols = random_int(generator()); - string padding = Choose({"SAME", "VALID"}); + std::string padding = Choose({"SAME", "VALID"}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("MaxPool") .RandomInput(DT_FLOAT, dims) @@ -3458,7 +3462,7 @@ TEST_F(OpTest, MaxPool3D) { int64_t batch = dims[3]; int64_t feature = dims[4]; - string padding = Choose({"SAME", "VALID"}); + std::string padding = Choose({"SAME", "VALID"}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("MaxPool3D") .RandomInput(DT_FLOAT, @@ -3585,20 +3589,20 @@ TEST_F(OpTest, OneHot) { int32_t depth = RandomDim(); Tensor indices(DT_INT32, TensorShape(dims)); - std::uniform_int_distribution distribution(-depth * 2, depth * 2); - test::FillFn(&indices, [this, &distribution](int i) -> int32 { + std::uniform_int_distribution distribution(-depth * 2, depth * 2); + test::FillFn(&indices, [this, &distribution](int i) -> int32_t { return distribution(generator()); }); - int axis = std::uniform_int_distribution(-num_dims - 5, - num_dims + 5)(generator()); + int axis = std::uniform_int_distribution( + -num_dims - 5, num_dims + 5)(generator()); OpTestBuilder builder("OneHot"); builder.Attr("T", type); builder.Attr("TI", DT_INT32); builder.Attr("axis", axis); builder.Input(indices); - builder.Input(test::AsScalar(depth)); + builder.Input(test::AsScalar(depth)); builder.RandomInput(type, {}); builder.RandomInput(type, {}); return ExpectTfAndXlaOutputsAreClose(builder); @@ -3621,8 +3625,8 @@ TEST_F(OpTest, Pack) { std::vector dims = RandomDims(); int num_dims = dims.size(); - int axis = std::uniform_int_distribution(-num_dims - 1, - num_dims)(generator()); + int axis = std::uniform_int_distribution(-num_dims - 1, + num_dims)(generator()); OpTestBuilder builder("Pack"); builder.Attr("T", type); @@ -3764,7 +3768,7 @@ TEST_F(OpTest, RandomUniform) { TEST_F(OpTest, Range) { Repeatedly([this]() { auto ToScalar = [](DataType type, int x) { - if (type == DT_INT32) return test::AsScalar(x); + if (type == DT_INT32) return test::AsScalar(x); if (type == DT_INT64) return test::AsScalar(x); if (type == DT_FLOAT) return test::AsScalar(x); if (type == DT_DOUBLE) return test::AsScalar(x); @@ -3881,8 +3885,8 @@ TEST_F(OpTest, Reshape) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Reshape") .RandomInput(type, dims_before) - .Input(test::AsTensor( - std::vector(dims_after.begin(), dims_after.end()))) + .Input(test::AsTensor( + std::vector(dims_after.begin(), dims_after.end()))) .Attr("T", type)); }); } @@ -3908,8 +3912,8 @@ TEST_F(OpTest, ResizeBilinear) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ResizeBilinear") .RandomInput(DT_FLOAT, in_dims) - .Input(test::AsTensor( - std::vector(out_dims.begin(), out_dims.end()))) + .Input(test::AsTensor( + std::vector(out_dims.begin(), out_dims.end()))) .Attr("T", DT_FLOAT) .Attr("align_corners", true)); }); @@ -3961,14 +3965,14 @@ TEST_F(OpTest, ReverseSequence) { int batch_size = dims[batch_dim]; int max_seq_len = dims[seq_dim]; - std::vector seq_lens(batch_size); - std::uniform_int_distribution d(0, max_seq_len); + std::vector seq_lens(batch_size); + std::uniform_int_distribution d(0, max_seq_len); absl::c_generate(seq_lens, [&]() { return d(generator()); }); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ReverseSequence") .RandomInput(type, dims) - .Input(test::AsTensor(seq_lens)) + .Input(test::AsTensor(seq_lens)) .Attr("seq_dim", seq_dim) .Attr("batch_dim", batch_dim) .Attr("T", type) @@ -4157,14 +4161,15 @@ TEST_F(OpTest, Size) { TEST_F(OpTest, Slice) { Repeatedly([this]() { SliceArguments a = ChooseSliceArguments(true); - std::vector size; + std::vector size; size.insert(size.end(), a.size.begin(), a.size.end()); - return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice") - .RandomInput(a.type, a.shape) - .Input(a.indices) - .Input(test::AsTensor(size)) - .Attr("T", a.type) - .Attr("Index", a.indices_type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Slice") + .RandomInput(a.type, a.shape) + .Input(a.indices) + .Input(test::AsTensor(size)) + .Attr("T", a.type) + .Attr("Index", a.indices_type)); }); } @@ -4298,8 +4303,8 @@ TEST_F(OpTest, SpaceToBatchND) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SpaceToBatchND") .RandomInput(type, input_dims) - .Input(test::AsTensor( - std::vector(block_dims.begin(), block_dims.end()))) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) .Input(paddings) .Attr("T", type)); }); @@ -4356,16 +4361,16 @@ TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) { int64_t batch_size = dims[0]; int64_t num_classes = dims[1]; - std::vector indices(batch_size); + std::vector indices(batch_size); for (int64_t i = 0; i < batch_size; ++i) { - indices[i] = - std::uniform_int_distribution(0, num_classes - 1)(generator()); + indices[i] = std::uniform_int_distribution( + 0, num_classes - 1)(generator()); } return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits") .RandomInput(DT_FLOAT, dims) - .Input(test::AsTensor(indices)) + .Input(test::AsTensor(indices)) .Attr("T", DT_FLOAT) .Attr("Tlabels", DT_INT32)); }); @@ -4379,18 +4384,19 @@ TEST_F(OpTest, Split) { auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(1); std::uniform_int_distribution ud; - int32_t dim = std::uniform_int_distribution( - -static_cast(dims.size()), - static_cast(dims.size()) - 1)(generator()); + int32_t dim = std::uniform_int_distribution( + -static_cast(dims.size()), + static_cast(dims.size()) - 1)(generator()); int n = std::uniform_int_distribution(1, 5)(generator()); // Ensure 'dim' is evenly divisible by 'n'. dims[dim] /= n; dims[dim] *= n; - return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split") - .Input(test::AsScalar(dim)) - .RandomInput(type, dims) - .Attr("T", type) - .Attr("num_split", n)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Split") + .Input(test::AsScalar(dim)) + .RandomInput(type, dims) + .Attr("T", type) + .Attr("num_split", n)); }); } @@ -4401,12 +4407,12 @@ TEST_F(OpTest, SplitV) { Repeatedly([this]() { // NOLINT: due to GTEST_SKIP auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(1, kDefaultMaxRank, 1); - int32_t dim = std::uniform_int_distribution( - -static_cast(dims.size()), - static_cast(dims.size()) - 1)(generator()); + int32_t dim = std::uniform_int_distribution( + -static_cast(dims.size()), + static_cast(dims.size()) - 1)(generator()); int n = std::uniform_int_distribution( 1, std::min(5, static_cast(dims[dim])))(generator()); - std::vector size_splits(n); + std::vector size_splits(n); for (int i = 0; i < n - 1; ++i) { size_splits.push_back(dims[dim] / n); } @@ -4414,8 +4420,8 @@ TEST_F(OpTest, SplitV) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SplitV") .RandomInput(type, dims) - .Input(test::AsTensor(size_splits)) - .Input(test::AsScalar(dim)) + .Input(test::AsTensor(size_splits)) + .Input(test::AsScalar(dim)) .Attr("T", type) .Attr("num_split", n) .Attr("Tlen", DT_INT32)); @@ -4515,12 +4521,12 @@ TEST_F(OpTest, StridedSlice) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); - std::vector begin(data_dims.size()), end(data_dims.size()); - std::vector strides(data_dims.size()); + std::vector begin(data_dims.size()), end(data_dims.size()); + std::vector strides(data_dims.size()); for (int i = 0; i < data_dims.size(); ++i) { - begin[i] = std::uniform_int_distribution( + begin[i] = std::uniform_int_distribution( -2 * data_dims[i], 2 * data_dims[i])(generator()); - end[i] = std::uniform_int_distribution( + end[i] = std::uniform_int_distribution( -2 * data_dims[i], 2 * data_dims[i])(generator()); // TODO(b/31360685): support strides other than 1 or -1 strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1; @@ -4543,9 +4549,9 @@ TEST_F(OpTest, StridedSlice) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("StridedSlice") .RandomInput(type, data_dims) - .Input(test::AsTensor(begin)) - .Input(test::AsTensor(end)) - .Input(test::AsTensor(strides)) + .Input(test::AsTensor(begin)) + .Input(test::AsTensor(end)) + .Input(test::AsTensor(strides)) .Attr("T", type) .Attr("Index", DT_INT32) .Attr("begin_mask", begin_mask) @@ -4656,14 +4662,14 @@ TEST_F(OpTest, Tile) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(1); - std::vector multiples(t_dims.size()); + std::vector multiples(t_dims.size()); for (int i = 0; i < t_dims.size(); ++i) { multiples[i] = std::uniform_int_distribution(1, 3)(generator()); } return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Tile") .RandomInput(type, t_dims) - .Input(test::AsTensor(multiples)) + .Input(test::AsTensor(multiples)) .Attr("T", type)); }); } @@ -4674,10 +4680,11 @@ TEST_F(OpTest, TopKV2) { Repeatedly([this]() { // NOLINT: due to GTEST_SKIP auto type = Choose({DT_INT32, DT_FLOAT, DT_INT64}); auto shape = RandomDims(1); - int32 k = std::uniform_int_distribution(1, shape[0])(generator()); + int32_t k = + std::uniform_int_distribution(1, shape[0])(generator()); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TopKV2") .RandomInput(type, shape) - .Input(test::AsScalar(k)) + .Input(test::AsScalar(k)) .Attr("sorted", RandomBool()) .Attr("T", type)); }); @@ -4687,13 +4694,14 @@ TEST_F(OpTest, Transpose) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); - std::vector perm(data_dims.size()); + std::vector perm(data_dims.size()); std::iota(perm.begin(), perm.end(), 0); std::shuffle(perm.begin(), perm.end(), generator()); - return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose") - .RandomInput(type, data_dims) - .Input(test::AsTensor(perm)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Transpose") + .RandomInput(type, data_dims) + .Input(test::AsTensor(perm)) + .Attr("T", type)); }); } @@ -4883,8 +4891,8 @@ TEST_F(OpTest, FusedBatchNormTraining) { } // namespace tensorflow int main(int argc, char** argv) { - tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0"); - tensorflow::tf_xla_reference_device_ptr = new tensorflow::string("CPU:0"); + tensorflow::tf_xla_test_device_ptr = new std::string("GPU:0"); + tensorflow::tf_xla_reference_device_ptr = new std::string("CPU:0"); std::vector flag_list = { tensorflow::Flag( "tf_xla_random_seed", &tensorflow::tf_xla_random_seed, @@ -4909,7 +4917,7 @@ int main(int argc, char** argv) { "tf_xla_test_use_mlir", &tensorflow::tf_xla_test_use_mlir, "Use MLIR legalization kernels for the operator under test"), }; - tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + std::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { LOG(ERROR) << "\n" << usage; diff --git a/tensorflow/compiler/tests/unary_ops_composition_test.cc b/tensorflow/compiler/tests/unary_ops_composition_test.cc index 641af606bb24d1..c27b8070bbb450 100644 --- a/tensorflow/compiler/tests/unary_ops_composition_test.cc +++ b/tensorflow/compiler/tests/unary_ops_composition_test.cc @@ -48,9 +48,9 @@ static bool Initialized = [] { class UnaryOpsCompositionTest : public OpsTestBase { protected: template - void RunComposedOp(const std::vector op_names, T input_scalar_value, - T expected_scalar_value) { - string xla_device_name = + void RunComposedOp(const std::vector op_names, + T input_scalar_value, T expected_scalar_value) { + std::string xla_device_name = tensorflow::IsGoogleCudaEnabled() ? DEVICE_XLA_GPU : DEVICE_XLA_CPU; SetDevice(DeviceType(xla_device_name), std::unique_ptr(DeviceFactory::NewDevice( diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 864b64c349e798..7f84349bffd15b 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -31,9 +31,9 @@ def testCopies(self): """Tests that copies onto and off XLA devices work.""" shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3], [16384, 1], [1, 16384], [1, 20000, 1, 1]] - for dtype in self.numeric_types: - for shape in shapes: - with self.session() as sess: + with self.session() as sess: + for dtype in self.numeric_types: + for shape in shapes: with ops.device("CPU"): x = array_ops.placeholder(dtype, shape) with self.test_scope(): @@ -53,8 +53,8 @@ def testCopiesOfUnsupportedTypesFailGracefully(self): dtypes.bfloat16.as_numpy_dtype ]) shape = (10, 10) - for unsupported_dtype in test_types - self.all_types: - with self.session() as sess: + with self.session() as sess: + for unsupported_dtype in test_types - self.all_types: with ops.device("CPU"): x = array_ops.placeholder(unsupported_dtype, shape) with self.test_scope(): diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 20f93d86adfad1..d642418a44c2f5 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -308,7 +308,8 @@ def device_scope(self): yield def assert_op_output_matches_expected( - self, op, inp, expected, equality_test=None, rtol=1e-3, atol=1e-5 + self, op, inp, expected, local_session, + equality_test=None, rtol=1e-3, atol=1e-5 ): """Verifies that 'op' produces 'expected' when fed input 'inp' . @@ -316,25 +317,25 @@ def assert_op_output_matches_expected( op: operator to test inp: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. + local_session: The session to use for the test. equality_test: either None, or a function that tests two numpy arrays for equality. If None, self.assertAllClose is used. rtol: relative tolerance for equality test. atol: absolute tolerance for equality test. """ - with self.session() as local_session: - with self.test_scope(): - pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name='a' - ) - output = op(pinp) - result = local_session.run(output, {pinp: inp}) - if equality_test is None: - self.assertEqual(output.dtype, expected.dtype) - self.assertAllCloseAccordingToType( - expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03 - ) - else: - equality_test(result, expected, rtol=rtol, atol=atol) + with self.test_scope(): + pinp = array_ops.placeholder( + dtypes.as_dtype(inp.dtype), inp.shape, name='a' + ) + output = op(pinp) + result = local_session.run(output, {pinp: inp}) + if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03 + ) + else: + equality_test(result, expected, rtol=rtol, atol=atol) def test_scope(self): """Deprecated alias of `device_scope`. diff --git a/tensorflow/compiler/tf2tensorrt/common/datavec.h b/tensorflow/compiler/tf2tensorrt/common/datavec.h index eff32f1f521af4..34b419d1d20d62 100644 --- a/tensorflow/compiler/tf2tensorrt/common/datavec.h +++ b/tensorflow/compiler/tf2tensorrt/common/datavec.h @@ -27,7 +27,7 @@ namespace tensorrt { // Input/output data format for OpConverterTest::BuildAndRun(). struct InputOutputData { size_t TotalBytes() const { return tensor.TotalBytes(); } - string name; + std::string name; Tensor tensor; }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc index c8eb3db2e0b9e4..b4c3052953c677 100755 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc @@ -739,16 +739,16 @@ class ReIndexer { // Initializes the index map with existing lowercase labels. ReIndexer(std::string eq) { for (char c : eq) { - if (islower(c)) { + if (absl::ascii_islower(c)) { idx_map_[c] = c; } } } // Finds new character for uppercase character c. char operator()(char c) { - if (!std::isupper(c)) return c; + if (!absl::ascii_isupper(c)) return c; if (idx_map_.count(c) > 0) return idx_map_[c]; - char new_idx = std::tolower(c); + char new_idx = absl::ascii_tolower(c); // If lower(c) is not used in the equation, use it to replace c. if (idx_map_.count(new_idx) == 0) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.cc index faedcf3de8c427..000c32df25d253 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.cc @@ -81,9 +81,7 @@ string ProfileStrategyToName(const ProfileStrategy strategy) { } Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy) { - string name_lowercase(name); - std::transform(name.begin(), name.end(), name_lowercase.begin(), - [](unsigned char c) { return std::tolower(c); }); + std::string name_lowercase = absl::AsciiStrToLower(name); if (name_lowercase == "range") { *strategy = ProfileStrategy::kRange; } else if (name_lowercase == "optimal") { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc index 30aff91a76d3b1..d1bf00a53d1cc3 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc @@ -99,7 +99,7 @@ string TRTEngineCacheResource::DebugString() const { EngineContext* TRTEngineCacheResource::GetEngineContext( const std::vector& input_shapes) { EngineContext* engine_context = nullptr; - int64 min_matched_batch_size = kint64max; + int64 min_matched_batch_size = std::numeric_limits::max(); for (const auto& pair : cache_) { const std::vector& cached_input_shapes = pair.first; // This should not happen, but just for safety. diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 416f1a37179736..e5545445817ec2 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -138,6 +138,25 @@ cc_library( ], ) +cc_library( + name = "encoded_buffer_allocation_info", + hdrs = ["encoded_buffer_allocation_info.h"], + visibility = [":friends"], + deps = [ + "@local_xla//xla/backends/cpu:buffer_allocation_info", + ], +) + +tf_cc_test( + name = "encoded_buffer_allocation_info_test", + srcs = ["encoded_buffer_allocation_info_test.cc"], + deps = [ + ":encoded_buffer_allocation_info", + "@com_google_googletest//:gtest_main", + "@local_xla//xla/backends/cpu:buffer_allocation_info", + ], +) + cc_library( name = "tf2xla", srcs = ["tf2xla.cc"], @@ -218,6 +237,7 @@ filegroup( name = "xla_compiled_cpu_runtime_hdrs", srcs = [ "allocator.h", + "encoded_buffer_allocation_info.h", "xla_compiled_cpu_function.h", "//tensorflow/core/kernels:xla_cpu_runtime_hdrs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", @@ -355,6 +375,7 @@ cc_library( # "@local_tsl//tsl/platform:context", # "@local_tsl//tsl/platform:cord", # "@local_tsl//tsl/platform:env_time", +# "@local_tsl//tsl/platform:refcount", # "@local_tsl//tsl/platform:ml_dtypes", # "@local_tsl//tsl/platform:logging", # "@local_tsl//tsl/platform:macros", @@ -406,8 +427,22 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/base:dynamic_annotations", - "@local_xla//xla:cpu_function_runtime", + "@com_google_absl//absl/types:span", + "@local_xla//xla/backends/cpu:alignment", + "@local_xla//xla/backends/cpu:buffer_allocation_info", + ], +) + +tf_cc_test( + name = "allocator_test", + srcs = ["allocator_test.cc"], + deps = [ + ":allocator", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", "@local_xla//xla/backends/cpu:alignment", + "@local_xla//xla/backends/cpu:buffer_allocation_info", ], ) @@ -418,14 +453,16 @@ cc_library( compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], deps = [ + # Keep dependencies to a minimum here; this library is used in every AOT + # binary produced by tfcompile. ":allocator", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + ":encoded_buffer_allocation_info", "@local_xla//xla/service:custom_call_status_internal", - # Keep dependencies to a minimum here; this library is used in every AOT - # binary produced by tfcompile. "@local_xla//xla/backends/cpu/runtime:rng_state_lib", "@local_xla//xla/backends/cpu:alignment", - "@local_xla//xla:cpu_function_runtime", + "@local_xla//xla/backends/cpu:buffer_allocation_info", "@local_xla//xla:executable_run_options", "//tensorflow/core/platform:types", "@com_google_absl//absl/container:flat_hash_map", @@ -481,25 +518,13 @@ cc_library( alwayslink = 1, ) -tf_cc_test( - name = "cpu_function_runtime_test", - srcs = ["cpu_function_runtime_test.cc"], - deps = [ - ":allocator", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@local_xla//xla:cpu_function_runtime", - "@local_xla//xla/backends/cpu:alignment", - ], -) - cc_library( name = "xla_jit_compiled_cpu_function", srcs = ["xla_jit_compiled_cpu_function.cc"], hdrs = ["xla_jit_compiled_cpu_function.h"], visibility = ["//visibility:public"], deps = [ + ":encoded_buffer_allocation_info", ":tf2xla", ":tf2xla_proto_cc", ":xla_compiled_cpu_function", @@ -513,9 +538,10 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:casts", - "@local_xla//xla:cpu_function_runtime", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/backends/cpu:buffer_allocation_info", + "@local_xla//xla/backends/cpu:buffer_allocation_info_util", "@local_xla//xla/backends/cpu/codegen:compiled_function_library", "@local_xla//xla/client:client_library", "@local_xla//xla/client:executable_build_options", @@ -528,7 +554,6 @@ cc_library( ] + if_libtpu( if_false = [ "@local_xla//xla/service:cpu_plugin", - "@local_xla//xla/service/cpu:buffer_info_util", "@local_xla//xla/service/cpu:cpu_executable", ], if_true = [], diff --git a/tensorflow/compiler/tf2xla/allocator.cc b/tensorflow/compiler/tf2xla/allocator.cc index 7f7c3a351bbe87..08db8bb0261bc6 100644 --- a/tensorflow/compiler/tf2xla/allocator.cc +++ b/tensorflow/compiler/tf2xla/allocator.cc @@ -20,8 +20,9 @@ limitations under the License. #include #include "absl/base/dynamic_annotations.h" +#include "absl/types/span.h" #include "xla/backends/cpu/alignment.h" -#include "xla/cpu_function_runtime.h" +#include "xla/backends/cpu/buffer_allocation_info.h" namespace tensorflow { @@ -64,26 +65,26 @@ size_t align_to(size_t n, size_t align) { } // namespace size_t AlignedBufferBytes( - const xla::cpu_function_runtime::BufferInfo* buffer_infos, size_t n, + absl::Span buffers, bool allocate_entry_params) { size_t total = 0; - for (size_t i = 0; i < n; ++i) { + for (size_t i = 0; i < buffers.size(); ++i) { bool should_allocate = - buffer_infos[i].is_temp_buffer() || - (buffer_infos[i].is_entry_parameter() && allocate_entry_params); + buffers[i].is_temp() || buffers[i].is_result() || + (buffers[i].is_entry_parameter() && allocate_entry_params); if (should_allocate) { - total += align_to(buffer_infos[i].size(), xla::cpu::Align()); + total += align_to(buffers[i].size(), xla::cpu::Align()); } } return total; } void* MallocContiguousBuffers( - const xla::cpu_function_runtime::BufferInfo* buffer_infos, size_t n, + absl::Span buffers, bool allocate_entry_params, void** bufs, bool annotate_initialized) { const size_t total = - tensorflow::AlignedBufferBytes(buffer_infos, n, allocate_entry_params); + tensorflow::AlignedBufferBytes(buffers, allocate_entry_params); void* contiguous = nullptr; if (total > 0) { contiguous = aligned_malloc(total, xla::cpu::Align()); @@ -94,13 +95,13 @@ void* MallocContiguousBuffers( } } uintptr_t pos = reinterpret_cast(contiguous); - for (size_t i = 0; i < n; ++i) { + for (size_t i = 0; i < buffers.size(); ++i) { bool should_allocate = - buffer_infos[i].is_temp_buffer() || - (buffer_infos[i].is_entry_parameter() && allocate_entry_params); + buffers[i].is_temp() || buffers[i].is_result() || + (buffers[i].is_entry_parameter() && allocate_entry_params); if (should_allocate) { bufs[i] = reinterpret_cast(pos); - pos += align_to(buffer_infos[i].size(), xla::cpu::Align()); + pos += align_to(buffers[i].size(), xla::cpu::Align()); } else { bufs[i] = nullptr; } diff --git a/tensorflow/compiler/tf2xla/allocator.h b/tensorflow/compiler/tf2xla/allocator.h index 4ed60e4cb65535..b9d181ff60ba06 100644 --- a/tensorflow/compiler/tf2xla/allocator.h +++ b/tensorflow/compiler/tf2xla/allocator.h @@ -18,7 +18,8 @@ limitations under the License. #include -#include "xla/cpu_function_runtime.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/buffer_allocation_info.h" namespace tensorflow { @@ -27,7 +28,7 @@ namespace tensorflow { // allocate_entry_params is false, entry parameters. There are `n` entries in // `buffer_infos`. Each buffer is aligned to Align() byte boundaries. size_t AlignedBufferBytes( - const xla::cpu_function_runtime::BufferInfo* buffer_infos, size_t n, + absl::Span buffers, bool allocate_entry_params); // MallocContiguousBuffers allocates buffers for use by the entry point @@ -43,7 +44,7 @@ size_t AlignedBufferBytes( // the head of the allocated contiguous block, which should be passed to // FreeContiguous when the buffers are no longer in use. void* MallocContiguousBuffers( - const xla::cpu_function_runtime::BufferInfo* buffer_infos, size_t n, + absl::Span buffers, bool allocate_entry_params, void** bufs, bool annotate_initialized); // FreeContiguous frees the contiguous block of memory allocated by diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc b/tensorflow/compiler/tf2xla/allocator_test.cc similarity index 71% rename from tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc rename to tensorflow/compiler/tf2xla/allocator_test.cc index 6904c58489f861..d5b9158c1fcb3f 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc +++ b/tensorflow/compiler/tf2xla/allocator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/cpu_function_runtime.h" +#include "tensorflow/compiler/tf2xla/allocator.h" #include #include @@ -21,17 +21,17 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/allocator.h" #include "xla/backends/cpu/alignment.h" +#include "xla/backends/cpu/buffer_allocation_info.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -using ::xla::cpu_function_runtime::BufferInfo; +using ::xla::cpu::BufferAllocationInfo; -TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { +TEST(AllocatorTest, AlignmentValue) { // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla @@ -41,38 +41,41 @@ TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { EXPECT_LE(xla::cpu::MinAlign(), Allocator::kAllocatorAlignment); } -std::vector SizesToBufferInfos(const intptr_t* sizes, size_t n) { - std::vector buffer_infos; - std::transform(sizes, sizes + n, std::back_inserter(buffer_infos), - [&](intptr_t size) { - if (size == -1) { - // Use a dummy on-stack buffer allocation to indicate the - // the current slot does not need an allocation. - int64_t on_stack_buffer_size = 4; - return BufferInfo::MakeOnStackBuffer(on_stack_buffer_size); - } - return BufferInfo::MakeTempBuffer(size); - }); +std::vector SizesToBufferAllocationInfos( + const intptr_t* sizes, size_t n) { + std::vector buffer_infos; + std::transform( + sizes, sizes + n, std::back_inserter(buffer_infos), [&](intptr_t size) { + if (size == -1) { + // Use a dummy on-stack buffer allocation to indicate the + // the current slot does not need an allocation. + int64_t on_stack_buffer_size = 4; + return BufferAllocationInfo::ThreadLocal(on_stack_buffer_size); + } + return BufferAllocationInfo::Temp(size); + }); return buffer_infos; } // Simple wrappers to make writing tests more ergonomic. size_t AlignedBufferBytesFromSizes(const intptr_t* sizes, size_t n) { - std::vector buffer_infos = SizesToBufferInfos(sizes, n); - return tensorflow::AlignedBufferBytes(buffer_infos.data(), n, + std::vector buffer_infos = + SizesToBufferAllocationInfos(sizes, n); + return tensorflow::AlignedBufferBytes(buffer_infos, /*allocate_entry_params=*/false); } void* MallocContiguousBuffersFromSizes(const intptr_t* sizes, size_t n, void** bufs, bool annotate_initialized) { - std::vector buffer_infos = SizesToBufferInfos(sizes, n); - return tensorflow::MallocContiguousBuffers(buffer_infos.data(), n, + std::vector buffer_infos = + SizesToBufferAllocationInfos(sizes, n); + return tensorflow::MallocContiguousBuffers(buffer_infos, /*allocate_entry_params=*/false, bufs, annotate_initialized); } -TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) { +TEST(AllocatorTest, AlignedBufferBytes) { EXPECT_EQ(AlignedBufferBytesFromSizes(nullptr, 0), 0); static constexpr intptr_t sizesA[1] = {-1}; @@ -96,7 +99,7 @@ void* add_ptr(void* base, uintptr_t delta) { // expected nullptrs, and write to each byte of allocated memory. We rely on // the leak checker to tell us if there's an inconsistency between malloc and // free. We also check the contiguous property. -TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { +TEST(AllocatorTest, MallocFreeContiguousBuffers) { // Test empty sizes. void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false); EXPECT_EQ(base, nullptr); @@ -158,23 +161,5 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { FreeContiguous(base); } -void CheckRoundTripIsOk(const BufferInfo& buffer_info) { - BufferInfo round_trip(buffer_info.Encode()); - ASSERT_EQ(round_trip, buffer_info); -} - -TEST(XlaCompiledCpuFunctionTest, BufferInfoTest) { - CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(0)); - CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(4)); - CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(0)); - CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(4)); - CheckRoundTripIsOk(BufferInfo::MakeConstant(0)); - CheckRoundTripIsOk(BufferInfo::MakeConstant(4)); - CheckRoundTripIsOk( - BufferInfo::MakeEntryParameter(/*size=*/0, /*param_number=*/4)); - CheckRoundTripIsOk( - BufferInfo::MakeEntryParameter(/*size=*/4, /*param_number=*/0)); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index c7c8702b49b774..d9f6927c09ecd6 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -180,7 +180,7 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_0) { // not need to be a constant. Output reshape = ops::Reshape(root, arg1, add); reshape.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, - std::vector()); + std::vector()); Graph graph(OpRegistry::Global()); TF_ASSERT_OK(root.ToGraph(&graph)); @@ -203,7 +203,7 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_1) { // Force const analysis to pretend that the first argument to `add` needs to // be a constant. - std::vector add_constant_inputs; + std::vector add_constant_inputs; add_constant_inputs.push_back("x"); add.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, add_constant_inputs); diff --git a/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h b/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h new file mode 100644 index 00000000000000..5981751259967a --- /dev/null +++ b/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h @@ -0,0 +1,99 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_ +#define TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_ + +#include + +#include "xla/backends/cpu/buffer_allocation_info.h" + +namespace xla { +namespace cpu { + +// Encoded version of `BufferAllocationInfo`, which can be used to reconstruct +// the `BufferAllocationInfo` later. It's used in the AOT compiler, to +// represent buffer allocation info as a lightweight struct. +struct EncodedBufferAllocationInfo { + EncodedBufferAllocationInfo(uint64_t packed_kind_and_size, + uint32_t entry_param_number, + uint32_t result_number) + : packed_kind_and_size(packed_kind_and_size), + entry_param_number(entry_param_number), + result_number(result_number) {} + + // Encodes BufferAllocationInfo into the struct that can be used to + // reconstruct the BufferAllocationInfo later using the constructor. We need + // this because we use BufferAllocationInfo in places where using protocol + // buffers would negatively impact binary size. + explicit EncodedBufferAllocationInfo( + const BufferAllocationInfo& buffer_info) { + packed_kind_and_size = Pack(buffer_info.kind(), buffer_info.size()); + entry_param_number = buffer_info.is_entry_parameter() + ? buffer_info.entry_parameter_number() + : -1; + result_number = buffer_info.is_result() ? buffer_info.result_number() : -1; + } + + explicit operator BufferAllocationInfo() const { + auto kind = UnpackKind(packed_kind_and_size); + auto size = UnpackSize(packed_kind_and_size); + int32_t entry_param_number = static_cast(this->entry_param_number); + int32_t result_number = static_cast(this->result_number); + + switch (kind) { + case BufferAllocationInfo::Kind::kConstant: + return BufferAllocationInfo::Constant(size); + case BufferAllocationInfo::Kind::kTemp: + return BufferAllocationInfo::Temp(size); + case BufferAllocationInfo::Kind::kParameter: + if (entry_param_number >= 0 && result_number >= 0) { + return BufferAllocationInfo::InOutParameter(size, entry_param_number, + result_number); + } + if (entry_param_number >= 0) { + return BufferAllocationInfo::EntryParameter(size, entry_param_number); + } + return BufferAllocationInfo::Result(size, result_number); + case BufferAllocationInfo::Kind::kThreadLocal: + return BufferAllocationInfo::ThreadLocal(size); + } + } + + static uint64_t Pack(BufferAllocationInfo::Kind kind, uint64_t size) { + return (static_cast(size) << 2) | static_cast(kind); + } + + static constexpr BufferAllocationInfo::Kind UnpackKind(uint64_t packed) { + return static_cast((packed << 62) >> 62); + } + + static constexpr uint64_t UnpackSize(uint64_t packed) { return packed >> 2; } + + uint64_t packed_kind_and_size = 0; + uint32_t entry_param_number = -1; + uint32_t result_number = -1; +}; +} // namespace cpu + +// TODO(ezhulenev): This is a temporary hack to keep `tfcompile` code working. +namespace cpu_function_runtime { +using BufferInfo = ::xla::cpu::BufferAllocationInfo; +using EncodedBufferInfo = ::xla::cpu::EncodedBufferAllocationInfo; +} // namespace cpu_function_runtime + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_ diff --git a/third_party/xla/xla/backends/cpu/buffer_allocation_info_test.cc b/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info_test.cc similarity index 81% rename from third_party/xla/xla/backends/cpu/buffer_allocation_info_test.cc rename to tensorflow/compiler/tf2xla/encoded_buffer_allocation_info_test.cc index b0b5bd57035fa2..c9fc52100abb33 100644 --- a/third_party/xla/xla/backends/cpu/buffer_allocation_info_test.cc +++ b/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2025 The OpenXLA Authors. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/backends/cpu/buffer_allocation_info.h" +#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h" #include +#include "xla/backends/cpu/buffer_allocation_info.h" namespace xla::cpu { namespace { -TEST(BufferAllocationInfoTest, RoundTrip) { +TEST(EncodedBufferAllocationInfoTest, RoundTrip) { auto round_trip = [](const BufferAllocationInfo& buffer_info) { - BufferAllocationInfo round_trip(buffer_info.Encode()); + EncodedBufferAllocationInfo encoded(buffer_info); + BufferAllocationInfo round_trip(encoded); ASSERT_EQ(round_trip, buffer_info); }; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index ba297127eae117..2adc83512c6617 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -83,11 +83,11 @@ struct ClusterTupleLessThan { }; // TODO(jpienaar): Move to OutputTensor. -string DebugString(const OutputTensor& tensor) { +std::string DebugString(const OutputTensor& tensor) { return absl::StrCat(tensor.node->name(), ":", tensor.index); } -string Branch_Name(BranchType b) { +std::string Branch_Name(BranchType b) { switch (b) { case BranchType::kElseBranch: return "else"; @@ -100,13 +100,13 @@ string Branch_Name(BranchType b) { } } -string DebugString(StateMap::CondId cond_state) { +std::string DebugString(StateMap::CondId cond_state) { if (cond_state == nullptr || cond_state->empty()) return "{}"; using value_type = StateMap::CondState::value_type; return absl::StrCat( "{", absl::StrJoin(*cond_state, ", ", - [](string* output, const value_type& pred_branch) { + [](std::string* output, const value_type& pred_branch) { const OutputTensor& pred = pred_branch.first; const BranchType& branch = pred_branch.second; if (branch == BranchType::kNeither) @@ -200,7 +200,7 @@ struct CondArgNode { explicit CondArgNode(Node* src, int src_output) : src(src), src_output(src_output) {} - string ToString() const { + std::string ToString() const { return absl::StrCat("src=", src->name(), ":", src_output, " switches=", NodesToString(switches)); } @@ -212,11 +212,11 @@ struct CondArgNode { }; using CondArgNodes = std::vector; -string DebugString(const CondArgNodes& nodes) { +std::string DebugString(const CondArgNodes& nodes) { return absl::StrCat( "[", absl::StrJoin(nodes, ", ", - [](string* output, const CondArgNode& node) { + [](std::string* output, const CondArgNode& node) { absl::StrAppend(output, node.ToString()); }), "]"); @@ -263,20 +263,20 @@ void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } -string StateMap::CondStateToString(const Node* node) const { +std::string StateMap::CondStateToString(const Node* node) const { return CondStateToString(LookupCondId(node)); } -string StateMap::CondStateToString(StateMap::CondId id) const { +std::string StateMap::CondStateToString(StateMap::CondId id) const { return DebugString(id); } -string StateMap::AncestorStateToString(const Node* node) const { +std::string StateMap::AncestorStateToString(const Node* node) const { if (auto id = LookupAncestorId(node)) { return absl::StrCat( "{", absl::StrJoin(*id, ",", - [](string* output, const AncestorNode& ancestor) { + [](std::string* output, const AncestorNode& ancestor) { absl::StrAppend(output, ancestor.output_tensor.node->name(), ":", ancestor.output_tensor.index); @@ -340,7 +340,7 @@ class Conditional { // Internal name of conditional. The name is based on the first merge node // added. - string name() const; + std::string name() const; // The FunctionalizeCond instance that created this. FunctionalizeCond* parent_; @@ -751,7 +751,7 @@ absl::Status Conditional::BuildIfNode(Graph* graph, VLOG(2) << "Build cond function for " << name(); NodeDebugInfo debug_info((*merges_.begin())->def()); NodeDefBuilder builder(name(), "If", library, &debug_info); - const string branch_name[] = {"else_branch", "then_branch"}; + const std::string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -817,7 +817,7 @@ absl::Status Conditional::BuildIfNode(Graph* graph, builder.Attr("Tcond", DT_BOOL); // Add some internal attributes which need to be propagated. for (absl::string_view attr_name : kAttrsToPropagate) { - string attr_val; + std::string attr_val; if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) { builder.Attr(attr_name, attr_val); } @@ -949,7 +949,7 @@ absl::Status Conditional::BuildAndReplace( return absl::OkStatus(); } -string Conditional::name() const { +std::string Conditional::name() const { CHECK(!merges_.empty()); return absl::StrCat((*merges_.begin())->name(), "_if"); } @@ -958,7 +958,7 @@ absl::Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, int port) { NodeBuilder id_builder(replacee->name(), "Identity"); id_builder.Input(if_node, port); - string outside_compilation; + std::string outside_compilation; if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttr, &outside_compilation) .ok()) { @@ -1580,7 +1580,7 @@ absl::Status FunctionalizeCond::FunctionalizeInternal() { return absl::OkStatus(); } -void FunctionalizeCond::DumpGraphWithCondState(const string& name) { +void FunctionalizeCond::DumpGraphWithCondState(const std::string& name) { const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup"; for (Node* n : graph_->nodes()) { diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index e37555b053d7ed..25d773ad50a105 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -136,11 +136,11 @@ class StateMap { BranchType FindBranchOf(CondId id, OutputTensor predicate) const; // Returns textual representation of node's CondState. - string CondStateToString(const Node* node) const; - string CondStateToString(CondId id) const; + std::string CondStateToString(const Node* node) const; + std::string CondStateToString(CondId id) const; // Returns textual representation of node's AncestorState. - string AncestorStateToString(const Node* node) const; + std::string AncestorStateToString(const Node* node) const; // Returns whether the cond state is the dead state. bool IsDead(CondId id) const; @@ -201,7 +201,7 @@ class FunctionalizeCond { absl::Status PropagateUpdatedState(const Node* replacee); // Dump graph with the CondState annotated. - void DumpGraphWithCondState(const string& name); + void DumpGraphWithCondState(const std::string& name); // Adds `switch_id` to the list of Switch node ids. void AddSwitchId(int switch_id); diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index 50bd47ad73e77e..edb2a7e0ea1b33 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -48,7 +48,7 @@ class FunctionalizeCondTest : public ::testing::Test { return fc_->state_map_.GetCondId(state); } - string GetString(const StateMap::StateMap::CondId id) { + std::string GetString(const StateMap::StateMap::CondId id) { return fc_->state_map_.CondStateToString(id); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index ac38725269bfd9..22b9b9187ecd7d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -51,8 +51,9 @@ namespace tensorflow { // Maps function name to // - new function name, if the function body was functionalized // - std::nullopt, if not -using FuncMap = std::map>; -using FuncMapIter = std::map>::const_iterator; +using FuncMap = std::map>; +using FuncMapIter = + std::map>::const_iterator; // Returns whether function has been processed before. bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) { @@ -65,8 +66,8 @@ bool FunctionHasBeenModified(FuncMapIter func_iter) { } // Returns a name for the new functionalized version of a function. -string GetNewFunctionName( - const string& func_name, Node* n, +std::string GetNewFunctionName( + const std::string& func_name, Node* n, AssociatedFunctionInfo::AssociatedFunctionType func_type, FunctionLibraryDefinition* fld) { // For SymbolicGradient, `func_name` is always "SymbolicGradient" which @@ -79,14 +80,15 @@ string GetNewFunctionName( } // Returns name to which a modified function has been mapped. -const string& GetMappedFunctionName(FuncMapIter func_iter) { +const std::string& GetMappedFunctionName(FuncMapIter func_iter) { DCHECK(func_iter->second.has_value()); return func_iter->second.value(); } // Updates `func_map` with function given by `canonicalized_name`. -void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name, - const string& new_func_name, bool function_modified) { +void UpdateFunctionMap(FuncMap* func_map, const std::string& canonicalized_name, + const std::string& new_func_name, + bool function_modified) { // If function was modified store its new name, otherwise add empty entry to // record that function has been processed and does not need to be rewritten. (*func_map)[canonicalized_name] = @@ -95,8 +97,9 @@ void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name, // Adds new function def to graph's function library if necessary. absl::Status AddFunctionDefToGraphLibrary( - const string& func_name, const AssociatedFunctionInfo& associated_function, - Graph* graph, FunctionLibraryDefinition* fld) { + const std::string& func_name, + const AssociatedFunctionInfo& associated_function, Graph* graph, + FunctionLibraryDefinition* fld) { const OpRegistrationData* op_reg_data; // We have to be careful with adding the function def since there are three // different `OpRegistryInterface`s involved here: @@ -129,8 +132,8 @@ absl::Status AddFunctionDefToGraphLibrary( // Functionalizes function given by `func_name`. Update `func_map` accordingly. absl::Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, + const std::string& func_name, const std::string& new_func_name, + const protobuf::Map& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter = {}); @@ -165,11 +168,11 @@ absl::Status FunctionalizeControlFlowForNodeAssociatedFunctions( associated_functions.size() == 1); // Process one node-function-pair. - string func_name = associated_function.func_name(); - string canonicalized_name = + std::string func_name = associated_function.func_name(); + std::string canonicalized_name = Canonicalize(func_name, AttrSlice(&associated_function.attrs())); auto func_iter = func_map->find(canonicalized_name); - string new_func_name; + std::string new_func_name; if (FunctionHasBeenProcessed(func_iter, func_map)) { if (FunctionHasBeenModified(func_iter)) { *any_function_modified = true; @@ -202,8 +205,8 @@ absl::Status FunctionalizeControlFlowForNodeAssociatedFunctions( } absl::Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, + const std::string& func_name, const std::string& new_func_name, + const protobuf::Map& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) { *function_modified = false; @@ -341,8 +344,8 @@ absl::Status FunctionalizeControlFlowForXlaPass::Run( // Find XLA compile ops and its corresponding FunctionDef. // TPUCompile op is not in the map because graph rewriting might happen // multiple times, and we want to avoid functionalize it again. - static std::map* kNodeTypeToFunctionAttrMapping = - new std::map{ + static std::map* kNodeTypeToFunctionAttrMapping = + new std::map{ // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass. {"_TPUReplicate", "computation"}, // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. @@ -355,12 +358,12 @@ absl::Status FunctionalizeControlFlowForXlaPass::Run( if (it == kNodeTypeToFunctionAttrMapping->end()) { continue; } - const string func_attr = it->second; + const std::string func_attr = it->second; NameAttrList func; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); VLOG(2) << "Graph has node " << n->type_string() << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( + std::string new_func_name = options.flib_def->UniqueFunctionName( absl::StrCat(func.name(), "_f15n_")); bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 7727853a8c4233..24fe7f5e13e7e0 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -46,7 +46,7 @@ namespace { // Returns the names of the "then" and "else" functions for the If node in a // graph. -absl::Status FindIfThenAndElse(const GraphDef& graph, string* op_name, +absl::Status FindIfThenAndElse(const GraphDef& graph, std::string* op_name, NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { if (node.op() == "If") { @@ -97,7 +97,7 @@ INSTANTIATE_TEST_SUITE_P( info) { bool restrict_to_tpu_nodes = std::get<0>(info.param); bool wrap_cond_in_function = std::get<1>(info.param); - string name = + std::string name = absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter", wrap_cond_in_function ? "_in_function" : "_in_graph"); return name; @@ -114,7 +114,7 @@ void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { auto identity_t = ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true); - auto seventeen = ops::Const( + auto seventeen = ops::Const( scope.WithOpName("cond").WithControlDependencies(identity_t), 17); auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less); auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true, @@ -122,7 +122,7 @@ void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { auto identity_f = ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false); - auto twenty_three = ops::Const( + auto twenty_three = ops::Const( scope.WithOpName("cond").WithControlDependencies(identity_f), 23); auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); auto add = ops::Add(scope.WithOpName("cond/false/add"), @@ -146,7 +146,7 @@ void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { void ConditionalTestFixture::CheckGraphDef( const GraphDef& graph_def, const FunctionLibraryDefinition& library) { - string op_name; + std::string op_name; NameAttrList then_fn; NameAttrList else_fn; TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); @@ -285,7 +285,7 @@ void ConditionalTestFixture::RunTest() { FunctionLibraryRuntime::Handle handle; // Functionalized function name is the type string of `cond_node`. - string func_name; + std::string func_name; for (Node* n : graph.nodes()) { if (n->name() == "cond_node") { func_name = n->type_string(); @@ -341,7 +341,7 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); auto merge = ops::Merge(scope.WithOpName("while/Merge"), std::initializer_list{enter, dummy}); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 10); auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); @@ -352,7 +352,7 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { switch_.output_false); auto identity = ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity, one); auto next_iteration = @@ -405,7 +405,7 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0); @@ -427,7 +427,7 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity, one); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); @@ -463,7 +463,8 @@ FunctionDef GetNoinlineFunctionDef() { // return [x + 1] // Define the above function, and add it to the given graph. It's used as the // while loop body in NoinlineLoopBody test. -absl::Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { +absl::Status AddNoinlineFunctionToGraph(const std::string& node_name, + Graph* graph) { FunctionDefLibrary fdef_lib; *(fdef_lib.add_function()) = GetNoinlineFunctionDef(); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); @@ -481,7 +482,7 @@ absl::Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { // x = array_ops.placeholder(dtypes.int32) // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) TEST(FunctionalizeControlFlow, NoinlineLoopBody) { - const string& noinline_node_name = "while/increment_fn"; + const std::string& noinline_node_name = "while/increment_fn"; Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -491,7 +492,7 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { "while/while_context"); auto merge = ops::Merge(scope.WithOpName("while/Merge"), std::initializer_list{enter, dummy}); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 10); auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); @@ -585,7 +586,7 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { } TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { - const string& noinline_node_name = "while/increment_fn"; + const std::string& noinline_node_name = "while/increment_fn"; Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -622,7 +623,7 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); auto merge = ops::Merge(scope.WithOpName("while/Merge"), std::initializer_list{enter, dummy}); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 10); auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); @@ -631,7 +632,7 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); auto identity = ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity, one); auto next_iteration = @@ -673,7 +674,7 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0); @@ -695,7 +696,7 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity, one); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); @@ -739,14 +740,15 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { std::initializer_list{enter_y, dummy}); // Loop condition - auto three = ops::Const(scope.WithOpName("while/cond/three") - .WithControlDependencies(merge_x.output), - 3); + auto three = + ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(merge_x.output), + 3); auto cond_add = ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three); - auto ten = ops::Const(scope.WithOpName("while/cond/ten") - .WithControlDependencies(merge_x.output), - 10); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") + .WithControlDependencies(merge_x.output), + 10); auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); @@ -765,10 +767,10 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), switch_y.output_true); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/one").WithControlDependencies(identity_x), 1); - auto two = ops::Const( + auto two = ops::Const( scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), 2); @@ -825,14 +827,15 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto three = ops::Const(scope.WithOpName("while/cond/three") - .WithControlDependencies(arg0.output), - 3); + auto three = + ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); auto cond_add = ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); - auto ten = ops::Const(scope.WithOpName("while/cond/ten") - .WithControlDependencies(arg0.output), - 10); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") + .WithControlDependencies(arg0.output), + 10); auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0); @@ -859,10 +862,10 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/one").WithControlDependencies(identity_x), 1); - auto two = ops::Const( + auto two = ops::Const( scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), 2); @@ -922,7 +925,7 @@ INSTANTIATE_TEST_SUITE_P( bool mark_inner_loop_tpu = std::get<1>(info.param); bool mark_outer_loop_tpu = std::get<2>(info.param); - string node_string; + std::string node_string; if (mark_inner_loop_tpu && mark_outer_loop_tpu) node_string = "both_loops_tpu"; else if (!mark_inner_loop_tpu && !mark_outer_loop_tpu) @@ -930,7 +933,7 @@ INSTANTIATE_TEST_SUITE_P( else node_string = mark_inner_loop_tpu ? "inner_loop_tpu" : "outer_loop_tpu"; - string name = absl::StrCat( + std::string name = absl::StrCat( restrict_to_tpu_nodes ? "restricted_" : "unrestricted_", node_string); return name; }); @@ -961,21 +964,21 @@ void ComplexTestFixture::RunTest() { auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); + auto three = ops::Const(scope.WithOpName("three"), 3); auto y = ops::Add(scope.WithOpName("y"), x, three); auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, TensorShape({})); // Outer loop - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); auto enter_i = ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer"); auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"), std::initializer_list{enter_i, dummy}); - auto ten = ops::Const(scope.WithOpName("outer/Less/y") - .WithControlDependencies(merge_i.output), - 10); + auto ten = ops::Const(scope.WithOpName("outer/Less/y") + .WithControlDependencies(merge_i.output), + 10); auto less_i = ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten); auto outer_loop_cond = @@ -998,7 +1001,7 @@ void ComplexTestFixture::RunTest() { ops::internal::Enter::Attrs().IsConstant(true)); // Inner loop - auto one_j = ops::Const( + auto one_j = ops::Const( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"), one_j, "inner"); @@ -1018,9 +1021,10 @@ void ComplexTestFixture::RunTest() { auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"), std::initializer_list{enter_k, dummy}); - auto five = ops::Const(scope.WithOpName("outer/inner/Five") - .WithControlDependencies(merge_j.output), - 5); + auto five = + ops::Const(scope.WithOpName("outer/inner/Five") + .WithControlDependencies(merge_j.output), + 5); auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five); auto loop_cond = @@ -1047,7 +1051,7 @@ void ComplexTestFixture::RunTest() { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("outer/inner/One") .WithControlDependencies( absl::Span{assign.operation}), @@ -1061,7 +1065,7 @@ void ComplexTestFixture::RunTest() { scope.WithOpName("outer/inner/NextIteration_k"), identity_k); // Body and backedge for outer loop. - auto one_outer = ops::Const( + auto one_outer = ops::Const( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") @@ -1086,9 +1090,10 @@ void ComplexTestFixture::RunTest() { } // Add '_tpu_replicate' attributes as specified. for (Node* n : graph.nodes()) { - string name = n->name(); - bool is_inner_node = name.find("outer/inner/") != string::npos; - bool is_outer_node = !is_inner_node && name.find("outer/") != string::npos; + std::string name = n->name(); + bool is_inner_node = name.find("outer/inner/") != std::string::npos; + bool is_outer_node = + !is_inner_node && name.find("outer/") != std::string::npos; if ((is_inner_node && mark_inner_loop_tpu_) || (is_outer_node && mark_outer_loop_tpu_)) { n->AddAttr("_tpu_replicate", "cluster"); @@ -1159,13 +1164,13 @@ void ComplexTestFixture::CheckOuterNodesFunctionalized( { Scope scope = Scope::NewRootScope().ExitOnError(); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); + auto three = ops::Const(scope.WithOpName("three"), 3); auto y = ops::Add(scope.WithOpName("y"), x, three); auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, TensorShape({})); - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), std::initializer_list{zero, y, x, var}, @@ -1184,7 +1189,7 @@ void ComplexTestFixture::CheckOuterNodesFunctionalized( auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), 10); auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); @@ -1220,14 +1225,14 @@ void ComplexTestFixture::CheckOuterNodesFunctionalized( auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); - auto one_j = ops::Const( + auto one_j = ops::Const( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto while_op = ops::While(scope.WithOpName("outer/inner/LoopCond"), std::initializer_list{one_j, arg1, arg2, arg3}, inner_cond_fn, inner_body_fn); - auto one_outer = ops::Const( + auto one_outer = ops::Const( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") @@ -1262,7 +1267,7 @@ void ComplexTestFixture::CheckInnerNodesFunctionalized( auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); - auto five = ops::Const( + auto five = ops::Const( scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0); @@ -1299,7 +1304,7 @@ void ComplexTestFixture::CheckInnerNodesFunctionalized( auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("outer/inner/One") .WithControlDependencies( absl::Span{assign.operation}), diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index cf3413154b8baa..d8558e7fb2b5fe 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -42,7 +42,7 @@ absl::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { absl::Status ExtractWhileLoopFrames( const std::vector& cf_info, const Graph* graph, - std::unordered_map* frames, + std::unordered_map* frames, const NodeFilter& node_filter) { for (Node* node : graph->op_nodes()) { const ControlFlowInfo& cf = cf_info[node->id()]; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index 970f62daa42af3..90c50f75e36387 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -47,7 +47,7 @@ struct WhileLoopArg { // Information about a loop frame. struct WhileLoopFrame { - string name; + std::string name; // Pointer to the parent frame. The root frame has a pointer to itself. WhileLoopFrame* parent = nullptr; @@ -76,7 +76,7 @@ struct WhileLoopFrame { // `FunctionalizeControlFlow` for more details about node filters). absl::Status ExtractWhileLoopFrames( const std::vector& cf_info, const Graph* graph, - std::unordered_map* frames, + std::unordered_map* frames, const NodeFilter& node_filter = {}); // Check that the graph has no cycle containing the given node. @@ -97,10 +97,10 @@ absl::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); // Returns a textual representation of the names of the nodes in the input. template -string NodesToString(const T& nodes) { +std::string NodesToString(const T& nodes) { return absl::StrCat("{", absl::StrJoin(nodes, ",", - [](string* output, const Node* node) { + [](std::string* output, const Node* node) { absl::StrAppend(output, node->name()); }), "}"); diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 2c02379c36cd45..b8183afd59481a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -438,7 +438,7 @@ absl::Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, builder.Attr("body", body_name); // Add some internal attributes which need to be propagated. for (absl::string_view attr_name : kAttrsToPropagate) { - string attr_val; + std::string attr_val; if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) { builder.Attr(attr_name, attr_val); } @@ -513,7 +513,7 @@ absl::Status FunctionalizeWhileLoop(Graph* graph, // connected to all source nodes in the graph. Many graphs violate this // invariant. std::vector cf_info; - std::vector unreachable_nodes; + std::vector unreachable_nodes; TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); if (!unreachable_nodes.empty()) { return errors::InvalidArgument( @@ -522,7 +522,7 @@ absl::Status FunctionalizeWhileLoop(Graph* graph, } // Builds Frames, indexed by name. - std::unordered_map frames; + std::unordered_map frames; TF_RETURN_IF_ERROR( ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter)); diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc index 2759ad8384cd81..b331272a2c9504 100644 --- a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -42,7 +42,7 @@ limitations under the License. namespace tensorflow { namespace { -absl::Status GetTestDevice(Session* session, string* test_device) { +absl::Status GetTestDevice(Session* session, std::string* test_device) { std::vector devices; TF_RETURN_IF_ERROR(session->ListDevices(&devices)); @@ -85,7 +85,7 @@ TEST(FusedBatchnormReserveSpaceTest, Test) { std::unique_ptr session( tensorflow::NewSession(tensorflow::SessionOptions{})); - string test_device; + std::string test_device; TF_ASSERT_OK(GetTestDevice(session.get(), &test_device)); Scope root = tensorflow::Scope::NewRootScope(); @@ -108,8 +108,8 @@ TEST(FusedBatchnormReserveSpaceTest, Test) { Output variance = Const(root.WithOpName("variance"), Input::Initializer(variance_data)); - string tf_device = absl::StrCat("/device:", test_device, ":0"); - string xla_device = absl::StrCat("/device:XLA_", test_device, ":0"); + std::string tf_device = absl::StrCat("/device:", test_device, ":0"); + std::string xla_device = absl::StrCat("/device:XLA_", test_device, ":0"); FusedBatchNorm fused_batch_norm_tf( root.WithOpName("fused_batch_norm_tf").WithDevice(tf_device), input, diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index f23c423fbb2632..5f794005b7c7c0 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -292,12 +292,12 @@ absl::Status GraphCompiler::CompileFunctionalNode(Node* n, } } if (add_token_input_output) { - std::vector token_input_nodes; + std::vector token_input_nodes; TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()), kXlaTokenInputNodesAttrName, &token_input_nodes)); std::vector token_inputs; - for (const string& node_name : token_input_nodes) { + for (const std::string& node_name : token_input_nodes) { auto token_or = compiler->GetNodeToken(node_name); TF_RETURN_IF_ERROR(token_or.status()); token_inputs.push_back(std::move(token_or).value()); diff --git a/tensorflow/compiler/tf2xla/graph_compiler_test.cc b/tensorflow/compiler/tf2xla/graph_compiler_test.cc index 3010ac7f0b026b..2dcb2ea0b52d45 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_test.cc @@ -104,8 +104,8 @@ class GraphCompilerTest : public ::testing::Test { core::ScopedUnref context_unref(xla_context); xla_context->Ref(); - auto step_container = - std::make_unique(0, [this](const string& name) { + auto step_container = std::make_unique( + 0, [this](const std::string& name) { absl::Status status = this->device_->resource_manager()->Cleanup(name); }); diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index d1c984e26f390a..116c1e68f66fe6 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -44,7 +44,7 @@ const char* const kFetchIdAttr = "_fetch_id"; const char* const kShapeAttr = "_shape"; const char* const kDebugNameAttr = "_debug_name"; -typedef std::unordered_map NodeMap; +typedef std::unordered_map NodeMap; // Each feed id identifies the positional output of some node, which may consist // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed @@ -54,14 +54,14 @@ typedef std::unordered_map NodeMap; absl::Status AddArgNodes( Graph* graph, const NodeMap& node_map, const protobuf::RepeatedPtrField& feeds, - const std::unordered_map& feed_remapping, + const std::unordered_map& feed_remapping, std::unordered_set* arg_nodes) { for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { const tf2xla::Feed& feed = feeds[arg_index]; // All feeds have been replaced by placeholders. const int output_index = 0; - const string key = TensorIdToString(feed.id()); + const std::string key = TensorIdToString(feed.id()); const auto remap_it = feed_remapping.find(key); auto node_it = node_map.find(remap_it->second); if (node_it == node_map.end()) { @@ -149,7 +149,7 @@ absl::Status AddRetvalNodes( // execution to know the input and output args for the generated function. absl::Status RewriteAndPruneGraph( Graph* graph, const tf2xla::Config& config, - const std::unordered_map& feed_remapping) { + const std::unordered_map& feed_remapping) { NodeMap node_map; for (Node* n : graph->nodes()) { node_map[n->name()] = n; @@ -164,7 +164,7 @@ absl::Status RewriteAndPruneGraph( FixupSourceAndSinkEdges(graph); VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); // Sanity-check, to make sure the feeds and fetches still exist post-pruning. - std::set missing_feeds, missing_fetches; + std::set missing_feeds, missing_fetches; for (const tf2xla::Feed& feed : config.feed()) { missing_feeds.insert(TensorIdToString(feed.id())); } @@ -173,14 +173,14 @@ absl::Status RewriteAndPruneGraph( } for (const Node* n : graph->op_nodes()) { if (n->type_string() == FunctionLibraryDefinition::kArgOp) { - string feed_id; + std::string feed_id; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); if (missing_feeds.erase(feed_id) == 0) { return errors::Aborted(FunctionLibraryDefinition::kArgOp, " node found with unknown feed id: ", feed_id); } } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { - string fetch_id; + std::string fetch_id; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); if (missing_fetches.erase(fetch_id) == 0) { return errors::Aborted(FunctionLibraryDefinition::kRetOp, @@ -277,7 +277,7 @@ absl::Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, GraphDef first_copy_def = graph_def; // Maps from name:port of a feed to the name:port of the placeholder to use. - std::unordered_map feed_remapping; + std::unordered_map feed_remapping; TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(), &feed_remapping, &first_copy_def)); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 5079ddd4389bd8..bb50d530484b10 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -371,7 +371,6 @@ cc_library( "@local_xla//xla/hlo/translate:stablehlo", "@local_xla//xla/mlir/utils:error_util", "@local_xla//xla/mlir/utils:type_util", - "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/python:refine_polymorphic_shapes", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", @@ -382,6 +381,7 @@ cc_library( "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_passes", + "@stablehlo//:stablehlo_passes_optimization", "@stablehlo//:stablehlo_serialization", "@stablehlo//:vhlo_ops", ], diff --git a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc index a6ddbfd3a01fef..74c888d37de784 100644 --- a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc @@ -94,9 +94,9 @@ class CollectiveReduceV2Op : public XlaOpKernel { private: DataType dtype_ = DT_INVALID; - string merge_op_name_; - string final_op_name_; - string communication_hint_; + std::string merge_op_name_; + std::string final_op_name_; + std::string communication_hint_; CollectiveReduceV2Op(const CollectiveReduceV2Op&) = delete; void operator=(const CollectiveReduceV2Op&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 0dd528e3dea173..240a099f075aa2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -48,7 +48,7 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); OP_REQUIRES_OK( ctx, ctx->GetAttr("exponential_avg_factor", &exponential_avg_factor_)); - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), @@ -61,7 +61,7 @@ class FusedBatchNormOp : public XlaOpKernel { errors::InvalidArgument( "FusedBatchNormEx supports at most 1 side input.")); add_side_input_ = (num_side_inputs == 1); - string activation_mode; + std::string activation_mode; OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode)); OP_REQUIRES(ctx, activation_mode == "Identity" || activation_mode == "Relu", @@ -249,7 +249,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 7c89720292b0a7..94486a104152ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -66,9 +66,11 @@ class BCastArgsOp : public XlaOpKernel { Tensor output(val_type, TensorShape({len})); for (int64_t i = 0; i < len; ++i) { if (val_type == DT_INT32) { - output.flat()(i) = static_cast(bcast.output_shape()[i]); + output.flat()(i) = + static_cast(bcast.output_shape()[i]); } else { - output.flat()(i) = static_cast(bcast.output_shape()[i]); + output.flat()(i) = + static_cast(bcast.output_shape()[i]); } } ctx->SetConstantOutput(0, output); @@ -129,9 +131,9 @@ class BCastGradArgsOp : public XlaOpKernel { Tensor constant(val_type, TensorShape({len})); for (int64_t i = 0; i < len; ++i) { if (val_type == DT_INT32) { - constant.flat()(i) = static_cast(v[i]); + constant.flat()(i) = static_cast(v[i]); } else { - constant.flat()(i) = static_cast(v[i]); + constant.flat()(i) = static_cast(v[i]); } } ctx->SetConstantOutput(idx, constant); diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 2bf4ab52c8b59e..bf428711664d76 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -28,7 +28,7 @@ namespace { class BiasOp : public XlaOpKernel { public: explicit BiasOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; + std::string data_format; if (ctx->GetAttr("data_format", &data_format).ok()) { OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index 510d5225d6f04b..7d323b16d8856e 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -55,7 +55,7 @@ class BucketizeOp : public XlaOpKernel { /*broadcast_dimensions=*/{0}), xla::S32); xla::XlaOp buckets = xla::Reduce( - comparison, /*init_value=*/xla::ConstantR0(builder, 0), + comparison, /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); context->SetOutput(0, buckets); diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index cead6d10c2a0eb..da40d84e73f063 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -66,7 +66,7 @@ XlaCaseOp::GetPrunedBranchesAndIndex(XlaOpKernelContext* ctx) { return {unpruned_branches_, ctx->Input(0)}; } - int32_t branch_index = branch_index_literal.Get({}); + int32_t branch_index = branch_index_literal.Get({}); if (branch_index < 0 || branch_index >= unpruned_branches_.size()) { branch_index = unpruned_branches_.size() - 1; } @@ -187,7 +187,8 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { // Add any TensorArray gradients touched by the then/else computation to // the enclosing graph. - for (const string& grad_source : update.tensor_array_gradients_accessed) { + for (const std::string& grad_source : + update.tensor_array_gradients_accessed) { VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " << grad_source; XlaResource* gradient; @@ -289,7 +290,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { // Set token input for this "case" op. std::vector token_inputs; token_inputs.reserve(token_input_nodes_.size()); - for (const string& node_name : token_input_nodes_) { + for (const std::string& node_name : token_input_nodes_) { auto token_or = compiler->GetNodeToken(node_name); OP_REQUIRES_OK(ctx, token_or.status()); token_inputs.push_back(token_or.value()); diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h index a4c01bea65a04d..6574fb4aac4c5e 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.h +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -65,8 +65,8 @@ class XlaCaseOp : public XlaOpKernel { DataTypeVector input_types_; DataTypeVector output_types_; bool has_token_input_output_; - std::vector token_input_nodes_; - string original_node_name_; + std::vector token_input_nodes_; + std::string original_node_name_; // Whether to propagate compile time consts into the cond branches. // This is not supported by default now since it may cause HBM memory // overheads. diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index e8c804791299a7..2c69974d8373dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -185,7 +185,7 @@ class StatelessCategoricalOp : public CategoricalOp { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessCategoricalOp(const StatelessCategoricalOp&) = delete; void operator=(const StatelessCategoricalOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index d2463a9974b1bb..7ab53f7ad89e75 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -38,7 +38,7 @@ template ::value>::type* = nullptr> DstT CastTo(int32_t src) { - return absl::bit_cast(static_cast(src)); + return absl::bit_cast(static_cast(src)); } // Returns scalar constant with the value in the tensor, if the given proto has diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 3fe22dcb4441e7..59f72e630c0f75 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -163,8 +163,8 @@ absl::Status CheckConvAttrs(const ConvOpAttrs& attrs) { absl::Status ConvBackpropComputeDimensionsV2XlaShapes( absl::string_view label, int num_spatial_dims, const xla::Shape& input_shape, const xla::Shape& filter_shape, - const xla::Shape& out_backprop_shape, absl::Span dilations, - const std::vector& strides, Padding padding, + const xla::Shape& out_backprop_shape, absl::Span dilations, + const std::vector& strides, Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims, absl::Span explicit_paddings) { TensorShape input_tensor_shape, filter_tensor_shape, @@ -203,7 +203,7 @@ absl::StatusOr ConvOpAttrs::Create(int num_spatial_dims, ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); } - string data_format; + std::string data_format; TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); if (!FormatFromString(data_format, &attrs.data_format)) { return errors::InvalidArgument("Invalid data format: ", data_format); @@ -231,7 +231,7 @@ absl::StatusOr ConvNDOpAttrs::Create(OpKernelConstruction* ctx) { ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); } - string data_format_str; + std::string data_format_str; TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format_str)); if (!(data_format_str == "CHANNELS_LAST" || data_format_str == "CHANNELS_FIRST")) { diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 94e454df205df2..e64cebe3970cd8 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -54,8 +54,8 @@ struct ConvOpAttrs { bool depthwise; int num_spatial_dims; - std::vector dilations; - std::vector strides; + std::vector dilations; + std::vector strides; Padding padding; std::vector explicit_paddings; TensorFormat data_format; @@ -68,8 +68,8 @@ struct ConvNDOpAttrs { int groups; int batch_dims; - std::vector dilations; - std::vector strides; + std::vector dilations; + std::vector strides; Padding padding; std::vector explicit_paddings; TensorFormat data_format; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index b1da0acd61608f..82fdf8ea577e39 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -92,9 +92,9 @@ class ConvNDOp : public XlaOpKernel { ConvOpAttrs forward_attrs; forward_attrs.depthwise = false; forward_attrs.num_spatial_dims = num_spatial_dims; - forward_attrs.dilations = attrs_.dilations.empty() - ? std::vector(num_spatial_dims + 2, 1) - : attrs_.dilations; + forward_attrs.dilations = + attrs_.dilations.empty() ? std::vector(num_spatial_dims + 2, 1) + : attrs_.dilations; forward_attrs.strides = attrs_.strides; forward_attrs.padding = attrs_.padding; forward_attrs.explicit_paddings = attrs_.explicit_paddings; diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index 226d6248bd00d8..27818415169dbe 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -36,9 +36,9 @@ class DataFormatDimMapOp : public XlaOpKernel { public: explicit DataFormatDimMapOp(OpKernelConstruction* context) : XlaOpKernel(context) { - string src_format; + std::string src_format; OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); - string dst_format; + std::string dst_format; OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5, errors::InvalidArgument( @@ -69,9 +69,9 @@ class DataFormatDimMapOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override { auto builder = context->builder(); xla::XlaOp dst_indices = - xla::ConstantR1(builder, absl::Span(dst_idx_)); + xla::ConstantR1(builder, absl::Span(dst_idx_)); const int dims = dst_idx_.size(); - xla::XlaOp rank = xla::ConstantR0(builder, dims); + xla::XlaOp rank = xla::ConstantR0(builder, dims); xla::XlaOp src_indices = (xla::ConvertElementType(context->Input(0), xla::S32) + rank) % rank; xla::XlaOp output = @@ -81,7 +81,7 @@ class DataFormatDimMapOp : public XlaOpKernel { } private: - std::vector dst_idx_; + std::vector dst_idx_; DataFormatDimMapOp(const DataFormatDimMapOp&) = delete; void operator=(const DataFormatDimMapOp&) = delete; @@ -146,13 +146,13 @@ class DataFormatVecPermuteOp : public XlaOpKernel { input_tensor_shape.DebugString())); } - string src_format_str = src_format_; - string dst_format_str = dst_format_; + std::string src_format_str = src_format_; + std::string dst_format_str = dst_format_; if (input_tensor_shape.dim_size(0) == spatial_dim_count) { // If the input is a vector of size spatial_dim_count, treat the elements // as spatial dimensions. auto keep_only_spatial_dimensions = - [spatial_dim_count](string* format_str) -> void { + [spatial_dim_count](std::string* format_str) -> void { auto new_end = std::remove_if(format_str->begin(), format_str->end(), [spatial_dim_count](const char dim) { @@ -164,7 +164,7 @@ class DataFormatVecPermuteOp : public XlaOpKernel { keep_only_spatial_dimensions(&src_format_str); keep_only_spatial_dimensions(&dst_format_str); } - std::vector dst_indices(dim0); + std::vector dst_indices(dim0); for (int i = 0; i < dim0; ++i) { for (int j = 0; j < dim0; ++j) { if (src_format_str[i] == dst_format_str[j]) { @@ -174,14 +174,14 @@ class DataFormatVecPermuteOp : public XlaOpKernel { } } xla::XlaOp indices = - xla::ConstantR1(builder, absl::Span(dst_indices)); + xla::ConstantR1(builder, absl::Span(dst_indices)); xla::XlaOp output = xla::TorchIndexSelect(ctx->Input(0), indices, 0); ctx->SetOutput(0, output); } private: - string src_format_; - string dst_format_; + std::string src_format_; + std::string dst_format_; DataFormatVecPermuteOp(const DataFormatVecPermuteOp&) = delete; void operator=(const DataFormatVecPermuteOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index e8e2babffd529c..7e93ed9c32e126 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -31,7 +31,7 @@ namespace { class DepthToSpaceOp : public XlaOpKernel { public: explicit DepthToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc index d383c7d0ab4aa3..bc03e14556f9cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc @@ -42,7 +42,7 @@ float get_fullrange() { class DequantizeOp : public XlaOpKernel { public: explicit DequantizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string mode_string; + std::string mode_string; int axis; bool narrow_range; diff --git a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc index 141415bcd0d8c0..a5665baa6e3dc5 100644 --- a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc @@ -39,11 +39,11 @@ class DeviceIndexOp : public XlaOpKernel { // When compiling we are not executing on any physical device, so we return // a sentinel value (size of the list of devices). ctx->SetOutput( - 0, xla::ConstantR0(ctx->builder(), device_names_.size())); + 0, xla::ConstantR0(ctx->builder(), device_names_.size())); } private: - std::vector device_names_; + std::vector device_names_; }; REGISTER_XLA_OP(Name("DeviceIndex"), DeviceIndexOp); diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index ceeea010ee7858..ae7488ad1e1cbd 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -54,8 +54,8 @@ class DynamicPartitionOp : public XlaOpKernel { xla::XlaOp CountS32(XlaOpKernelContext* ctx, xla::XlaOp input, int64_t target) { xla::XlaOp equal_dim = - xla::Compare(input, xla::ConstantR0(ctx->builder(), target), {}, - xla::ComparisonDirection::kEq); + xla::Compare(input, xla::ConstantR0(ctx->builder(), target), + {}, xla::ComparisonDirection::kEq); xla::XlaOp casted = xla::ConvertElementType(equal_dim, xla::S32); return xla::ReduceAll( casted, xla::Zero(ctx->builder(), xla::S32), @@ -178,8 +178,9 @@ class DynamicPartitionOp : public XlaOpKernel { } else { xla::XlaOp length; if (count_diff != 0) { - length = xla::Div(partition_length[i], - xla::ConstantR0(ctx->builder(), count_diff)); + length = + xla::Div(partition_length[i], + xla::ConstantR0(ctx->builder(), count_diff)); } else { length = CountS32(ctx, ctx->Input(1), /*target=*/i); } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index cb7e4f6f96437e..edf9afb5ae14fb 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -145,8 +145,8 @@ class DynamicStitchOp : public XlaOpKernel { // Construct the reverse mapping, for each index, of which slice of which // input it comes from. - std::vector src_input_vector(number_of_indices); - std::vector src_slice_vector(number_of_indices); + std::vector src_input_vector(number_of_indices); + std::vector src_slice_vector(number_of_indices); std::vector src_index_used(number_of_indices); int index_used_count = 0; for (int input_num = 0; input_num < indices.size(); input_num++) { diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 4a1de78d9371b3..b9ca65cfbd6371 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -179,9 +179,9 @@ class ExtractImagePatchesOp : public XlaOpKernel { } protected: - std::vector ksizes_; - std::vector dilations_; - std::vector strides_; + std::vector ksizes_; + std::vector dilations_; + std::vector strides_; Padding padding_; private: diff --git a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc index b2b1eb3343e698..8075982c766a97 100644 --- a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc @@ -154,7 +154,7 @@ class FusedConv2DInt8Op : public XlaOpKernel { // Un-vectorize NCHW_VECT_C to NCHW. TensorFormat orig_data_format = conv_attrs_.data_format; - int64 vect_width = -1; + int64_t vect_width = -1; switch (conv_attrs_.data_format) { case FORMAT_NCHW_VECT_C: vect_width = conv_input_shape.dimensions(4); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2783951e1b6b0f..e94f74d1fed8ef 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -275,7 +275,7 @@ class GatherOp : public XlaOpKernel { // The number of batch dimensions, as passed in the batch_dims attribute. // It must be less than or equal to rank(indices). - int32 batch_dims_ = 0; + int32_t batch_dims_ = 0; }; REGISTER_XLA_OP(Name("Gather"), MlirXlaOpKernel); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc index 033144e9f308e4..2aec21a6db5888 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc @@ -28,7 +28,7 @@ namespace { class GatherOp : public XlaOpKernel { public: explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) { - string dnums_attr; + std::string dnums_attr; OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); OP_REQUIRES( context, dnums_.ParsePartialFromString(dnums_attr), @@ -60,7 +60,7 @@ class ScatterOp : public XlaOpKernel { explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) { OP_REQUIRES_OK( context, context->GetAttr("update_computation", &update_computation_)); - string dnums_attr; + std::string dnums_attr; OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); OP_REQUIRES( context, dnums_.ParsePartialFromString(dnums_attr), diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 17db09722ba954..56c86d3d597227 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -84,7 +84,8 @@ static absl::StatusOr PopulateTensorArrayGradients( // Add any TensorArray gradients touched by the then/else computation to // the enclosing graph. - for (const string& grad_source : update.tensor_array_gradients_accessed) { + for (const std::string& grad_source : + update.tensor_array_gradients_accessed) { VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " << grad_source; XlaResource* gradient; @@ -318,7 +319,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { if (has_token_input_output_ && i == num_inputs - 1) { // Set token input for this "if" op. std::vector token_inputs; - for (const string& node_name : token_input_nodes_) { + for (const std::string& node_name : token_input_nodes_) { auto token_or = compiler->GetNodeToken(node_name); OP_REQUIRES_OK(ctx, token_or.status()); token_inputs.push_back(token_or.value()); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index fc6dd2e08bf41f..c11cfcb08e0b09 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -61,8 +61,8 @@ class XlaIfOp : public XlaOpKernel { DataTypeVector output_types_; std::vector output_shapes_; bool has_token_input_output_; - std::vector token_input_nodes_; - string original_node_name_; + std::vector token_input_nodes_; + std::string original_node_name_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index a8eb7bbf794268..a2676e095b91b7 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -352,10 +352,11 @@ struct WhileCondFn { xla::XlaBuilder* cond_builder) const { xla::XlaOp row_idx = values[0]; xla::XlaOp row_in_bounds = - xla::Lt(row_idx, xla::ConstantR0(cond_builder, num_boxes)); + xla::Lt(row_idx, xla::ConstantR0(cond_builder, num_boxes)); xla::XlaOp num_outputs_so_far = values[1]; - xla::XlaOp results_not_full = xla::Lt( - num_outputs_so_far, xla::ConstantR0(cond_builder, output_size)); + xla::XlaOp results_not_full = + xla::Lt(num_outputs_so_far, + xla::ConstantR0(cond_builder, output_size)); return xla::And(row_in_bounds, results_not_full); } }; @@ -375,7 +376,7 @@ struct SuppressBodyFn { auto num_outputs_so_far = values[1]; auto iou_mask = values[2]; auto included_iou = values[3]; - auto zero = xla::ConstantR0(builder, 0); + auto zero = xla::ConstantR0(builder, 0); // Determine if current elem is active using a slice. // TODO(b/118437727): The only reason we need an explicit vector is because // some old GCCs can't deduce the right type for MakeConstSpan, and @@ -386,7 +387,7 @@ struct SuppressBodyFn { active_elem = xla::Reshape(active_elem, {}); // Increment output count iff current elem is not suppressed. num_outputs_so_far = xla::Select( - active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), + active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), num_outputs_so_far); // Slice out the row_idx. auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); @@ -412,7 +413,7 @@ struct SuppressBodyFn { } included_iou = xla::Select(cond, xla::And(included_iou, supp_mask), included_iou); - row_idx = row_idx + xla::ConstantR0(builder, 1); + row_idx = row_idx + xla::ConstantR0(builder, 1); return std::vector{row_idx, num_outputs_so_far, iou_mask, included_iou}; } @@ -456,7 +457,7 @@ class NonMaxSuppressionOp : public XlaOpKernel { errors::InvalidArgument( "scores size ", std::to_string(scores_shape.dim_size(0)), " must equal number of boxes ", std::to_string(num_boxes))); - OP_REQUIRES(context, num_boxes <= kint32max, + OP_REQUIRES(context, num_boxes <= std::numeric_limits::max(), errors::InvalidArgument("XLA compilation requires number of " "boxes to be <= kint32max, got ", num_boxes)); @@ -477,7 +478,7 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES( context, output_size >= 0, errors::InvalidArgument("Need output_size >= 0, got ", output_size)); - OP_REQUIRES(context, output_size <= kint32max, + OP_REQUIRES(context, output_size <= std::numeric_limits::max(), errors::InvalidArgument("Need output_size <= kint32Max, got ", output_size)); const xla::XlaOp score_thresh = context->Input("score_threshold"); @@ -564,8 +565,8 @@ class NonMaxSuppressionOp : public XlaOpKernel { std::vector init_values; init_values.reserve(4); - init_values.push_back(xla::ConstantR0(builder, 0)); // col_idx - init_values.push_back(xla::ConstantR0(builder, 0)); // num_outputs + init_values.push_back(xla::ConstantR0(builder, 0)); // col_idx + init_values.push_back(xla::ConstantR0(builder, 0)); // num_outputs init_values.push_back(iou_thresh_mask); init_values.push_back(included_iou); @@ -595,8 +596,8 @@ class NonMaxSuppressionOp : public XlaOpKernel { // can be suppressed by score threshold. xla::XlaOp ones_included = xla::Select( included, - xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), - xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); + xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), + xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); // num_valid is scalar. Value should be bound by output_size. xla::XlaOp num_valid_total = xla::Reduce( @@ -604,8 +605,8 @@ class NonMaxSuppressionOp : public XlaOpKernel { /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); - xla::XlaOp num_valid = - xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); + xla::XlaOp num_valid = xla::Min( + num_valid_total, xla::ConstantR0(builder, output_size)); // Re-index into the original scores input tensor, using a Gather. // Boxes were suppressed in the sorted domain. diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 58811c10744131..9959f8d4e44be6 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -120,8 +120,8 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( const int64_t out_size_factor = align_corners ? out_size[i] - 1 : out_size[i]; - int64_t gcd = MathUtil::GCD(static_cast(in_size_factor), - static_cast(out_size_factor)); + int64_t gcd = MathUtil::GCD(static_cast(in_size_factor), + static_cast(out_size_factor)); dims.stride[i] = in_size_factor / gcd; dims.kernel_size[i] = out_size_factor / gcd; } diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc index f357262a39c35b..5b730cc0a9076d 100644 --- a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -96,7 +96,7 @@ class InTopKOp : public XlaOpKernel { xla::CreateScalarAddComputation(xla::S32, xla_builder), {1}); xla::XlaOp result = - xla::And(xla::Lt(num_gt_r1, xla::ConstantR0(xla_builder, k)), + xla::And(xla::Lt(num_gt_r1, xla::ConstantR0(xla_builder, k)), xla::IsFinite(targets_values_r1)); context->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index 85d70705c83837..390bc09c33057d 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -464,7 +464,7 @@ class TfCallbackDevice : public DeviceBase { set_tensorflow_accelerator_device_info(&accelerator_device_info_); } - const string& name() const override { return name_; } + const std::string& name() const override { return name_; } PerOpGpuDevice* MakeGpuDevice() override { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index dfe8a36005b837..aabbd8d3b0514e 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -60,7 +60,7 @@ class ListDiffOp : public XlaOpKernel { absl::Status status; switch (val_type) { case DT_INT32: - status = ListDiffWithIndexType(context, idx_type); + status = ListDiffWithIndexType(context, idx_type); break; case DT_INT64: status = ListDiffWithIndexType(context, idx_type); @@ -111,7 +111,7 @@ class ListDiffOp : public XlaOpKernel { DataType idx_type) { switch (idx_type) { case DT_INT32: - return ListDiff(context); + return ListDiff(context); case DT_INT64: return ListDiff(context); default: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index 48e8f976cc67bb..8e7c966bdf35fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -57,7 +57,7 @@ static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal, void ReadAlignment(OpKernelConstruction* context, bool* left_align_superdiagonal, bool* left_align_subdiagonal) { - string align; + std::string align; OP_REQUIRES_OK(context, context->GetAttr("align", &align)); *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT"; diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 82dbfb3839312c..215de2bc5067e4 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -78,7 +78,7 @@ class OneHotOp : public XlaOpKernel { } private: - int32 axis_; + int32_t axis_; OneHotOp(const OneHotOp&) = delete; void operator=(const OneHotOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 1758451faf469f..15b2b5f9d2ebbb 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -113,7 +113,7 @@ class PadOp : public XlaOpKernel { high_pad_size = xla::Reshape(high_pad_size, {}); high_pad_size = xla::ConvertElementType(high_pad_size, xla::S32); // Low pad has to be static. - xla::XlaOp low_pad_size = xla::ConstantR0( + xla::XlaOp low_pad_size = xla::ConstantR0( ctx->builder(), pad_literal.Get({i, 0})); xla::XlaOp input_size = xla::GetDimensionSize(input, i); xla::XlaOp total_size = low_pad_size + input_size + high_pad_size; @@ -122,7 +122,7 @@ class PadOp : public XlaOpKernel { total_size, xla::ValueInferenceMode::kUpperBound); OP_REQUIRES_OK(ctx, size_upper_bound_status_or.status()); auto size_upper_bound = - size_upper_bound_status_or.value().Get({}); + size_upper_bound_status_or.value().Get({}); OP_REQUIRES( ctx, size_upper_bound.has_value(), errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index aa7c78b8b8f97a..77db609d997614 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -88,8 +88,8 @@ class PoolingOp : public XlaOpKernel { num_spatial_dims_(num_spatial_dims), reduction_type_(reduction_type) { if (ctx->num_inputs() == 1) { - std::vector ksize_int; - std::vector stride_int; + std::vector ksize_int; + std::vector stride_int; OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); OP_REQUIRES(ctx, ksize_int.size() == num_dims(), errors::InvalidArgument("Sliding window ksize field must " @@ -255,15 +255,15 @@ class MaxPoolOp : public PoolingOp { ctx->builder()->GetShape(pooling); OP_REQUIRES_OK(ctx, result_shape.status()); - int64 num_channels = result_shape->dimensions(1); + int64_t num_channels = result_shape->dimensions(1); OP_REQUIRES( ctx, num_channels % *vect_width == 0, errors::FailedPrecondition("Result of NCHW_VECT_C op must have " "channels multiple of ", *vect_width, ", but was ", num_channels)); - absl::InlinedVector new_dims(result_shape->dimensions().begin(), - result_shape->dimensions().end()); + absl::InlinedVector new_dims( + result_shape->dimensions().begin(), result_shape->dimensions().end()); new_dims[1] /= *vect_width; new_dims.insert(new_dims.begin() + 2, *vect_width); pooling = @@ -298,7 +298,7 @@ class AvgPoolOp : public PoolingOp { : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) { - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); @@ -466,7 +466,7 @@ class MaxPool2DGradOp : public MaxPoolGradOp { public: explicit MaxPool2DGradOp(OpKernelConstruction* ctx) : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; + std::string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); @@ -505,7 +505,7 @@ class AvgPoolGradOp : public XlaOpKernel { errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); - string data_format; + std::string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); @@ -561,7 +561,7 @@ class AvgPoolGradOp : public XlaOpKernel { protected: const int num_spatial_dims_; std::vector ksize_; - std::vector stride_; + std::vector stride_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; }; @@ -677,7 +677,7 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto b = ctx->builder(); - auto sixteen = xla::ConstantR0(b, 16); + auto sixteen = xla::ConstantR0(b, 16); // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32. // // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back @@ -702,7 +702,7 @@ class MaxPoolGradGradOp : public XlaOpKernel { const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs"); auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs"); - auto sixteen = xla::ConstantR0(rb.get(), 16); + auto sixteen = xla::ConstantR0(rb.get(), 16); auto lhs_criteria = xla::ShiftLeft(xla::ShiftRightLogical( xla::BitcastConvertType(lhs, xla::S32), sixteen), @@ -749,7 +749,7 @@ class MaxPool2DGradGradOp : public MaxPoolGradGradOp { public: explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx) : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; + std::string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); @@ -767,7 +767,7 @@ class MaxPool3DGradGradOp : public MaxPoolGradGradOp { public: explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx) : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) { - string data_format; + std::string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index cac9f8a68f234e..961fce9caa7728 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -113,7 +113,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { errors::Internal("Expected 4 inputs to QuantizeAndDequantize")); num_bits = ctx->Input(3); } else { - num_bits = xla::ConstantR0(b, num_bits_); + num_bits = xla::ConstantR0(b, num_bits_); } const xla::XlaOp zero = XlaHelpers::Zero(b, data_type); @@ -129,17 +129,17 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { xla::XlaOp min_quantized, max_quantized; if (signed_input_) { if (narrow_range_) { - min_quantized = - -Pow(two, ConvertElementType( - num_bits - xla::ConstantR0(b, 1), xla_type)) + - one; + min_quantized = -Pow(two, ConvertElementType( + num_bits - xla::ConstantR0(b, 1), + xla_type)) + + one; } else { min_quantized = -Pow(two, ConvertElementType( - num_bits - xla::ConstantR0(b, 1), xla_type)); + num_bits - xla::ConstantR0(b, 1), xla_type)); } max_quantized = - Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), + Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), xla_type)) - one; } else { @@ -222,7 +222,7 @@ class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), errors::InvalidArgument("num_bits is out of range: ", num_bits_, " with signed_input_ ", signed_input_)); - string round_mode_string; + std::string round_mode_string; OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string)); OP_REQUIRES( ctx, diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc index 8f2350f26861c4..dea3ecf85af7b8 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc @@ -140,7 +140,7 @@ absl::StatusOr GetAlgId(XlaOpKernelContext* ctx, int alg_input_idx) { if (alg_dtype == DT_INT32) { return alg_literal.Get({}); } else { - return alg_literal.Get({}); + return alg_literal.Get({}); } } @@ -172,7 +172,7 @@ DataType MaybeConvertBF16ToF32(DataType const& dtype) { } absl::StatusOr BuildUniformRandoms( - XlaOpKernelContext* ctx, DataType dtype, string device_type_string, + XlaOpKernelContext* ctx, DataType dtype, std::string device_type_string, TensorShape shape, std::function lo_fn, std::function hi_fn) { @@ -190,7 +190,7 @@ absl::StatusOr BuildUniformRandoms( absl::StatusOr BuildUniformRandoms(XlaOpKernelContext* ctx, DataType dtype, - string device_type_string, + std::string device_type_string, xla::Shape xla_shape, xla::XlaOp lo, xla::XlaOp hi) { xla::XlaOp key = ctx->Input(kRandomKeyInputIdx); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h index 11ff44602f1900..5fb7aa4822834c 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -73,7 +73,7 @@ DataType MaybeConvertBF16ToF32(DataType const& dtype); // type, in the given low and high range, where low and high are expressed in // XLA functions. absl::StatusOr BuildUniformRandoms( - XlaOpKernelContext* ctx, DataType dtype, string device_type_string, + XlaOpKernelContext* ctx, DataType dtype, std::string device_type_string, TensorShape shape, std::function lo, std::function hi); @@ -82,7 +82,7 @@ absl::StatusOr BuildUniformRandoms( // ops. absl::StatusOr BuildUniformRandoms(XlaOpKernelContext* ctx, DataType dtype, - string device_type_string, + std::string device_type_string, xla::Shape xla_shape, xla::XlaOp lo, xla::XlaOp hi); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 6a8a98342c1123..3bfe9e384405b2 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -119,7 +119,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { } } - string desc = ctx->op_kernel().name(); + std::string desc = ctx->op_kernel().name(); xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index c54c4613d29e44..a1dd0164e73fc7 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -311,7 +311,7 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type), xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); - auto shifting_value = xla::ConstantR1( + auto shifting_value = xla::ConstantR1( ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); auto shifted_gather_indices = xla::Add(gather_indices, shifting_value, {last_warp_dim}); @@ -384,7 +384,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::Pad(data, xla::Zero(ctx->builder(), data_type), xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); - auto shifting_value = xla::ConstantR1( + auto shifting_value = xla::ConstantR1( ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); auto shifted_gather_indices = xla::Add(gather_indices, shifting_value, {last_warp_dim}); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 5cecbf37706283..5c77a4dfe29934 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -134,8 +134,8 @@ class ReverseSequenceOp : public XlaOpKernel { } private: - int32 batch_dim_; - int32 seq_dim_; + int32_t batch_dim_; + int32_t seq_dim_; }; REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp); diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index e1e93d614286a3..32b75c26c70212 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -35,7 +35,7 @@ class SendOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override; private: - string tensor_name_; + std::string tensor_name_; SendOp(const SendOp&) = delete; void operator=(const SendOp&) = delete; @@ -60,7 +60,7 @@ class RecvOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override; private: - string tensor_name_; + std::string tensor_name_; xla::Shape shape_; RecvOp(const RecvOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 108bf3848aae93..d24d1688d188a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -104,7 +104,8 @@ class RangeOp : public XlaOpKernel { absl::StatusOr output; switch (type) { case DT_INT32: - output = CreateRangeTensor(start, limit, delta, ctx->builder()); + output = + CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_INT64: output = diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 7e8889cb2ccee6..07bf81e9d76b58 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -109,7 +109,7 @@ class XlaSetBoundOp : public XlaOpKernel { bound_shape.DebugString())); int64_t bound; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound)); - xla::Literal bound_literal = xla::LiteralUtil::CreateR0(bound); + xla::Literal bound_literal = xla::LiteralUtil::CreateR0(bound); xla::XlaOp result = xla::CustomCall( ctx->builder(), "SetBound", {ctx->Input("input")}, ctx->InputXlaShape("input").value(), "", false, {}, &bound_literal); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index 57825657b205ab..beb38ce9a273ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -33,15 +33,15 @@ absl::Status TensorShapeToConstant(const TensorShape& input_shape, Tensor* shape_constant) { const int dims = input_shape.dims(); if (shape_constant->dtype() == DT_INT32) { - auto vec = shape_constant->vec(); + auto vec = shape_constant->vec(); for (int i = 0; i < dims; ++i) { int64_t dim_size = input_shape.dim_size(i); - if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { + if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { return errors::InvalidArgument( "Shape with out_type=int32 does not support tensors > int32max", " but dim ", i, " is ", dim_size); } - vec(i) = static_cast(dim_size); + vec(i) = static_cast(dim_size); } } else { auto vec = shape_constant->vec(); diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc index 74e04e035ef3be..0ee9173cda69e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc @@ -101,8 +101,8 @@ absl::Status GetAndValidateAttributes(OpKernelConstruction* ctx, return absl::OkStatus(); } -std::vector GetSliceIndices(absl::Span num_partitions, - absl::Span slice_shape, +std::vector GetSliceIndices(absl::Span num_partitions, + absl::Span slice_shape, const int index) { DCHECK_EQ(num_partitions.size(), slice_shape.size()); @@ -213,7 +213,7 @@ class XlaSplitNDBaseOp : public XlaOpKernel { // Calculate paddings necessary for slice instead of padding input and // slicing subsequently to reduce temporary memory allocation. for (int dim = 0; dim < rank; ++dim) { - const int64 dim_size = input_shape.dim_size(dim); + const int64_t dim_size = input_shape.dim_size(dim); if (slice_start_indices[dim] >= dim_size) { // Complete padding. slice_start_indices[dim] = dim_size; @@ -387,9 +387,9 @@ class XlaConcatNDBaseOp : public XlaOpKernel { std::vector update_slice_start_indices; update_slice_start_indices.reserve(rank); - for (int64 start_index : slice_start_indices) { + for (int64_t start_index : slice_start_indices) { update_slice_start_indices.push_back( - xla::ConstantR0(ctx->builder(), start_index)); + xla::ConstantR0(ctx->builder(), start_index)); } output = xla::DynamicUpdateSlice(output, input_slice, update_slice_start_indices); diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 844a31f97990fc..b0e337cec20c33 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -180,8 +180,8 @@ class SliceOp : public XlaOpKernel { xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {}); if (constant_size_is_minus_one && size[i] == -1) { // size = input_.dim_size(i) - begin[i] - dynamic_size = xla::ConstantR0(ctx->builder(), - input_shape.dim_size(i)) - + dynamic_size = xla::ConstantR0(ctx->builder(), + input_shape.dim_size(i)) - begin_indices[i]; } auto constant_size = ctx->value_inference().AnalyzeConstant( @@ -192,7 +192,7 @@ class SliceOp : public XlaOpKernel { // triggered when some dimensions's slice sizes are constant while // some are dynamic. sliced = xla::SliceInDim( - sliced, 0, constant_size->Get({}).value(), 1, i); + sliced, 0, constant_size->Get({}).value(), 1, i); } else { // We gave a generous bound (same as input) to the output, try reset // the bound if a tighter one can be found. diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index ac33e0877200dc..180ba322f0fdd0 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -34,7 +34,7 @@ namespace { class SpaceToDepthOp : public XlaOpKernel { public: explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); diff --git a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc index 124e36557f1429..f6d468131ac94e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc @@ -69,8 +69,8 @@ class XlaSpmdFullToShardShapeOp : public XlaOpKernel { } private: - string manual_sharding_str_; - int32 single_dim_; + std::string manual_sharding_str_; + int32_t single_dim_; std::vector unspecified_dims_; XlaSpmdFullToShardShapeOp(const XlaSpmdFullToShardShapeOp&) = delete; void operator=(const XlaSpmdFullToShardShapeOp&) = delete; @@ -120,8 +120,8 @@ class XlaSpmdShardToFullShapeOp : public XlaOpKernel { private: TensorShape full_shape_; - string manual_sharding_str_; - int32 single_dim_; + std::string manual_sharding_str_; + int32_t single_dim_; std::vector unspecified_dims_; XlaSpmdShardToFullShapeOp(const XlaSpmdShardToFullShapeOp&) = delete; void operator=(const XlaSpmdShardToFullShapeOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 3c99ad63565266..4672477be3534b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -120,7 +120,7 @@ class StackOp : public XlaOpKernel { private: DataType dtype_; - string stack_name_; + std::string stack_name_; StackOp(const StackOp&) = delete; void operator=(const StackOp&) = delete; @@ -152,7 +152,7 @@ class StackPushOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. std::vector start_indices(elem_shape.dims() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; TensorShape slice_shape = elem_shape; @@ -164,7 +164,7 @@ class StackPushOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple( b, {xla::DynamicUpdateSlice(ta, update, start_indices), - xla::Add(index, xla::ConstantR0(b, 1))}))); + xla::Add(index, xla::ConstantR0(b, 1))}))); ctx->SetOutput(0, value); } @@ -204,12 +204,12 @@ class StackPopOp : public XlaOpKernel { xla::XlaOp ta = xla::GetTupleElement(state, 0); xla::XlaOp index = xla::GetTupleElement(state, 1); - index = Sub(index, xla::ConstantR0(b, 1)); + index = Sub(index, xla::ConstantR0(b, 1)); OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. std::vector start_indices(stack_shape.dims(), - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; auto slice_shape = stack_shape.dim_sizes(); diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index e7ff8194b96ce8..80047c5f17cc98 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -511,7 +511,7 @@ class RngSkipOp : public XlaOpKernel { REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"), RngSkipOp<>); -using RngReadAndSkipOp = RngSkipOp; +using RngReadAndSkipOp = RngSkipOp; REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"), RngReadAndSkipOp); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index aa71c5c34d2e1a..246981c3465ef1 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -76,7 +76,7 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { // `BitcastConvertType(ConvertElementType(u32, U16), BF16)`, to avoid the // unclear `ConvertElementType(f32, BF16)` behavior. xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) & - xla::ConstantR0(builder, 0xFFFF0000); + xla::ConstantR0(builder, 0xFFFF0000); return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32), xla::BF16); } else { @@ -184,7 +184,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformOp(const StatelessRandomUniformOp&) = delete; void operator=(const StatelessRandomUniformOp&) = delete; @@ -240,7 +240,7 @@ class StatelessRandomUniformIntOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformIntOp(const StatelessRandomUniformIntOp&) = delete; void operator=(const StatelessRandomUniformIntOp&) = delete; @@ -283,7 +283,7 @@ class StatelessRandomUniformFullIntOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformFullIntOp(const StatelessRandomUniformFullIntOp&) = delete; @@ -336,7 +336,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomNormalOp(const StatelessRandomNormalOp&) = delete; void operator=(const StatelessRandomNormalOp&) = delete; @@ -384,7 +384,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessTruncatedNormalOp(const StatelessTruncatedNormalOp&) = delete; void operator=(const StatelessTruncatedNormalOp&) = delete; @@ -449,7 +449,7 @@ class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessParameterizedTruncatedNormalOp( const StatelessParameterizedTruncatedNormalOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index ce1fee91ae6a51..689e6ca3f7bf41 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -128,7 +128,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformOp(const StatelessRandomUniformOp&) = delete; void operator=(const StatelessRandomUniformOp&) = delete; @@ -177,7 +177,7 @@ class StatelessRandomUniformIntOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformIntOp(const StatelessRandomUniformIntOp&) = delete; void operator=(const StatelessRandomUniformIntOp&) = delete; @@ -225,7 +225,7 @@ class StatelessRandomUniformFullIntOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformFullIntOp(const StatelessRandomUniformFullIntOp&) = delete; @@ -295,7 +295,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomNormalOp(const StatelessRandomNormalOp&) = delete; void operator=(const StatelessRandomNormalOp&) = delete; @@ -330,7 +330,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessTruncatedNormalOp(const StatelessTruncatedNormalOp&) = delete; void operator=(const StatelessTruncatedNormalOp&) = delete; @@ -369,7 +369,7 @@ class GetKeyCounterOp : public XlaOpKernel { } private: - string device_type_string_; + std::string device_type_string_; GetKeyCounterOp(const GetKeyCounterOp&) = delete; void operator=(const GetKeyCounterOp&) = delete; @@ -392,7 +392,7 @@ class GetAlgOp : public XlaOpKernel { } private: - string device_type_string_; + std::string device_type_string_; GetAlgOp(const GetAlgOp&) = delete; void operator=(const GetAlgOp&) = delete; @@ -430,7 +430,7 @@ class GetKeyCounterAlgOp : public XlaOpKernel { } private: - string device_type_string_; + std::string device_type_string_; GetKeyCounterAlgOp(const GetKeyCounterAlgOp&) = delete; void operator=(const GetKeyCounterAlgOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index e15196bd756462..1b44d1e07c4bd8 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -142,7 +142,7 @@ class StridedSliceOp : public XlaOpKernel { // Pad input to 2x to avoid OOB access. slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)), padding_config); - for (int64 i = 0; i < result_dims_are_dynamic.size(); ++i) { + for (int64_t i = 0; i < result_dims_are_dynamic.size(); ++i) { if (result_dims_are_dynamic[i]) { slice = xla::RemoveDynamicDimension(slice, i); } @@ -178,7 +178,7 @@ class StridedSliceOp : public XlaOpKernel { // Can't infer a lower bound. return false; } - return lower_bound->Get({}) >= 0; + return lower_bound->Get({}) >= 0; }; if (begin_mask) { begin_index = zero; @@ -220,7 +220,7 @@ class StridedSliceOp : public XlaOpKernel { // size 1 dims of a shape. slice = xla::Reshape(slice, final_shape.dim_sizes()); for (int64_t i = 0; i < final_shape.dims(); ++i) { - int64 processing_shape_dim = shape_spec.output_to_processing_mapping[i]; + int64_t processing_shape_dim = shape_spec.output_to_processing_mapping[i]; // If processing_shape_dim is -1, it means the output dimension was newly // added by new_axis_mask_, which doesn't show up in input. if (processing_shape_dim != -1) { @@ -341,9 +341,9 @@ class StridedSliceOp : public XlaOpKernel { int64_t sparse_index = shape_spec.output_to_sparse_mapping[i]; bool end_is_dynamic = sparse_index == -1 ? false : ends_are_dynamic[sparse_index]; - bool backward_slice = sparse_index == -1 - ? false - : end_literal.Get({sparse_index}) < 0; + bool backward_slice = + sparse_index == -1 ? false + : end_literal.Get({sparse_index}) < 0; if (input_is_dynamic || end_is_dynamic) { OP_REQUIRES( ctx, strides[input_index] == 1, @@ -363,8 +363,8 @@ class StridedSliceOp : public XlaOpKernel { "sized slice with dynamic negative index %lld. ")); operand_size = xla::Add( operand_size, - xla::ConstantR0(ctx->builder(), - end_literal.Get({sparse_index}))); + xla::ConstantR0( + ctx->builder(), end_literal.Get({sparse_index}))); } else { // The end of slice with dynamic slice size is the min of operand // shape and slice size. E.g., t[:end_size], result size is @@ -376,13 +376,13 @@ class StridedSliceOp : public XlaOpKernel { {}); } else { end_size = - xla::ConstantR0(ctx->builder(), end[input_index]); + xla::ConstantR0(ctx->builder(), end[input_index]); } operand_size = xla::Min(operand_size, end_size); } slice = xla::SetDimensionSize( slice, - xla::Sub(operand_size, xla::ConstantR0( + xla::Sub(operand_size, xla::ConstantR0( ctx->builder(), begin[input_index])), i); } @@ -397,8 +397,8 @@ class StridedSliceOp : public XlaOpKernel { } private: - int32 begin_mask_, end_mask_; - int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + int32_t begin_mask_, end_mask_; + int32_t ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; }; @@ -634,8 +634,8 @@ class StridedSliceGradOp : public XlaOpKernel { } private: - int32 begin_mask_, end_mask_; - int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + int32_t begin_mask_, end_mask_; + int32_t ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; }; @@ -751,8 +751,8 @@ class StridedSliceAssignOp : public XlaOpKernel { } private: - int32 begin_mask_, end_mask_; - int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + int32_t begin_mask_, end_mask_; + int32_t ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; DataType dtype_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 888908e30b2331..e89c3e3b4f837b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -94,7 +94,7 @@ absl::Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, // Checks that the TensorArray 'resource' has been initialized, and has type // 'dtype'. Sets 'shape' to the shape -absl::Status CheckTensorArrayIsInitialized(const string& op_name, +absl::Status CheckTensorArrayIsInitialized(const std::string& op_name, const XlaResource* resource, DataType dtype) { if (resource->kind() != XlaResource::kTensorArray) { @@ -184,7 +184,7 @@ class TensorArrayOp : public XlaOpKernel { private: PartialTensorShape element_shape_; DataType dtype_; - string tensor_array_name_; + std::string tensor_array_name_; TensorArrayOp(const TensorArrayOp&) = delete; void operator=(const TensorArrayOp&) = delete; @@ -218,7 +218,7 @@ class TensorArrayWriteOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. std::vector start_indices(elem_shape.dims() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; TensorShape slice_shape = elem_shape; @@ -270,7 +270,7 @@ class TensorArrayReadOp : public XlaOpKernel { // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. std::vector start_indices(ta_shape.dims(), - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; auto slice_shape = ta_shape.dim_sizes(); @@ -430,7 +430,7 @@ class TensorArrayScatterOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto index = xla::Reshape(xla::Slice(indices, {i}, {i + 1}, {1}), {}); std::vector start_indices(elem_shape.dims() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } @@ -570,7 +570,8 @@ class TensorArraySizeOp : public XlaOpKernel { XlaResource* var; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); - size_tensor.scalar()() = static_cast(var->max_array_size()); + size_tensor.scalar()() = + static_cast(var->max_array_size()); ctx->SetConstantOutput(0, size_tensor); } @@ -609,7 +610,7 @@ class TensorArrayGradOp : public XlaOpKernel { } private: - string source_; + std::string source_; TensorArrayGradOp(const TensorArrayGradOp&) = delete; void operator=(const TensorArrayGradOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index a1f58d5ae9b40e..f128c96c570e6c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -70,7 +70,7 @@ absl::StatusOr>> GetTensorListDynamicDims( dynamic_dims.push_back(ctx->Input(1)); } else { dynamic_dims.push_back( - xla::ConstantR0(ctx->builder(), num_elements)); + xla::ConstantR0(ctx->builder(), num_elements)); } for (int64_t dim = 0; dim < element_shape.dimensions().size(); ++dim) { if (dims_are_dynamic[dim]) { @@ -80,7 +80,7 @@ absl::StatusOr>> GetTensorListDynamicDims( dynamic_dims.push_back(dynamic_dim_size); } else { dynamic_dims.push_back( - xla::ConstantR0(ctx->builder(), dynamic_sizes[dim])); + xla::ConstantR0(ctx->builder(), dynamic_sizes[dim])); } } list_dynamic_dims.push_back(std::move(dynamic_dims)); @@ -191,7 +191,7 @@ class TensorListReserveOp : public XlaOpKernel { OP_REQUIRES_OK( ctx, SetTensorListPushIndex( - new_list, xla::ConstantR0(ctx->builder(), num_elements), + new_list, xla::ConstantR0(ctx->builder(), num_elements), &result)); ctx->SetTensorListOutput(0, result); return; @@ -324,13 +324,13 @@ class TensorListElementShapeOp : public XlaOpKernel { ctx->SetOutput(0, xla::ConstantR1(b, list_shape.dimensions())); break; case DT_INT32: { - std::vector size; + std::vector size; const auto& dimensions = list_shape.dimensions(); size.reserve(dimensions.size()); for (int64_t s : dimensions) { size.push_back(s); } - ctx->SetOutput(0, xla::ConstantR1(b, size)); + ctx->SetOutput(0, xla::ConstantR1(b, size)); break; } default: diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 683dc4737e6dab..0a7297456fce8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -393,7 +393,7 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, std::vector start_indices( element_part_shape.dimensions().size() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = push_index; xla::XlaOp list_part = xla::GetTupleElement(list, i); @@ -409,7 +409,7 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, xla::XlaOp update = xla::Reshape(element, element_dims); std::vector start_indices(element_shape.dimensions().size() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = push_index; xla::XlaOp list_part = xla::GetTupleElement(list, 0); @@ -418,7 +418,7 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, result_parts.push_back(updated_list_part); } - xla::XlaOp updated_push_index = push_index + xla::ConstantR0(b, 1); + xla::XlaOp updated_push_index = push_index + xla::ConstantR0(b, 1); result_parts.push_back(updated_push_index); *result = xla::Tuple(b, result_parts); @@ -441,14 +441,14 @@ absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list)); int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1); - push_index = push_index - xla::ConstantR0(b, 1); + push_index = push_index - xla::ConstantR0(b, 1); std::vector list_result_parts, element_result_parts; for (int i = 0; i < list_tuple_size - 1; i++) { const xla::Shape& list_part_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, i); std::vector start_indices(list_part_shape.dimensions().size(), - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = push_index; std::vector slice_shape = @@ -496,7 +496,7 @@ absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, xla::XlaOp update = xla::Reshape(element, element_dims); std::vector start_indices(element_shape.dimensions().size() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; xla::XlaOp list_part = xla::GetTupleElement(list, 0); @@ -550,7 +550,7 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, const xla::Shape& buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); std::vector start_indices(buffer_shape.dimensions().size(), - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; std::vector slice_shape = @@ -585,7 +585,7 @@ absl::Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, } std::vector result_parts{tensor, - xla::ConstantR0(b, push_index)}; + xla::ConstantR0(b, push_index)}; *result = xla::Tuple(b, result_parts); return absl::OkStatus(); } diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 039320573f4558..9c4e0b63490205 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -137,7 +137,7 @@ class InvertPermutationOp : public XlaOpKernel { absl::Status status; switch (dtype) { case DT_INT32: - InvertPermutation(ctx); + InvertPermutation(ctx); break; case DT_INT64: InvertPermutation(ctx); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 96f44d14e42ef4..90a022f5111e9a 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -56,7 +56,7 @@ REGISTER_XLA_OP(Name("Abs"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Acos"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Acosh"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Asin"), MlirXlaOpKernel); -XLAJIT_MAKE_UNARY(Asinh, xla::Asinh(x)); +REGISTER_XLA_OP(Name("Asinh"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Atan"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Atanh"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Ceil"), MlirXlaOpKernel); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc index dbd6cda9d950d0..1d487f70d09d21 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc @@ -36,7 +36,7 @@ namespace tensorflow { namespace { using XlaUnaryOpGenerator = std::function; -using XlaOpGeneratorMap = absl::flat_hash_map; +using XlaOpGeneratorMap = absl::flat_hash_map; void PopulateXlaOpGeneratorMap(XlaOpGeneratorMap* op_generator_map) { auto add_xla_op_generator = [&](std::string name, @@ -120,7 +120,7 @@ class UnaryOpsCompositionOp : public XlaOpKernel { } private: - std::vector op_names_; + std::vector op_names_; }; REGISTER_XLA_OP(Name("_UnaryOpsComposition"), UnaryOpsCompositionOp); diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index a7a1a438f95b9e..c9ddab9efb6e22 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -165,7 +165,7 @@ class ResourceGatherOp : public XlaOpKernel { } private: - int32 batch_dims_; + int32_t batch_dims_; }; REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 415f465f0b5088..57821f74e97024 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -449,7 +449,8 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Add any TensorArray gradients touched by the body to the enclosing // graph. - for (const string& grad_source : update.tensor_array_gradients_accessed) { + for (const std::string& grad_source : + update.tensor_array_gradients_accessed) { VLOG(4) << "TensorArray " << resource->name() << " accessed gradient " << grad_source; XlaResource* gradient; @@ -553,7 +554,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Set token input for this "while" op. std::vector token_inputs; token_inputs.reserve(token_input_nodes_.size()); - for (const string& node_name : token_input_nodes_) { + for (const std::string& node_name : token_input_nodes_) { auto token_or = compiler->GetNodeToken(node_name); OP_REQUIRES_OK(ctx, token_or.status()); token_inputs.push_back(token_or.value()); @@ -590,7 +591,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } else { int32_t dim_size = shape.dimensions(0); dynamic_dims.push_back( - xla::ConstantR0(ctx->builder(), dim_size)); + xla::ConstantR0(ctx->builder(), dim_size)); } // Set dynamic dimension size to 0 for element value. Inside the while diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 8e9f317ac4f3fe..b1937c14f0bebc 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -61,8 +61,8 @@ class XlaWhileOp : public XlaOpKernel { NameAttrList cond_name_attr_; NameAttrList body_name_attr_; bool has_token_input_output_; - std::vector token_input_nodes_; - string original_node_name_; + std::vector token_input_nodes_; + std::string original_node_name_; // Whether to propagate compile time consts into the loop body. // This is not supported by default now since it may cause HBM memory // overheads. diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index faa8b30bcf9dc6..1ac01a4c172cfe 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -61,13 +61,13 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo #include "stablehlo/transforms/StablehloRefineShapes.h" // from @stablehlo +#include "stablehlo/transforms/optimization/Passes.h" // from @stablehlo #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/translate/stablehlo.h" #include "xla/mlir/utils/type_util.h" -#include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/python/refine_polymorphic_shapes.h" #include "xla/service/hlo.pb.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" @@ -121,7 +121,7 @@ bool IsTokenType(mlir::Type type) { } absl::StatusOr> -XlaCallModuleLoader::Create(mlir::MLIRContext *context, int version, +XlaCallModuleLoader::Create(mlir::MLIRContext* context, int version, mlir::StringRef module_str, std::vector disabled_checks, std::vector platforms, @@ -165,7 +165,7 @@ absl::Status XlaCallModuleLoader::SetPlatformIndex( if (platform_index < 0) return absl::OkStatus(); VLOG(3) << "XlaCallModule setting the platform_index to " << platform_index << " for platform " << compilation_platform << "."; - mlir::Block &main_body = main_.front(); + mlir::Block& main_body = main_.front(); if (main_.getNumArguments() < 1) { return absl::InvalidArgumentError(absl::StrCat( @@ -241,19 +241,19 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( " non-token and non-platform-index arguments. The input ", "shapes are (", absl::StrJoin(input_shapes, ", ", - [](std::string *out, const xla::Shape &s) { + [](std::string* out, const xla::Shape& s) { absl::StrAppend(out, s.ToString()); }), ") and the main function argument types are ", absl::StrJoin(InputTypes(), ", ", - [](std::string *out, const mlir::Type &t) { + [](std::string* out, const mlir::Type& t) { absl::StrAppend(out, mlir::debugString(t)); }), ")")); } // Derive static input types to use for main. - mlir::Block &main_body = main_.front(); + mlir::Block& main_body = main_.front(); mlir::Builder builder(module_->getContext()); std::vector static_array_input_types(nr_inputs); int next_actual_input = 0; @@ -272,7 +272,7 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( } // Get static MLIR Type from xla Shape. - const xla::Shape &xla_shape = input_shapes[next_actual_input++]; + const xla::Shape& xla_shape = input_shapes[next_actual_input++]; std::vector xla_dimensions; if (xla_shape.IsArray()) { xla_dimensions = std::vector(xla_shape.dimensions().begin(), @@ -370,7 +370,7 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( } absl::Status XlaCallModuleLoader::LoadModule( - mlir::MLIRContext *context, int version, mlir::StringRef module_str, + mlir::MLIRContext* context, int version, mlir::StringRef module_str, std::vector disabled_checks, std::vector platforms, int num_invocation_args, bool main_has_token_input_output, bool use_shardy_partitioner) { @@ -457,7 +457,7 @@ absl::Status XlaCallModuleLoader::LoadModule( return absl::InvalidArgumentError("Cannot find 'main' in module"); } - mlir::Block &main_body = main_.front(); + mlir::Block& main_body = main_.front(); int nr_token_arguments = llvm::count_if(InputTypes(), IsTokenType); if (version < kVersionStartSupportEffects) { @@ -489,7 +489,7 @@ absl::Status XlaCallModuleLoader::ValidateXlaCallModuleInvariants() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); bool moduleValidationFailed = false; - module_->walk([&](mlir::Operation *op) { + module_->walk([&](mlir::Operation* op) { // StableHLO programs created by jax2tf only contain operations // from Builtin, Func, StableHLO, Shardy dialects. if (!llvm::isagetContext()); - // TODO (b/410057228): Replace MHLO canonicalization with StableHLO. - // This code requires MHLO CaseOp canonicalization to remove unreachable - // branches, else `tf.call_tf_function` inlining can fail. mlir::PassManager pm(module_->getContext()); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + pm.addNestedPass( + mlir::stablehlo::createStablehloTargetIndependentOptimizationPass()); if (use_shardy_partitioner_) { // We need to export shardings because the lowering path go directly to // HLO but not the MLIR to HLO path that invokes SdyRoundTripExport. @@ -543,7 +539,7 @@ absl::Status XlaCallModuleLoader::PrepareStablehloForLowering() { if (failed(pm.run(*module_))) { return absl::InternalError( - absl::StrCat("MHLO->HLO lowering passes failed: ", + absl::StrCat("StableHLO->HLO lowering passes failed: ", diag_handler.ConsumeStatus().ToString())); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index 9a2a00c58732f3..e06c0b09ba9938 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -166,13 +166,13 @@ class XlaCallModuleOp : public XlaOpKernel { explicit XlaCallModuleOp(OpKernelConstruction *ctx) : XlaOpKernel(ctx) { int version; OP_REQUIRES_OK(ctx, ctx->GetAttr("version", &version)); - string module_str; + std::string module_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("module", &module_str)); std::vector expected_output_shapes; OP_REQUIRES_OK(ctx, ctx->GetAttr("Sout", &expected_output_shapes)); std::vector expected_output_dtypes; OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &expected_output_dtypes)); - std::vector dim_args_spec; + std::vector dim_args_spec; OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_args_spec", &dim_args_spec)); OP_REQUIRES(ctx, dim_args_spec.empty(), absl::UnimplementedError( @@ -183,9 +183,9 @@ class XlaCallModuleOp : public XlaOpKernel { "The size of Sout (", expected_output_shapes.size(), ") must match the size of Tout (", expected_output_dtypes.size(), ")"))); - std::vector disabled_checks; + std::vector disabled_checks; OP_REQUIRES_OK(ctx, ctx->GetAttr("disabled_checks", &disabled_checks)); - std::vector platforms; + std::vector platforms; OP_REQUIRES_OK(ctx, ctx->GetAttr("platforms", &platforms)); // TODO(necula): change this to OP_REQUIRES_OK when 6 months have passed // since we added the function_list and has_token_input_output @@ -222,7 +222,7 @@ class XlaCallModuleOp : public XlaOpKernel { }) << "])"; } - string compilation_device_type = ctx->device_type().type_string(); + std::string compilation_device_type = ctx->device_type().type_string(); compilation_platform_ = ""; if (compilation_device_type == DEVICE_CPU_XLA_JIT) { compilation_platform_ = "CPU"; @@ -293,7 +293,7 @@ class XlaCallModuleOp : public XlaOpKernel { xla::XlaOp token_input; if (!op_token_input_nodes_.empty()) { std::vector token_inputs; - for (const string &node_name : op_token_input_nodes_) { + for (const std::string& node_name : op_token_input_nodes_) { auto token = compiler->GetNodeToken(node_name); OP_REQUIRES_OK(ctx, token.status()); token_inputs.push_back(token.value()); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc index 139ac17b35c637..99a0ec6d9e38dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc @@ -55,8 +55,8 @@ class XlaCustomCallOp : public XlaOpKernel { } private: - string target_name_; - string backend_config_; + std::string target_name_; + std::string backend_config_; DataType output_type_; TensorShape output_shape_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc index 7b0ea597c63488..6889c093a11201 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc @@ -42,7 +42,7 @@ class XlaDequantizeOp : public XlaOpKernel { xla::QuantizedRange range(min_range_, max_range_); xla::XlaOp output = - xla::Dequantize(input, range, mode_, transpose_output_); + xla::Dequantize(input, range, mode_, transpose_output_); context->SetOutput(0, output); } @@ -50,7 +50,7 @@ class XlaDequantizeOp : public XlaOpKernel { float min_range_; float max_range_; bool transpose_output_; - string mode_; + std::string mode_; XlaDequantizeOp(const XlaDequantizeOp&) = delete; void operator=(const XlaDequantizeOp&) = delete; }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 8236e67eeded01..f77cb46c44de8c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -34,12 +34,12 @@ namespace { class XlaDotOp : public XlaOpKernel { public: explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) { - string dnums_attr; + std::string dnums_attr; OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); OP_REQUIRES( context, dnums_.ParsePartialFromString(dnums_attr), errors::InvalidArgument("Error parsing convolution dimension numbers")); - string precision_config_attr; + std::string precision_config_attr; OP_REQUIRES_OK( context, context->GetAttr("precision_config", &precision_config_attr)); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc index 0cfd247bdd1de6..7765de131e865c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -41,7 +41,7 @@ class XlaSelfAdjointEigOp : public XlaOpKernel { private: bool lower_; - int32 max_iter_; + int32_t max_iter_; float epsilon_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc index f3bd088ced826a..6639c8003e1a15 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc @@ -37,7 +37,7 @@ class XlaSvdOp : public XlaOpKernel { explicit XlaSvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); - string precision_config_attr; + std::string precision_config_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("precision_config", &precision_config_attr)); OP_REQUIRES(ctx, @@ -57,7 +57,7 @@ class XlaSvdOp : public XlaOpKernel { } private: - int32 max_iter_; + int32_t max_iter_; float epsilon_; xla::PrecisionConfig precision_config_; }; diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 6a67cfa237af70..0028f8e61cbd11 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -222,7 +222,7 @@ static absl::Status XlaDotShapeFunction(shape_inference::InferenceContext* c) { return shape_inference::UnknownShape(c); } - string dimension_numbers_string; + std::string dimension_numbers_string; TF_RETURN_IF_ERROR( c->GetAttr("dimension_numbers", &dimension_numbers_string)); @@ -1027,7 +1027,7 @@ REGISTER_OP("XlaEinsum") .Attr("equation: string") .Attr("T: {complex64, bfloat16, float}") .SetShapeFn([](shape_inference::InferenceContext* context) { - string equation; + std::string equation; TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation)); // XlaEinsum supports only two-input einsum equations. if (!absl::StrContains(equation, ",")) { @@ -1057,9 +1057,9 @@ REGISTER_OP("XlaSpmdFullToShardShape") if (!c->RankKnown(input_handle)) { return shape_inference::UnknownShape(c); } - string sharding_attr; + std::string sharding_attr; TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr)); - int32 single_dim; + int32_t single_dim; TF_RETURN_IF_ERROR(c->GetAttr("dim", &single_dim)); xla::OpSharding sharding; sharding.ParseFromString(sharding_attr); diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index 84ed56a468df8e..47e76f81a0328c 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -304,7 +304,7 @@ absl::Status MaybeRewriteWhileNode( resource_input_count, index_mapping)); // Modify cond and body functions. - for (auto const& attr_name : std::vector{"cond", "body"}) { + for (auto const& attr_name : std::vector{"cond", "body"}) { NameAttrList attr_value; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value)); const FunctionBody* fbody; @@ -363,7 +363,7 @@ absl::Status MaybeRewriteWhileNode( // Save the new FunctionDef. FunctionDef new_fdef; - string new_name = + std::string new_name = fld->UniqueFunctionName(absl::StrCat(attr_value.name(), "_rearrange_")); TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef)); @@ -435,7 +435,7 @@ absl::Status MaybeRewriteIfNode( std::map resource_retval_to_arg, retval_index_mapping; for (auto const& attr_name : - std::vector{"then_branch", "else_branch"}) { + std::vector{"then_branch", "else_branch"}) { NameAttrList f; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f)); const FunctionBody* fbody; @@ -459,7 +459,7 @@ absl::Status MaybeRewriteIfNode( // Save the new FunctionDef. FunctionDef new_fdef; - string new_name = + std::string new_name = fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_")); TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef)); diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc index 956f597301d28d..39efe2d682eb12 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -34,15 +34,16 @@ bool HasResourceInputOrOutput(const OpDef& op_def) { } TEST(ResourceOperationTableTest, HaveAllResourceOps) { - absl::flat_hash_map known_resource_ops; + absl::flat_hash_map known_resource_ops; for (absl::string_view known_resource_op : resource_op_table_internal::GetKnownResourceOps()) { ASSERT_TRUE( - known_resource_ops.insert({string(known_resource_op), false}).second); + known_resource_ops.insert({std::string(known_resource_op), false}) + .second); } - std::vector xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); - for (const string& xla_op_name : xla_op_names) { + std::vector xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); + for (const std::string& xla_op_name : xla_op_names) { const OpDef* op_def; TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def)); if (HasResourceInputOrOutput(*op_def)) { @@ -52,7 +53,7 @@ TEST(ResourceOperationTableTest, HaveAllResourceOps) { } } - std::vector unnecessary_resource_ops; + std::vector unnecessary_resource_ops; for (const auto& pair : known_resource_ops) { if (!pair.second) { unnecessary_resource_ops.push_back(pair.first); diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 7e0b70e4df270a..4b285078f94d21 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -50,7 +50,8 @@ xla::OpMetadata CreateOpMetadata(const std::string& op_type, } void AssignOpMetadataToSharding(xla::OpSharding& sharding, - const string& op_type, const string& op_name) { + const std::string& op_type, + const std::string& op_name) { auto metadata = CreateOpMetadata(op_type, op_name); if (sharding.type() == xla::OpSharding::TUPLE) { for (auto& sharding_element : *sharding.mutable_tuple_shardings()) { @@ -69,7 +70,7 @@ absl::Status CoreOutOfRangeError(int core, int num_cores_per_replica) { } // namespace absl::StatusOr> ParseShardingFromDevice( - const string& device_name, int num_cores_per_replica, + const std::string& device_name, int num_cores_per_replica, std::optional explicit_sharding, std::optional metadata) { if (device_name.empty()) { @@ -102,7 +103,7 @@ absl::StatusOr> ParseShardingFromDevice( absl::StatusOr> ParseShardingFromDevice( const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) { - const string& device_name = node_def.device(); + const std::string& device_name = node_def.device(); TF_ASSIGN_OR_RETURN(std::optional sharding, GetShardingFromNodeDef(node_def, add_metadata)); return ParseShardingFromDevice( @@ -114,7 +115,7 @@ absl::StatusOr> ParseShardingFromDevice( absl::StatusOr> ParseShardingFromDevice( const Node& node, int num_cores_per_replica, bool add_metadata) { - string device_name = node.assigned_device_name(); + std::string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } @@ -152,7 +153,7 @@ absl::StatusOr> ParseShardingFromEdgeSource( } void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { - string device_name = src.assigned_device_name(); + std::string device_name = src.assigned_device_name(); if (device_name.empty()) { device_name = src.requested_device(); } @@ -169,7 +170,7 @@ absl::StatusOr> GetShardingFromNodeDefInternal( if (!HasNodeAttr(node_def, attribute)) { return std::optional(); } - string value; + std::string value; xla::OpSharding sharding; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attribute, &value)); if (tensorflow::DecodeShardingAttribute(value, sharding).failed()) { diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index e579f3ee0ff397..85259e0c729883 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -36,7 +36,7 @@ namespace tensorflow { // - a non-value if there is no assigned core or // - a sharding set as per xla::sharding_builder::AssignDevice. absl::StatusOr> ParseShardingFromDevice( - const string& device_name, int num_cores_per_replica, + const std::string& device_name, int num_cores_per_replica, std::optional explicit_sharding = std::nullopt, std::optional metadata = std::nullopt); diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc index 585e3887fe686c..c987e8f167422f 100644 --- a/tensorflow/compiler/tf2xla/sharding_util_test.cc +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -33,7 +33,7 @@ TEST(CoreUtilTest, ParseShardingFromDevice) { Graph graph(OpRegistry::Global()); auto core_from_sharding = - [](std::optional sharding) -> int64 { + [](std::optional sharding) -> int64_t { if (sharding.has_value() && sharding.value().type() == xla::OpSharding::MAXIMAL) { return sharding.value().tile_assignment_devices(0); diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index afe82e0de40f62..e8b2a56cdf64d2 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -48,8 +48,8 @@ absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { } else if (node->IsIfNode()) { AttrValue device_ordinal_value; device_ordinal_value.set_i(device_ordinal); - for (const string& attr_name : - std::vector{"then_branch", "else_branch"}) { + for (const std::string& attr_name : + std::vector{"then_branch", "else_branch"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; @@ -59,7 +59,8 @@ absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { } else if (node->IsWhileNode()) { AttrValue device_ordinal_value; device_ordinal_value.set_i(device_ordinal); - for (const string& attr_name : std::vector{"cond", "body"}) { + for (const std::string& attr_name : + std::vector{"cond", "body"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; @@ -80,39 +81,40 @@ absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; - ReverseDFS(g, - [&](Node* n) { - std::vector token_input_nodes; - if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, - &token_input_nodes) - .ok() || - token_input_nodes.empty()) { - return; - } - - if (first_side_effecting_node_on_path != nullptr) { - return; - } - - first_side_effecting_node_on_path = n; - string original_node_name; - TF_CHECK_OK(GetNodeAttr(n->def(), - kXlaOriginalOutsideCompilationNodeName, - &original_node_name)); - results.insert(original_node_name); - }, - [&](Node* n) { - if (first_side_effecting_node_on_path == n) { - first_side_effecting_node_on_path = nullptr; - } - }, - NodeComparatorName()); + ReverseDFS( + g, + [&](Node* n) { + std::vector token_input_nodes; + if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, + &token_input_nodes) + .ok() || + token_input_nodes.empty()) { + return; + } + + if (first_side_effecting_node_on_path != nullptr) { + return; + } + + first_side_effecting_node_on_path = n; + std::string original_node_name; + TF_CHECK_OK(GetNodeAttr(n->def(), + kXlaOriginalOutsideCompilationNodeName, + &original_node_name)); + results.insert(original_node_name); + }, + [&](Node* n) { + if (first_side_effecting_node_on_path == n) { + first_side_effecting_node_on_path = nullptr; + } + }, + NodeComparatorName()); return results; } bool HasSideEffectingNodes(const Graph& g) { for (Node* n : g.nodes()) { - std::vector token_input_nodes; + std::vector token_input_nodes; if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes) .ok() && !token_input_nodes.empty()) { @@ -123,10 +125,10 @@ bool HasSideEffectingNodes(const Graph& g) { } absl::Status ParseHostComputeCoreList( - absl::Span list_from_attr, - std::map* host_compute_core) { + absl::Span list_from_attr, + std::map* host_compute_core) { for (const auto& hc_core : list_from_attr) { - std::vector parts = str_util::Split(hc_core, ":"); + std::vector parts = str_util::Split(hc_core, ":"); if (parts.size() != 2) { return errors::InvalidArgument( "Malformed host_compute_core entry ", hc_core, diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index 34f30eb7661bc1..9ba994a16a3c8e 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -61,8 +61,9 @@ bool HasSideEffectingNodes(const Graph& g); // Parse the mapping from outside_compilation_subgraph name to core number, // which is specified in an attr as a list of strings // :. -absl::Status ParseHostComputeCoreList(absl::Span list_from_attr, - std::map* host_compute_core); +absl::Status ParseHostComputeCoreList( + absl::Span list_from_attr, + std::map* host_compute_core); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 43623a8db8014f..193eb7c08bc08a 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -21,12 +21,12 @@ limitations under the License. namespace tensorflow { absl::Status InstantiateFunctionForTest( - const string& name, const FunctionLibraryDefinition& library, + const std::string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result) { const FunctionDef* fdef = library.Find(name); TF_RET_CHECK(fdef != nullptr); - auto get_func_sig = [&library](const string& op, const OpDef** sig) { + auto get_func_sig = [&library](const std::string& op, const OpDef** sig) { return library.LookUpOpDef(op, sig); }; InstantiationResult inst; diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index 2b2eb4f582af3e..2c9cdc1c352238 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -41,7 +41,7 @@ struct InstantiationResultForTest { // Instantiates a function, producing a GraphDef to compare against the // expected graph. absl::Status InstantiateFunctionForTest( - const string& name, const FunctionLibraryDefinition& library, + const std::string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc index 504e9d0246322e..eccc2dfaf8d4a4 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -32,7 +32,8 @@ namespace tensorflow { namespace tf2xla { namespace { -void PrintSupportedOps(const string& device, const string& regen_run) { +void PrintSupportedOps(const std::string& device, + const std::string& regen_run) { XlaOpRegistry::RegisterCompilationKernels(); std::vector kdefs = @@ -46,10 +47,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) { << "Operator | Type Constraint\n" << "-------- | ---------------" << std::endl; for (const KernelDef* kdef : kdefs) { - std::vector constraints; + std::vector constraints; constraints.reserve(kdef->constraint().size()); for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) { - std::vector types; + std::vector types; const auto& allowed_values = constraint.allowed_values().list().type(); types.reserve(allowed_values.size()); for (int type : allowed_values) { @@ -70,18 +71,18 @@ void PrintSupportedOps(const string& device, const string& regen_run) { } // namespace void SupportedOpsMain(int argc, char** argv, const char* regen_run) { - std::vector device_names = XlaOpRegistry::BackendNames(); + std::vector device_names = XlaOpRegistry::BackendNames(); std::sort(device_names.begin(), device_names.end()); // Set up and parse flags. - string device; + std::string device; std::vector flag_list = { {"device", &device, "Name of the compilation device for which to print supported ops, " "one of: " + absl::StrJoin(device_names, ",")}, }; - string usage = Flags::Usage(argv[0], flag_list); + std::string usage = Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); QCHECK(parsed_flags_ok) << "\n" << usage; QCHECK(XlaOpRegistry::IsBackendRegistered(device)) diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index d61d66bfe53b72..72bd28f2b47a8c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -118,8 +118,8 @@ TEST(ConvertGraphDefToXla, Sum) { TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. - auto x_literal = xla::LiteralUtil::CreateR0(10); - auto y_literal = xla::LiteralUtil::CreateR0(32); + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); auto x_global_or = client->TransferToServer(x_literal); auto y_global_or = client->TransferToServer(y_literal); TF_EXPECT_OK(x_global_or.status()); @@ -140,23 +140,23 @@ TEST(ConvertGraphDefToXla, Sum) { ConvertGraphDefToXla(graph_def, config, client, &computation))); } -GraphDef EinsumGraph() { +GraphDef EinsumGraph(DataType dtype = DT_FLOAT) { GraphDef graph_def; NodeDef* x = graph_def.add_node(); x->set_name("x"); x->set_op("Placeholder"); - (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + (*x->mutable_attr())["dtype"] = TypeAttrValue(dtype); NodeDef* y = graph_def.add_node(); y->set_name("y"); y->set_op("Placeholder"); - (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + (*y->mutable_attr())["dtype"] = TypeAttrValue(dtype); NodeDef* einsum = graph_def.add_node(); einsum->set_name("einsum"); einsum->set_op("Einsum"); einsum->add_input("x"); einsum->add_input("y"); (*einsum->mutable_attr())["equation"] = StringAttrValue("ij,jk->ik"); - (*einsum->mutable_attr())["T"] = TypeAttrValue(DT_FLOAT); + (*einsum->mutable_attr())["T"] = TypeAttrValue(dtype); (*einsum->mutable_attr())["N"] = IntAttrValue(2); return graph_def; } @@ -233,6 +233,35 @@ TEST_F(ConvertGraphDefToXlaWithTF32Disabled, EXPECT_EQ(num_dots, 1); } +TEST_F(ConvertGraphDefToXlaWithTF32Disabled, + EinsumIsConvertedToDotWithDefaultPrecisionIfNotF32) { + GraphDef graph_def = EinsumGraph(DT_BFLOAT16); + tf2xla::Config config = EinsumConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_dots = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "dot") { + num_dots++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::DEFAULT); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::DEFAULT); + } + } + } + EXPECT_EQ(num_dots, 1); +} + GraphDef Conv2DGraph() { GraphDef graph_def; NodeDef* x = graph_def.add_node(); @@ -338,8 +367,8 @@ TEST(ConvertGraphDefToXla, SumWithUnusedArgument) { TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. - auto x_literal = xla::LiteralUtil::CreateR0(10); - auto y_literal = xla::LiteralUtil::CreateR0(32); + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); auto x_global_or = client->TransferToServer(x_literal); auto y_global_or = client->TransferToServer(y_literal); auto unused_global_or = client->TransferToServer(y_literal); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 9f21af2741dcde..042b572c234355 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -58,8 +58,9 @@ absl::Status ValidateTensorId(const tf2xla::TensorId& id) { return absl::OkStatus(); } -absl::Status CheckNameDuplicates(const string& kind, const string& name, - std::set* names) { +absl::Status CheckNameDuplicates(const std::string& kind, + const std::string& name, + std::set* names) { if (!name.empty()) { if (!names->insert(name).second) { return errors::InvalidArgument("duplicate ", kind, " name: ", name); @@ -68,12 +69,12 @@ absl::Status CheckNameDuplicates(const string& kind, const string& name, return absl::OkStatus(); } -absl::Status CheckFeedFetchNameConflicts(const string& kind, - const std::set& names) { +absl::Status CheckFeedFetchNameConflicts(const std::string& kind, + const std::set& names) { // We don't allow the feeds or fetches to contain both "foo" and "foo_data", // since that will cause a collision in codegen symbols. - for (const string& name : names) { - const string name_data(name + "_data"); + for (const std::string& name : names) { + const std::string name_data(name + "_data"); if (names.find(name_data) != names.end()) { return errors::InvalidArgument("conflicting ", kind, " name: ", name, " and ", name_data); @@ -227,7 +228,7 @@ absl::Status ReplaceRetvalInputWithArg( // the function to replace _Arg nodes in `const_input_index_to_node` with Const // inputs. absl::Status PropagateConstIntoFuncAttr( - Node* n, const string& attr_name, + Node* n, const std::string& attr_name, const absl::flat_hash_map& const_input_index_to_node, const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld, bool passthrough_arg_to_retval = false) { @@ -255,7 +256,7 @@ absl::Status PropagateConstIntoFuncAttr( // Save rewritten function. FunctionDef replace_fdef; - string new_func_name = + std::string new_func_name = fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_")); const StackTracesMap* stack_traces = lookup_fld->GetStackTraces(func_attr.name()); @@ -301,7 +302,7 @@ absl::Status PropagateConstIntoIfNode( // Rewrite "then_branch" and "else_branch" function, replace usage of those // _Arg nodes with corresponding const node. for (const auto& attr_name : - std::vector{"then_branch", "else_branch"}) { + std::vector{"then_branch", "else_branch"}) { TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr( if_node, attr_name, const_input_index_to_node, lookup_fld, fld)); } @@ -309,13 +310,14 @@ absl::Status PropagateConstIntoIfNode( return absl::OkStatus(); } -using GraphCache = absl::flat_hash_map>; +using GraphCache = + absl::flat_hash_map>; absl::StatusOr FindOrInsert( GraphCache* cache, const NameAttrList& body_attr, const FunctionLibraryDefinition* lookup_fld, const FunctionLibraryDefinition* fallback_fld) { - const string name = body_attr.name(); + const std::string name = body_attr.name(); std::unique_ptr& value = (*cache)[name]; if (!value) { const FunctionDef* body_func = lookup_fld->Find(name); @@ -413,7 +415,7 @@ absl::Status PropagateConstIntoAndAroundWhileNode( absl::flat_hash_map const_input_index_to_mutable_node; NameAttrList body_attr; TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr)); - const string fn_name = body_attr.name(); + const std::string fn_name = body_attr.name(); const FunctionDef* body_func = lookup_fld->Find(fn_name); if (!body_func) { return errors::Internal("Propagate: Cannot find body function ", fn_name, @@ -461,7 +463,7 @@ absl::Status PropagateConstIntoAndAroundWhileNode( // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with // corresponding const node. - for (const auto& attr_name : std::vector{"cond", "body"}) { + for (const auto& attr_name : std::vector{"cond", "body"}) { TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr( while_node, attr_name, const_input_index_to_node, lookup_fld, fld, /*passthrough_arg_to_retval=*/attr_name == "body")); @@ -487,7 +489,7 @@ absl::StatusOr IsLoopInvariant( } absl::Status ValidateConfig(const tf2xla::Config& config) { - std::set names; + std::set names; for (const tf2xla::Feed& feed : config.feed()) { TF_RETURN_IF_ERROR(ValidateTensorId(feed.id())); TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape())); @@ -508,19 +510,20 @@ absl::Status ValidateConfig(const tf2xla::Config& config) { absl::Status AddPlaceholdersForFeeds( const tf2xla::Config& config, const OpRegistryInterface* op_registry, - std::unordered_map* feed_remapping, GraphDef* graph_def) { + std::unordered_map* feed_remapping, + GraphDef* graph_def) { struct PlaceholderInfo { const tf2xla::Feed* feed = nullptr; // point to Feed in . - string placeholder_name; + std::string placeholder_name; DataType data_type = DT_INVALID; }; // Put each fed tensor into a map by name:port. A map is used for determinism // when creating placeholders (genrules want deterministic output). - std::map placeholder_info; + std::map placeholder_info; for (int i = 0; i < config.feed_size(); ++i) { const tf2xla::Feed* feed = &config.feed(i); - const string name_port = TensorIdToString(feed->id()); + const std::string name_port = TensorIdToString(feed->id()); PlaceholderInfo& info = placeholder_info[name_port]; info.feed = feed; info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(), @@ -529,7 +532,7 @@ absl::Status AddPlaceholdersForFeeds( } // Verify node exists and determine data type. - std::unordered_map name_to_node; + std::unordered_map name_to_node; for (int i = 0; i < graph_def->node_size(); ++i) { name_to_node[graph_def->node(i).name()] = &graph_def->node(i); } @@ -609,25 +612,25 @@ absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, out->clear_node(); // Tensors needed for feeding. - std::set> feed_tensors; + std::set> feed_tensors; for (const tf2xla::Feed& feed : config.feed()) { feed_tensors.insert( std::make_pair(feed.id().node_name(), feed.id().output_index())); } // Maps node name to reachability. - std::unordered_map> node_by_name; + std::unordered_map> node_by_name; for (const NodeDef& node : in.node()) { node_by_name[node.name()] = std::pair(false, &node); } // Traverse. - std::queue name_queue; + std::queue name_queue; for (int i = 0; i < config.fetch_size(); ++i) { name_queue.push(config.fetch(i).id().node_name()); } while (!name_queue.empty()) { - const string name = name_queue.front(); + const std::string name = name_queue.front(); name_queue.pop(); auto find_it = node_by_name.find(name); @@ -642,9 +645,9 @@ absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, map_entry.first = true; // Push input nodes of the currently visited node to name_queue. - for (const string& in_edge : map_entry.second->input()) { + for (const std::string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = string(id.first); + const std::string node_name = std::string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); @@ -668,7 +671,7 @@ absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, return absl::OkStatus(); } -string TensorIdToString(const tf2xla::TensorId& id) { +std::string TensorIdToString(const tf2xla::TensorId& id) { return absl::StrCat(id.node_name(), ":", id.output_index()); } @@ -682,7 +685,7 @@ absl::Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { std::optional sharding, ParseShardingFromDevice( *possible_match, - /*num_cores_per_replica=*/std::numeric_limits::max(), + /*num_cores_per_replica=*/std::numeric_limits::max(), /*add_metadata=*/false)); if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) { const int core_annotation = sharding.value().tile_assignment_devices(0); @@ -709,7 +712,7 @@ void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, } namespace { -uint32 InitialRandomSeed() { +uint32_t InitialRandomSeed() { // Support plumbing the TF seed through to XLA is being worked on. // If a user wants deterministic behavior, their best option // is to start with a known checkpoint. This also handles issues when @@ -724,13 +727,13 @@ uint32 InitialRandomSeed() { } } // namespace -uint32 GetXLARandomSeed() { +uint32_t GetXLARandomSeed() { // We initialize counter with an odd number and increment it by two // everytime. This ensures that it will never be zero, even // after an overflow. When seeded with zero, some XLA backends // can return all zeros instead of random numbers. - static std::atomic counter(InitialRandomSeed()); - uint32 seed = counter.fetch_add(2); + static std::atomic counter(InitialRandomSeed()); + uint32_t seed = counter.fetch_add(2); std::srand(seed); return std::rand() | 1; } @@ -766,7 +769,7 @@ bool HasAssociatedFunction(const NodeDef& node_def, std::vector GetAssociatedFunctions( const Node& node, const FunctionLibraryDefinition* fld) { std::vector results; - const string& op = node.type_string(); + const std::string& op = node.type_string(); if (fld->Contains(op)) { // This is a function call node. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); @@ -795,7 +798,7 @@ std::vector GetAssociatedFunctions( absl::Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, - const string& rewritten_function_name) { + const std::string& rewritten_function_name) { switch (associated_function.type()) { case AssociatedFunctionInfo::kFunctionCallNode: { // Change this node to call the new function. @@ -834,7 +837,7 @@ absl::Status RewriteAssociatedFunction( GradientDef gradient_def; gradient_def.set_function_name(func.name()); gradient_def.set_gradient_func(rewritten_function_name); - string original_grad_func = fld->FindGradient(func.name()); + std::string original_grad_func = fld->FindGradient(func.name()); if (original_grad_func.empty()) { TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def)); } else if (original_grad_func != rewritten_function_name) { @@ -863,9 +866,9 @@ absl::Status RewriteAssociatedFunction( } absl::Status CachedFunctionHandles::GetOrInstantiate( - const string& func_name, AttrSlice attrs, + const std::string& func_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle) { - string canonicalized_name = Canonicalize(func_name, attrs); + std::string canonicalized_name = Canonicalize(func_name, attrs); auto iter = handles_.find(canonicalized_name); if (iter != handles_.end()) { *handle = iter->second; @@ -919,8 +922,8 @@ absl::StatusOr ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) { } absl::StatusOr BuildIdentityNode( - Graph* graph, const string& node_name, DataType dtype, const Node* input, - std::optional requested_device) { + Graph* graph, const std::string& node_name, DataType dtype, + const Node* input, std::optional requested_device) { // Create identity node. NodeDef ndef; ndef.set_name(node_name); @@ -975,7 +978,7 @@ absl::Status PruneUnreachableFunctionsFromGraph( g.ToGraphDef(&graph_def); FunctionLibraryDefinition reachable_functions = fld->ReachableDefinitions(graph_def); - for (const string& func_name : fld->ListFunctionNames()) { + for (const std::string& func_name : fld->ListFunctionNames()) { if (!reachable_functions.Find(func_name)) { TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name)); } @@ -1106,7 +1109,7 @@ absl::Status RewriteTensorListWithConstElement(Graph* g, // Add rewritten backward While body function. FunctionDef new_fdef; - string new_name = fld->UniqueFunctionName( + std::string new_name = fld->UniqueFunctionName( absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_")); TF_RETURN_IF_ERROR( GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef)); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index f2ce3944ac158c..4da5a474d964dc 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -41,7 +41,8 @@ absl::Status ValidateConfig(const tf2xla::Config& config); // feeds). absl::Status AddPlaceholdersForFeeds( const tf2xla::Config& config, const OpRegistryInterface* op_registry, - std::unordered_map* feed_remapping, GraphDef* graph_def); + std::unordered_map* feed_remapping, + GraphDef* graph_def); // Returns in a copy of , pruned to only include fetches from // . @@ -49,7 +50,7 @@ absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, GraphDef* out); // Returns node:port for the given . -string TensorIdToString(const tf2xla::TensorId& id); +std::string TensorIdToString(const tf2xla::TensorId& id); // Updates the sharding of based on the sharding of its neighbors. // If is true, outgoing edges from are considered; else incoming @@ -61,7 +62,7 @@ void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef); // Returns the next random seed to use for seeding xla rng. -uint32 GetXLARandomSeed(); +uint32_t GetXLARandomSeed(); // Indicates how a FunctionDef is associated with a graph node (e.g. the node is // a function call, or the node has function attrs). @@ -74,14 +75,14 @@ class AssociatedFunctionInfo { }; // The function is an attr of the node. - static AssociatedFunctionInfo FunctionAttr(const string& func_name, + static AssociatedFunctionInfo FunctionAttr(const std::string& func_name, const AttrValueMap& attrs, - const string& attr_name) { + const std::string& attr_name) { return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name); } // The node is a function call. - static AssociatedFunctionInfo FunctionCall(const string& func_name, + static AssociatedFunctionInfo FunctionCall(const std::string& func_name, const AttrValueMap& attrs) { // attr_name will not be used in this case. return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs, @@ -89,7 +90,7 @@ class AssociatedFunctionInfo { } // The node is a SymbolicGradient op. - static AssociatedFunctionInfo SymbolicGradient(const string& func_name, + static AssociatedFunctionInfo SymbolicGradient(const std::string& func_name, const AttrValueMap& attrs) { // attr_name will not be used in this case. return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs, @@ -98,15 +99,17 @@ class AssociatedFunctionInfo { AssociatedFunctionType type() const { return type_; } - const string& func_name() const { return func_name_; } + const std::string& func_name() const { return func_name_; } - const string& attr_name() const { return attr_name_; } + const std::string& attr_name() const { return attr_name_; } const AttrValueMap& attrs() const { return attrs_; } private: - AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name, - const AttrValueMap& attrs, const string& attr_name) + AssociatedFunctionInfo(AssociatedFunctionType type, + const std::string& func_name, + const AttrValueMap& attrs, + const std::string& attr_name) : type_(type), func_name_(func_name), attrs_(attrs), @@ -114,11 +117,11 @@ class AssociatedFunctionInfo { // Available for all instances. AssociatedFunctionType type_; - string func_name_; + std::string func_name_; AttrValueMap attrs_; // Only available if the function is defined in an attr. - string attr_name_; + std::string attr_name_; }; // Returns if the NodeDef has associated function. @@ -142,7 +145,7 @@ std::vector GetAssociatedFunctions( absl::Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, - const string& rewritten_function_name); + const std::string& rewritten_function_name); // Class to act as cache for FunctionLibraryRuntime::Handle objects. class CachedFunctionHandles { @@ -152,7 +155,7 @@ class CachedFunctionHandles { // Populates `handle` for requested function and attributes. If we have // instantiated the function with the same attributes before, `handle` will be // cached handle; otherwise instantiate the function and populate `handle`. - absl::Status GetOrInstantiate(const string& func_name, AttrSlice attrs, + absl::Status GetOrInstantiate(const std::string& func_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle); // Releases all handles in the cache. Returns first non-OK status if any; @@ -163,7 +166,7 @@ class CachedFunctionHandles { private: FunctionLibraryRuntime* flr_; - std::map handles_; + std::map handles_; CachedFunctionHandles(const CachedFunctionHandles&) = delete; void operator=(const CachedFunctionHandles&) = delete; @@ -179,9 +182,9 @@ struct OutEdgeInfo { absl::StatusOr ReplaceNode(Graph* g, Node* n, const NodeDef& node_def); // Helper function that builds an Identity node. -absl::StatusOr BuildIdentityNode(Graph* graph, const string& node_name, - DataType dtype, const Node* input, - std::optional requested_device); +absl::StatusOr BuildIdentityNode( + Graph* graph, const std::string& node_name, DataType dtype, + const Node* input, std::optional requested_device); // For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite // body functions to use the Const nodes instead of original _Arg nodes. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index e66a8a38813474..ef64b82f50e5be 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -157,7 +157,7 @@ TEST(ValidateConfig, ConflictingFetchName) { ExpectErrorContains(ValidateConfig(config), "conflicting fetch name"); } -static tf2xla::Config FetchesConfig(std::vector fetches) { +static tf2xla::Config FetchesConfig(std::vector fetches) { tf2xla::Config config; for (const auto& fetch_node_name : fetches) { auto* fetch = config.add_fetch(); @@ -409,7 +409,7 @@ TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) { TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn)); const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name()); ASSERT_NE(rewritten_body_fn, nullptr); - std::unordered_map nodes; + std::unordered_map nodes; for (const NodeDef& node_def : rewritten_body_fn->node_def()) { nodes[node_def.name()] = node_def; } diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index ec456344bcfced..007ecef7492600 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -87,6 +87,9 @@ absl::Status DataTypeToPrimitiveType(DataType data_type, case tensorflow::DT_FLOAT8_E5M2FNUZ: *type = xla::F8E5M2FNUZ; return absl::OkStatus(); + case tensorflow::DT_FLOAT4_E2M1FN: + *type = xla::F4E2M1FN; + return absl::OkStatus(); case tensorflow::DT_BFLOAT16: *type = xla::BF16; return absl::OkStatus(); @@ -122,6 +125,7 @@ absl::StatusOr EncodePrimitiveTypeAsDataType( {xla::F8E4M3FNUZ, DT_FLOAT8_E4M3FNUZ}, {xla::F8E4M3B11FNUZ, DT_FLOAT8_E4M3B11FNUZ}, {xla::F8E5M2FNUZ, DT_FLOAT8_E5M2FNUZ}, + {xla::F4E2M1FN, DT_FLOAT4_E2M1FN}, {xla::BF16, DT_BFLOAT16}, {xla::F16, DT_HALF}, {xla::F32, DT_FLOAT}, diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 215decdb4d8843..add79c369b69ef 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -39,7 +39,7 @@ class XlaCompilationAllocator : public Allocator { XlaCompilationAllocator() {} ~XlaCompilationAllocator() override {} - string Name() override { return "xla_compilation"; } + std::string Name() override { return "xla_compilation"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override { // Regardless of the size requested, always allocates an XlaExpression. diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 4603bbf119a8bf..5ee45e499cb49e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -25,16 +25,16 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/allocator.h" #include "xla/backends/cpu/runtime/rng_state_lib.h" -#include "xla/cpu_function_runtime.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { -int32 GetResultIndex(const int32* result_index_table, int32 num_results) { +int32_t GetResultIndex(const int32_t* result_index_table, int32_t num_results) { auto it = std::min_element(result_index_table, result_index_table + num_results); @@ -72,7 +72,7 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, alloc_mode == AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS; // Allocate arg and temp buffers. alloc_buffer_table_ = tensorflow::MallocContiguousBuffers( - static_data.buffer_infos_, static_data.num_buffers_, + absl::MakeConstSpan(static_data.buffer_infos_, static_data.num_buffers_), /*allocate_entry_params=*/allocate_entry_params, buffer_table_, /*annotate_initialized=*/true); // If Hlo profiling is enabled the generated code expects an appropriately @@ -150,7 +150,7 @@ int LookupNameIndex(absl::string_view name, const char** names) { } // namespace -int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const { +int XlaCompiledCpuFunction::LookupArgIndex(const std::string& name) const { return LookupNameIndex(name, arg_names_); } @@ -162,7 +162,7 @@ int XlaCompiledCpuFunction::LookupVariableIndex(absl::string_view name) const { return num_args_ - num_variables_ + index; } -int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const { +int XlaCompiledCpuFunction::LookupResultIndex(const std::string& name) const { return LookupNameIndex(name, result_names_); } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 009650d76109bb..061982db6fd08f 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -28,9 +28,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h" #include "xla/backends/cpu/alignment.h" +#include "xla/backends/cpu/buffer_allocation_info.h" #include "xla/backends/cpu/runtime/rng_state_lib.h" -#include "xla/cpu_function_runtime.h" #include "xla/executable_run_options.h" #include "xla/service/custom_call_status_internal.h" #include "tensorflow/core/platform/types.h" @@ -123,19 +124,19 @@ class XlaCompiledCpuFunction { // End serialized thunk execution specific // Contains information about the buffers used by the XLA computation. - const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; + const xla::cpu::BufferAllocationInfo* buffer_infos_ = nullptr; int32_t num_buffers_ = 0; // Result parameter i is described by // buffer_infos[result_index_table[i]]. - const int32* result_index_table_ = nullptr; + const int32_t* result_index_table_ = nullptr; // There are num_results result parameters. int64_t num_results_ = 0; // Entry parameter i is described by // buffer_infos[arg_index_table[i]]. - const int32* arg_index_table_ = nullptr; + const int32_t* arg_index_table_ = nullptr; // There are num_args entry parameters. int64_t num_args_ = 0; @@ -209,7 +210,7 @@ class XlaCompiledCpuFunction { // TODO(fschneider): For now this always returns an empty string because there // is no support for error reporting in XLA. Remove this once all callers are // updated. - string error_msg() const { return error_msg_; } + std::string error_msg() const { return error_msg_; } void set_error_msg(absl::string_view error_msg) { error_msg_ = error_msg; } @@ -251,9 +252,7 @@ class XlaCompiledCpuFunction { // called for each positional argument, in order to set the argument buffers. // // Allocated memory must be aligned to the size specified by - // xla::cpu_function_runtime::MinAlign(). If possible, use the functions in - // tensorflow/compiler/tf2xla/cpu_function_runtime.h to ensure correct - // alignment. + // xla::cpu::MinAlign(). // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. @@ -304,7 +303,7 @@ class XlaCompiledCpuFunction { // The index remains constant for every instance of XlaCompiledCpuFunction // generated from the same static data, and might not be cheap to determine. // Recommended usage is to capture this in a variable for re-use. - int LookupArgIndex(const string& name) const; + int LookupArgIndex(const std::string& name) const; // Returns the 0-based index for the variable with the given `name`. // Returns -1 if the name wasn't found, or data isn't available. @@ -320,7 +319,7 @@ class XlaCompiledCpuFunction { // The index remains constant for every instance of XlaCompiledCpuFunction // generated from the same static data, and might not be cheap to determine. // Recommended usage is to capture this in a variable for re-use. - int LookupResultIndex(const string& name) const; + int LookupResultIndex(const std::string& name) const; // Returns the name of the argument at `index`. // Returns nullptr if `HasNameIndices() == false` or `index` is out of range. @@ -362,11 +361,11 @@ class XlaCompiledCpuFunction { return temp_allocation_index_; } - const xla::cpu_function_runtime::BufferInfo* buffer_infos() const { + const xla::cpu::BufferAllocationInfo* buffer_infos() const { return buffer_infos_; } - int32 num_buffers() const { return num_buffers_; } + int32_t num_buffers() const { return num_buffers_; } void** buffer_table() const { return buffer_table_; } @@ -415,7 +414,7 @@ class XlaCompiledCpuFunction { static void set_static_data_buffer_infos( StaticData* static_data, - const xla::cpu_function_runtime::BufferInfo* buffer_infos) { + const xla::cpu::BufferAllocationInfo* buffer_infos) { static_data->buffer_infos_ = buffer_infos; } @@ -425,7 +424,7 @@ class XlaCompiledCpuFunction { } static void set_static_data_result_index_table( - StaticData* static_data, const int32* result_index_table) { + StaticData* static_data, const int32_t* result_index_table) { static_data->result_index_table_ = result_index_table; } @@ -435,7 +434,7 @@ class XlaCompiledCpuFunction { } static void set_static_data_arg_index_table(StaticData* static_data, - const int32* arg_index_table) { + const int32_t* arg_index_table) { static_data->arg_index_table_ = arg_index_table; } @@ -531,22 +530,22 @@ class XlaCompiledCpuFunction { void** const buffer_table_; // Describes the buffers used by the XLA computation. - const xla::cpu_function_runtime::BufferInfo* const buffer_infos_; - const int32 num_buffers_; + const xla::cpu::BufferAllocationInfo* const buffer_infos_; + const int32_t num_buffers_; // Indices of expanded result tuple. - const int32 num_results_; - const int32* const result_index_table_; + const int32_t num_results_; + const int32_t* const result_index_table_; // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. - const int32* const arg_index_table_; + const int32_t* const arg_index_table_; // The number of incoming arguments. - const int32 num_args_; + const int32_t num_args_; // The number of incoming variables. - const int32 num_variables_; + const int32_t num_variables_; // Shapes of the input arguments. const ShapeInfo* const arg_shape_infos_; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9e761dc6003d80..5088badf28e9cb 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -130,7 +131,7 @@ ComputeArgAndRetvalShardings(const Graph& graph) { [](const Node* n) -> absl::StatusOr> { TF_ASSIGN_OR_RETURN( auto sharding, - ParseShardingFromDevice(*n, std::numeric_limits::max(), + ParseShardingFromDevice(*n, std::numeric_limits::max(), /*add_metadata=*/false)); return sharding; }; @@ -173,7 +174,7 @@ absl::Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, xla_context->Ref(); absl::Status status; auto step_container = std::make_unique( - step_id, [&status, device](const string& name) { + step_id, [&status, device](const std::string& name) { status = device->resource_manager()->Cleanup(name); }); TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(), @@ -484,8 +485,8 @@ absl::Status BuildComputation( } // namespace -string XlaCompiler::Argument::HumanString() const { - string common; +std::string XlaCompiler::Argument::HumanString() const { + std::string common; if (!name.empty()) { common = absl::StrCat(" name=", name); } @@ -503,7 +504,7 @@ string XlaCompiler::Argument::HumanString() const { return absl::StrCat("kind=constant-resource", common, " value=", constant_value.DebugString()); case kResource: { - string output = absl::StrCat( + std::string output = absl::StrCat( "kind=resource", common, " resource_kind=", XlaResource::KindToString(resource_kind), " initialized=", initialized, " is_fast_mem=", fast_mem); @@ -543,7 +544,7 @@ XlaCompiler::Argument::DimensionSizesAsInlinedVector() const { } } -string XlaCompiler::Argument::ShapeHumanString() const { +std::string XlaCompiler::Argument::ShapeHumanString() const { if (absl::holds_alternative(shape)) { return std::get(shape).DebugString(); } else { @@ -592,9 +593,9 @@ XlaCompiler::~XlaCompiler() = default; int64_t XlaCompiler::NextStepId() { return next_step_id_++; } -uint64 XlaCompiler::SignatureHash::operator()( - const std::pair>& signature) const { - return std::hash()(signature.first); +uint64_t XlaCompiler::SignatureHash::operator()( + const std::pair>& signature) const { + return std::hash()(signature.first); } static absl::Status GetFunctionBody(const NameAttrList& function, @@ -703,9 +704,9 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) .IgnoreError(); auto node_name_index = graph->BuildNodeNameIndex(); - std::unordered_map> shape_map; + std::unordered_map> shape_map; for (const auto& node_shape_info : shape_info) { - const string& node_name = node_shape_info.first; + const std::string& node_name = node_shape_info.first; const std::vector& output_shapes = node_shape_info.second; const auto& node_iter = node_name_index.find(node_name); if (node_iter != node_name_index.end()) { @@ -726,9 +727,9 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) .IgnoreError(); auto node_name_index = graph->BuildNodeNameIndex(); - std::unordered_map> shape_map; + std::unordered_map> shape_map; for (const auto& node_shape_info : shape_info) { - const string& node_name = node_shape_info.first; + const std::string& node_name = node_shape_info.first; const std::vector& output_shapes = node_shape_info.second; const auto& node_iter = node_name_index.find(node_name); if (node_iter != node_name_index.end()) { @@ -754,7 +755,7 @@ std::vector GetValidControlRets( // the map with nodes in FunctionDef control_ret_nodes and later query it // using the nodes in `graph`. The Node pointers would be different but the // Node name is expected to remain the same between the two. - absl::flat_hash_map control_ret_nodes_map; + absl::flat_hash_map control_ret_nodes_map; for (int i = 0; i < orig_control_ret_nodes.size(); ++i) { const Node* n = orig_control_ret_nodes[i]; control_ret_nodes_map[n->name()] = i; @@ -814,7 +815,7 @@ absl::Status XlaCompiler::CompileFunction( const NameAttrList& fn_name_attrs, absl::Span args, XlaCompiler::CompilationResult* result) { - string function_id = + std::string function_id = Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; @@ -1325,7 +1326,7 @@ namespace { absl::Status ValidateFunctionDef(const FunctionDef* fdef, const FunctionLibraryDefinition& flib_def) { for (const NodeDef& node : fdef->node_def()) { - const string& op = node.op(); + const std::string& op = node.op(); if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { continue; } @@ -1340,7 +1341,8 @@ absl::Status ValidateFunctionDef(const FunctionDef* fdef, // Returned pointer points to the internal string either in node's attributes // or in its NodeDef. This pointer is valid as long as the node has not been // modified. -absl::Status GetPotentialFunctionName(const Node& node, const string** name) { +absl::Status GetPotentialFunctionName(const Node& node, + const std::string** name) { if (node.IsPartitionedCall()) { const AttrValue* attr_value; TF_RETURN_IF_ERROR( @@ -1361,7 +1363,8 @@ absl::Status GetPotentialFunctionName(const Node& node, const string** name) { // given device_type, invalid data type, missing attributes...) absl::Status ValidateGraph(const Graph* graph, const FunctionLibraryDefinition& flib_def, - const DeviceType& device_type, const string& name) { + const DeviceType& device_type, + const std::string& name) { // Make sure the XLA compilation kernels are registered. This operation is // idempotent so it is fine if someone called it already. XlaOpRegistry::RegisterCompilationKernels(); @@ -1398,7 +1401,7 @@ absl::Status ValidateGraph(const Graph* graph, if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { continue; } - const string* function_name; + const std::string* function_name; TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name)); const FunctionDef* fdef = flib_def.Find(*function_name); absl::Status s; @@ -1455,6 +1458,36 @@ class DummyStackTrace : public AbstractStackTrace { }; namespace { +const xla::HloInstructionProto* FindInstructionById( + const xla::HloComputationProto& computation, int64_t id) { + auto iter = + absl::c_find_if(computation.instructions(), + [id](const xla::HloInstructionProto& instruction) { + return instruction.id() == id; + }); + if (iter == computation.instructions().end()) { + return nullptr; + } + return &(*iter); +} + +bool ShouldAddPrecisionToInstruction( + const xla::HloInstructionProto& instruction, + const xla::HloComputationProto& computation) { + static constexpr std::array kOpsPossiblyUsingTF32 = { + "dot", "convolution"}; + if (!absl::c_linear_search(kOpsPossiblyUsingTF32, instruction.opcode())) { + return false; + } + if (instruction.shape().element_type() == xla::F32) { + return true; + } + return absl::c_any_of(instruction.operand_ids(), [&](int64_t operand_id) { + const xla::HloInstructionProto* operand = + FindInstructionById(computation, operand_id); + return operand && operand->shape().element_type() == xla::F32; + }); +} // Add precisions configs to the HLO module to avoid TensorFloat32 computations // in XLA. @@ -1462,13 +1495,7 @@ namespace { // Some operations, such as Einsum are converted through MlirXlaOpKernel, which // doesn't set the precisions, so we set them all here. // -// TODO(tdanyluk): We may want to restrict this logic to only set the operand -// precision for F32 operands. (Historically, it was set without regard to -// operand type in other parts of TF2XLA.) void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { - static constexpr std::array kOpsPossiblyUsingTF32 = { - "dot", "convolution"}; - xla::PrecisionConfig precision_config; precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); @@ -1476,8 +1503,7 @@ void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { for (xla::HloComputationProto& computation : *module.mutable_computations()) { for (xla::HloInstructionProto& instruction : *computation.mutable_instructions()) { - if (absl::c_find(kOpsPossiblyUsingTF32, instruction.opcode()) != - kOpsPossiblyUsingTF32.end()) { + if (ShouldAddPrecisionToInstruction(instruction, computation)) { *instruction.mutable_precision_config() = precision_config; } } @@ -1487,7 +1513,7 @@ void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { } // namespace absl::Status XlaCompiler::CompileGraph( - const XlaCompiler::CompileOptions& options, string const& name, + const XlaCompiler::CompileOptions& options, const std::string& name, std::unique_ptr graph, absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name; @@ -1689,7 +1715,7 @@ xla::ChannelHandle XlaCompiler::NewChannel( return new_handle; } -absl::Status XlaCompiler::GetChannelHandle(const string& key, +absl::Status XlaCompiler::GetChannelHandle(const std::string& key, xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { @@ -1701,7 +1727,7 @@ absl::Status XlaCompiler::GetChannelHandle(const string& key, } absl::Status XlaCompiler::GetHostToDeviceChannelHandle( - const string& key, xla::ChannelHandle* channel) { + const std::string& key, xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { result.first->second = NewChannel(xla::ChannelHandle::HOST_TO_DEVICE); @@ -1712,7 +1738,7 @@ absl::Status XlaCompiler::GetHostToDeviceChannelHandle( } absl::Status XlaCompiler::GetDeviceToHostChannelHandle( - const string& key, xla::ChannelHandle* channel) { + const std::string& key, xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { result.first->second = NewChannel(xla::ChannelHandle::DEVICE_TO_HOST); @@ -1724,7 +1750,7 @@ absl::Status XlaCompiler::GetDeviceToHostChannelHandle( namespace { -void SetTransfer(const string& key, absl::Span types, +void SetTransfer(const std::string& key, absl::Span types, absl::Span shapes, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); @@ -1739,7 +1765,7 @@ void SetTransfer(const string& key, absl::Span types, } // namespace absl::Status XlaCompiler::SetDeviceToHostMetadata( - const string& key, absl::Span types, + const std::string& key, absl::Span types, absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key]; @@ -1759,7 +1785,7 @@ absl::Status XlaCompiler::SetDeviceToHostMetadata( } absl::Status XlaCompiler::GetDeviceToHostShapes( - const string& key, std::vector* shapes) const { + const std::string& key, std::vector* shapes) const { const auto iter = host_compute_sends_.find(key); if (iter == host_compute_sends_.end()) { return errors::InvalidArgument( @@ -1774,7 +1800,7 @@ absl::Status XlaCompiler::GetDeviceToHostShapes( } absl::Status XlaCompiler::SetHostToDeviceMetadata( - const string& key, absl::Span types, + const std::string& key, absl::Span types, absl::Span shapes) { if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) { tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key]; @@ -1794,7 +1820,7 @@ absl::Status XlaCompiler::SetHostToDeviceMetadata( } absl::Status XlaCompiler::GetHostComputeControlDependency( - const string& host_compute_name, xla::XlaOp* handle) { + const std::string& host_compute_name, xla::XlaOp* handle) { const auto iter = host_compute_control_output_.find(host_compute_name); if (iter == host_compute_control_output_.end()) { return errors::InvalidArgument( @@ -1807,7 +1833,7 @@ absl::Status XlaCompiler::GetHostComputeControlDependency( } absl::Status XlaCompiler::SetHostComputeControlDependency( - const string& host_compute_name, const xla::XlaOp handle) { + const std::string& host_compute_name, const xla::XlaOp handle) { if (host_compute_control_output_.find(host_compute_name) != host_compute_control_output_.end()) { return errors::InvalidArgument( @@ -1819,7 +1845,7 @@ absl::Status XlaCompiler::SetHostComputeControlDependency( } void XlaCompiler::PushNodeTokenMapping() { - node_token_mapping_stack_.emplace(std::map{}); + node_token_mapping_stack_.emplace(std::map{}); } absl::Status XlaCompiler::PopNodeTokenMapping() { @@ -1832,7 +1858,7 @@ absl::Status XlaCompiler::PopNodeTokenMapping() { return absl::OkStatus(); } -absl::Status XlaCompiler::SetNodeToken(const string& node_name, +absl::Status XlaCompiler::SetNodeToken(const std::string& node_name, const xla::XlaOp op) { if (node_token_mapping_stack_.empty()) { return errors::FailedPrecondition( @@ -1847,7 +1873,8 @@ absl::Status XlaCompiler::SetNodeToken(const string& node_name, return absl::OkStatus(); } -absl::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { +absl::StatusOr XlaCompiler::GetNodeToken( + const std::string& node_name) { if (node_token_mapping_stack_.empty()) { return errors::FailedPrecondition( "Calling GetNodeToken() when node_token_mapping_stack_ is " diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 2beb730eb06fa3..216125f9cb153e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -277,7 +277,8 @@ class XlaCompiler { // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. - absl::Status CompileGraph(const CompileOptions& options, string const& name, + absl::Status CompileGraph(const CompileOptions& options, + const std::string& name, std::unique_ptr graph, absl::Span args, CompilationResult* result); @@ -295,31 +296,32 @@ class XlaCompiler { // Channel handles can be used to communicate between different // computations. Computations that communicate should be compiled with the // same XlaCompiler. - absl::Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + absl::Status GetChannelHandle(const std::string& key, + xla::ChannelHandle* channel); // Retrieves the host-to-device channel handle associated with `key`. // Allocates a new channel handle if none exists. - absl::Status GetHostToDeviceChannelHandle(const string& key, + absl::Status GetHostToDeviceChannelHandle(const std::string& key, xla::ChannelHandle* channel); // Retrieves the device-to-host channel handle associated with `key`. // Allocates a new channel handle if none exists. - absl::Status GetDeviceToHostChannelHandle(const string& key, + absl::Status GetDeviceToHostChannelHandle(const std::string& key, xla::ChannelHandle* channel); // Sets the shapes and types for the device to host transfer associated with // 'key'. - absl::Status SetDeviceToHostMetadata(const string& key, + absl::Status SetDeviceToHostMetadata(const std::string& key, absl::Span types, absl::Span shapes); // Gets the shapes the device to host transfer associated with 'key'. - absl::Status GetDeviceToHostShapes(const string& key, + absl::Status GetDeviceToHostShapes(const std::string& key, std::vector* shapes) const; // Sets the shapes and types for the host to device transfer associated with // 'key'. - absl::Status SetHostToDeviceMetadata(const string& key, + absl::Status SetHostToDeviceMetadata(const std::string& key, absl::Span types, absl::Span shapes); @@ -334,10 +336,10 @@ class XlaCompiler { // 'host_compute_name' can be any string the client wishes to use to identify // a given HostCompute Op as long as the names are unique within the // compilation. - absl::Status GetHostComputeControlDependency(const string& host_compute_name, - xla::XlaOp* handle); - absl::Status SetHostComputeControlDependency(const string& host_compute_name, - xla::XlaOp handle); + absl::Status GetHostComputeControlDependency( + const std::string& host_compute_name, xla::XlaOp* handle); + absl::Status SetHostComputeControlDependency( + const std::string& host_compute_name, xla::XlaOp handle); const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } @@ -345,8 +347,8 @@ class XlaCompiler { void PushNodeTokenMapping(); absl::Status PopNodeTokenMapping(); - absl::Status SetNodeToken(const string& node_name, xla::XlaOp op); - absl::StatusOr GetNodeToken(const string& node_name); + absl::Status SetNodeToken(const std::string& node_name, xla::XlaOp op); + absl::StatusOr GetNodeToken(const std::string& node_name); // Sets the function body `fbody` to the one registered as `function`. absl::Status FindFunctionBody(const NameAttrList& function, @@ -405,20 +407,22 @@ class XlaCompiler { FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. struct SignatureHash { - uint64 operator()( - const std::pair>& signature) const; + uint64_t operator()( + const std::pair>& signature) const; }; - std::unordered_map>, + std::unordered_map>, CompilationResult, SignatureHash> cache_; - std::unordered_map channels_; + std::unordered_map channels_; - std::unordered_map host_compute_sends_; - std::unordered_map host_compute_recvs_; + std::unordered_map + host_compute_sends_; + std::unordered_map + host_compute_recvs_; - std::unordered_map host_compute_control_output_; + std::unordered_map host_compute_control_output_; // This is used to store mapping. Side-effecting // ops call SetNodeToken() to record its token output, so later side-effecting @@ -427,7 +431,7 @@ class XlaCompiler { // It's a stack because we need a mapping like this for each level of nested // CompileGraph() call. In CompileGraph(), we will push a new mapping to the // stack, and pop the mapping before returning. - std::stack> node_token_mapping_stack_; + std::stack> node_token_mapping_stack_; XlaCompiler(const XlaCompiler&) = delete; void operator=(const XlaCompiler&) = delete; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index a3090e81f84a82..2c149eacda678e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -140,7 +140,7 @@ namespace { // compiled kernels. class DummyResourceForTest : public ResourceBase { public: - string DebugString() const override { return "dummy"; } + std::string DebugString() const override { return "dummy"; } void Increment() { ++value_; } int Get() { return value_; } @@ -268,8 +268,8 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, Simple) { .value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } @@ -366,8 +366,8 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -484,7 +484,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -602,7 +602,7 @@ TEST_F(XlaCompilerTest, MixedOrderArguments) { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -680,7 +680,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // func(a) { b=7; c=-a; return b, c; } Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); - auto b = ops::Const(scope.WithOpName("B"), 7); + auto b = ops::Const(scope.WithOpName("B"), 7); auto c = ops::Neg(scope.WithOpName("C"), a); auto d = ops::_Retval(scope.WithOpName("D"), b, 0); auto e = ops::_Retval(scope.WithOpName("E"), c, 1); @@ -710,7 +710,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); @@ -718,8 +718,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Execute(*result.computation, {param0_data.get()}).value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); - xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); xla::Literal expected = xla::LiteralUtil::MakeTuple({&expected0, &expected1}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); @@ -885,7 +885,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { // The names of instructions were uniquified by the XlaBuilder and the // unique ids may be different, the rest of the fields should be // identical. - string str1, str2; + std::string str1, str2; LOG(INFO) << "instr1 = " << instr1.DebugString(); LOG(INFO) << "instr2 = " << instr2.DebugString(); instr1.AppendPartialToString(&str1); @@ -904,7 +904,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { auto flow = ops::Const(scope, {}); auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2"); - auto index = ops::Const(scope, 1); + auto index = ops::Const(scope, 1); auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index, grad2.flow_out); auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32); @@ -933,12 +933,12 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { const XlaCompiler::ResourceUpdate& update = result.resource_updates[0]; EXPECT_EQ(0, update.input_index); EXPECT_EQ(DT_INT32, update.type); - EXPECT_EQ((std::set{"grad1", "grad2"}), + EXPECT_EQ((std::set{"grad1", "grad2"}), update.tensor_array_gradients_accessed); // Tests that the generated computation works. - xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr param0_data = client_->TransferToServer(input).value(); @@ -947,10 +947,10 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { client_->Execute(*result.computation, {param0_data.get()}).value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal output_read = xla::LiteralUtil::CreateR0(42); - xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); - xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal output_read = xla::LiteralUtil::CreateR0(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); xla::Literal output_resource = xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); xla::Literal expected_literal = @@ -964,7 +964,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); auto flow = ops::Const(scope, {}); auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); - auto index = ops::Const(scope, 1); + auto index = ops::Const(scope, 1); auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); @@ -996,7 +996,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); auto flow = ops::Const(scope, {}); auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2"); - auto index = ops::Const(scope, 1); + auto index = ops::Const(scope, 1); auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); @@ -1067,8 +1067,8 @@ TEST_F(XlaCompilerTest, FunctionCallWithConstants) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope scope = Scope::NewRootScope().ExitOnError(); - auto value = ops::Const(scope.WithOpName("value"), 1, {}); - auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); NodeDef def; @@ -1151,9 +1151,9 @@ TEST_F(XlaCompilerTest, SliceWithDynamicBegins) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope scope = Scope::NewRootScope().ExitOnError(); - auto value = ops::Const(scope.WithOpName("shape"), {5}, {1}); + auto value = ops::Const(scope.WithOpName("shape"), {5}, {1}); auto begin = ops::_Arg(scope.WithOpName("arg"), DT_INT32, 0); - auto size = ops::Const(scope.WithOpName("value"), {1}, {1}); + auto size = ops::Const(scope.WithOpName("value"), {1}, {1}); TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); @@ -1188,8 +1188,8 @@ TEST_F(XlaCompilerTest, SliceWithDynamicBegins) { void RunAndCheckVariablesComputation( xla::Client* client, const XlaCompiler::CompilationResult& result) { - xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -1201,8 +1201,8 @@ void RunAndCheckVariablesComputation( .value(); xla::Literal actual_literal = client->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); - xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0, &expected1}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); @@ -1220,7 +1220,7 @@ TEST_F(XlaCompilerTest, Variables) { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -1356,7 +1356,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { std::move(graph), args, &result)); // Tests that the generated computation works. - xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param1_data = client_->TransferToServer(param1_literal).value(); @@ -1379,7 +1379,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandle) { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto r = ops::_Retval(scope.WithOpName("R"), var, 0); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1); @@ -1414,7 +1414,7 @@ absl::StatusOr> BuildTestGraph() { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); @@ -1475,9 +1475,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Tests that the generated computation works. xla::Literal param0_literal = - xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); + xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); xla::Literal param1_literal = - xla::LiteralUtil::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -1490,8 +1490,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::Literal actual_literal = client_->Transfer(*actual).value(); xla::Literal expected0 = - xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); - xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); + xla::Literal expected1 = + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0, &expected1}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); @@ -1547,9 +1548,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Tests that the generated computation works. xla::Literal param0_literal = - xla::LiteralUtil::CreateR1({4, 55, 1, -3}); + xla::LiteralUtil::CreateR1({4, 55, 1, -3}); xla::Literal param1_literal = - xla::LiteralUtil::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -1561,8 +1562,10 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { .value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR1({27, 67, 35, 402}); - xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected0 = + xla::LiteralUtil::CreateR1({27, 67, 35, 402}); + xla::Literal expected1 = + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0, &expected1}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); @@ -1587,8 +1590,8 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope scope = Scope::NewRootScope().ExitOnError(); - auto value = ops::Const(scope.WithOpName("value"), 1, {}); - auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); NodeDef def; @@ -1684,7 +1687,8 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { side_effecting_op.set_name("DummySideEffectingOp"); side_effecting_op.set_op("DummySideEffectingOp"); AddNodeAttr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}, &side_effecting_op); + std::vector{kXlaTokenArgNodeName}, + &side_effecting_op); AddNodeAttr(kXlaOriginalOutsideCompilationNodeName, side_effecting_op.name(), &side_effecting_op); absl::Status status; @@ -1768,8 +1772,8 @@ TEST_F(XlaCompilerTest, OpsWithTensorListInput) { } Scope scope = Scope::NewRootScope().ExitOnError(); - auto element_shape = ops::Const(scope, {1}, {1}); - auto max_elements = ops::Const(scope, {10}, {}); + auto element_shape = ops::Const(scope, {1}, {1}); + auto max_elements = ops::Const(scope, {10}, {}); auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); std::initializer_list out = {arg, arg}; auto add_n = ops::AddN(scope, out); @@ -1822,7 +1826,7 @@ TEST_F(XlaCompilerTest, WhileWithResources) { auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1); auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2); - auto less = ops::Less(scope, arg0, ops::Const(scope, 10)); + auto less = ops::Less(scope, arg0, ops::Const(scope, 10)); (void)ops::_Retval(scope.WithOpName("ret"), less, 0); TF_ASSERT_OK(scope.ToGraph(graph.get())); FunctionDef fdef; @@ -1899,9 +1903,9 @@ TEST_F(XlaCompilerTest, WhileWithResources) { ASSERT_EQ(output2.input_index, 2); // Tests that the generated computation works. - xla::Literal literal0 = xla::LiteralUtil::CreateR0(0); - xla::Literal literal1 = xla::LiteralUtil::CreateR0(2); - xla::Literal literal2 = xla::LiteralUtil::CreateR0(1); + xla::Literal literal0 = xla::LiteralUtil::CreateR0(0); + xla::Literal literal1 = xla::LiteralUtil::CreateR0(2); + xla::Literal literal2 = xla::LiteralUtil::CreateR0(1); std::unique_ptr data0 = client_->TransferToServer(literal0).value(); std::unique_ptr data1 = @@ -1916,9 +1920,9 @@ TEST_F(XlaCompilerTest, WhileWithResources) { .value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR0(10); - xla::Literal expected1 = xla::LiteralUtil::CreateR0(2); - xla::Literal expected2 = xla::LiteralUtil::CreateR0(1); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(10); + xla::Literal expected1 = xla::LiteralUtil::CreateR0(2); + xla::Literal expected2 = xla::LiteralUtil::CreateR0(1); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); @@ -1978,7 +1982,7 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) { TEST_F(XlaCompilerTest, AliasResourceUpdates) { Scope scope = Scope::NewRootScope().ExitOnError(); - auto a = ops::Const(scope.WithOpName("A"), {1, 2}); + auto a = ops::Const(scope.WithOpName("A"), {1, 2}); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); auto write = ops::AssignAddVariableOp(scope, var, a); auto read = ops::ReadVariableOp( @@ -2022,7 +2026,7 @@ TEST_F(XlaCompilerTest, AliasResourceUpdates) { TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) { XlaCompiler compiler(DefaultOptions()); - const string& key = "comm_key"; + const std::string& key = "comm_key"; std::vector types{DT_INT32}; std::vector shapes{TensorShape({2})}; @@ -2035,7 +2039,7 @@ TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) { TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) { XlaCompiler compiler(DefaultOptions()); - const string& key = "comm_key"; + const std::string& key = "comm_key"; std::vector types{DT_INT32}; std::vector shapes{TensorShape({2})}; std::vector types2{DT_FLOAT}; @@ -2051,7 +2055,7 @@ TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) { TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) { XlaCompiler compiler(DefaultOptions()); - const string& key = "comm_key"; + const std::string& key = "comm_key"; std::vector types{DT_INT32}; std::vector shapes{TensorShape({2})}; @@ -2064,7 +2068,7 @@ TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) { TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) { XlaCompiler compiler(DefaultOptions()); - const string& key = "comm_key"; + const std::string& key = "comm_key"; std::vector types{DT_INT32}; std::vector shapes{TensorShape({2})}; std::vector types2{DT_FLOAT}; diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 92ddf0125aded1..fad607b1ae1333 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -67,7 +67,7 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, } } -string XlaContext::DebugString() const { return "XLA JIT context"; } +std::string XlaContext::DebugString() const { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { const int64_t retvals_size = retvals_.size(); @@ -84,7 +84,7 @@ XlaResource* XlaContext::AddResource(std::unique_ptr resource) { const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; @@ -100,7 +100,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { return LookupOrCreate(type, &min_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building Min() for " << type_string; xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; @@ -116,7 +116,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { return LookupOrCreate(type, &add_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; @@ -133,7 +133,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { const xla::XlaComputation* XlaContext::GetOrCreateLogAddExp( const DataType type) { return LookupOrCreate(type, &log_add_exp_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building LogAddExp() for " << type_string; xla::XlaBuilder b("log_add_exp<" + type_string + ">"); xla::PrimitiveType xla_type; @@ -154,7 +154,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateLogAddExp( const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { return LookupOrCreate(type, &mul_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building Mul() for " << type_string; xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 9184fb4300633c..1d72f0c756f364 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -50,7 +50,7 @@ class XlaContext : public ResourceBase { const Graph* graph); // Virtual method defined by ResourceBase. - string DebugString() const override; + std::string DebugString() const override; XlaCompiler* compiler() const { return compiler_; } diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 61bd10e413ccf3..e867dd14209ab8 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -73,7 +73,7 @@ XlaExpression XlaExpression::Resource(XlaResource* resource) { return e; } -string XlaExpression::HumanString() const { +std::string XlaExpression::HumanString() const { switch (kind_) { case Kind::kInvalid: return "invalid"; diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index d410b79a3da137..ed0041fc9942a0 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -115,7 +115,7 @@ class XlaExpression { XlaResource* resource() const { return resource_; } // Returns a human-readable summary of the expression. - string HumanString() const; + std::string HumanString() const; // Returns the value of a kValue or kXlaOp as an xla::XlaOp. Returns // an erroneous XlaOp if the expression is not a constant or an expression. diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc index 7a0cc34de9af2e..797002476aeb1c 100644 --- a/tensorflow/compiler/tf2xla/xla_expression_test.cc +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -38,14 +38,15 @@ class XlaExpressionTest : public ::testing::Test { void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); builder_ = std::make_unique("acomputation"); - constant_ = test::AsScalar(42); - op_ = xla::ConstantR0(builder_.get(), 7); + constant_ = test::AsScalar(42); + op_ = xla::ConstantR0(builder_.get(), 7); non_constant_op_ = xla::Parameter( builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x"); resource_ = std::make_unique( - XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"), - DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1, - /*tensor_array_gradients=*/std::set(), + XlaResource::kVariable, /*arg_num=*/0, + /*name=*/std::string("avariable"), DT_INT32, TensorShape({17, 3}), op_, + /*tensor_array_size=*/-1, + /*tensor_array_gradients=*/std::set(), /*tensor_array_multiple_writes_aggregate=*/false); } @@ -87,8 +88,8 @@ TEST_F(XlaExpressionTest, AsXlaOp) { builder_->BuildConstantSubGraph(const_as_op)); TF_ASSERT_OK_AND_ASSIGN(xla::Literal value, client_->ComputeConstant(computation)); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0(42), - value)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + xla::LiteralUtil::CreateR0(42), value)); } TEST_F(XlaExpressionTest, GetShape) { @@ -120,7 +121,7 @@ TEST_F(XlaExpressionTest, ResolveConstant) { std::optional op_constant, XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_)); ASSERT_TRUE(op_constant.has_value()); - test::ExpectTensorEqual(test::AsScalar(7), *op_constant); + test::ExpectTensorEqual(test::AsScalar(7), *op_constant); TF_ASSERT_OK_AND_ASSIGN(std::optional op_nonconstant, XlaExpression::XlaOp(non_constant_op_, DT_FLOAT) @@ -131,7 +132,7 @@ TEST_F(XlaExpressionTest, ResolveConstant) { std::optional constant_constant, XlaExpression::Constant(constant_).ResolveConstant(client_)); ASSERT_TRUE(constant_constant.has_value()); - test::ExpectTensorEqual(constant_, *constant_constant); + test::ExpectTensorEqual(constant_, *constant_constant); } TEST_F(XlaExpressionTest, ResolveConstantOnResource) { diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index eb91ed5c3f78d6..45814517342abc 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -233,7 +233,7 @@ absl::Status ResolveDeviceAssignment( // For GPU collectives, `xla_global_id`s are arbitrary integers, and XLA // requires a mapping from local device IDs to global device IDs. const DeviceMgr* device_mgr = ctx->function_library()->device_mgr(); - std::map global_device_ids; + absl::btree_map global_device_ids; for (int device_idx = 0; device_idx < params->group.group_size; device_idx++) { @@ -246,8 +246,8 @@ absl::Status ResolveDeviceAssignment( // This is a local device, so include it in the mapping. const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = resolved_device->tensorflow_accelerator_device_info(); - global_device_ids[accelerator_device_info->stream->parent() - ->device_ordinal()] = + global_device_ids[xla::LocalDeviceId( + accelerator_device_info->stream->parent()->device_ordinal())] = device_attributes.xla_global_id(); } } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 38f01c83db8251..0b3425e5b8524a 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -136,7 +136,7 @@ struct XlaResourceUpdate { bool modified; // If the resource is a TensorArray, the set of gradients read or written. - std::set tensor_array_gradients_accessed; + std::set tensor_array_gradients_accessed; }; struct XlaCompilationResult { diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index ad571976cbfcf5..b374e8c8e81dd6 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -26,13 +26,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "xla/backends/cpu/buffer_allocation_info.h" +#include "xla/backends/cpu/buffer_allocation_info_util.h" #include "xla/backends/cpu/codegen/compiled_function_library.h" #include "xla/client/client_library.h" #include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" -#include "xla/cpu_function_runtime.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/service/cpu/buffer_info_util.h" #include "xla/service/cpu/cpu_aot_compilation_result.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/service/platform_util.h" @@ -62,10 +62,10 @@ absl::StatusOr ComputeResultIndex( // Returns the number of results. int CountResults( - absl::Span buffer_infos) { + absl::Span buffer_infos) { int num_results = 0; for (const auto& info : buffer_infos) { - if (info.is_result_parameter()) { + if (info.is_result()) { ++num_results; } } @@ -76,12 +76,12 @@ int CountResults( // tf2xla::{Feed,Fetch,Variable}. We hold the actual strings in nonempty_names, // and hold arrays of pointers in name_ptrs, terminated by a nullptr entry. template -void CollectNames(const T& entries, std::vector* nonempty_names, +void CollectNames(const T& entries, std::vector* nonempty_names, std::vector* name_ptrs) { // First collect `nonempty_names`, to ensure the underlying strings won't // change out from under us. for (const auto& entry : entries) { - const string& name = entry.name(); + const std::string& name = entry.name(); if (!name.empty()) { nonempty_names->push_back(name); } @@ -90,7 +90,7 @@ void CollectNames(const T& entries, std::vector* nonempty_names, name_ptrs->reserve(entries.size() + 1); // +1 for nullptr array terminator size_t nonempty_index = 0; for (const auto& entry : entries) { - const string& name = entry.name(); + const std::string& name = entry.name(); if (!name.empty()) { name_ptrs->push_back(nonempty_names->at(nonempty_index).c_str()); ++nonempty_index; @@ -150,13 +150,18 @@ XlaJitCompiledCpuFunction::Compile( cpu_executable->buffer_assignment(); // Compute buffer infos and the result index, needed to run the raw function. - std::vector buffer_infos = - xla::cpu::CreateBufferInfosFromBufferAssignment(cpu_executable->module(), - buffer_assignment); - std::vector arg_index_table = - xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); - std::vector result_index_table = - xla::cpu::CreateResultIndexTableFromBufferInfos(buffer_infos); + std::vector buffer_infos = + xla::cpu::CreateBufferAllocationInfos(cpu_executable->module(), + buffer_assignment); + + std::vector buffer_allocation_infos = + xla::cpu::CreateBufferAllocationInfos(cpu_executable->module(), + buffer_assignment); + + std::vector arg_index_table = + xla::cpu::CreateArgIndexTable(buffer_infos); + std::vector result_index_table = + xla::cpu::CreateResultIndexTable(buffer_infos); TF_ASSIGN_OR_RETURN(size_t result_index, ComputeResultIndex(buffer_assignment)); const int num_results = CountResults(buffer_infos); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index 8d142ffbe3254f..6f61f472a2fd5a 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -22,10 +22,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h" +#include "xla/backends/cpu/buffer_allocation_info.h" #include "xla/client/local_client.h" -#include "xla/cpu_function_runtime.h" #include "xla/service/cpu/executable.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/types.h" @@ -82,20 +83,20 @@ class XlaJitCompiledCpuFunction { XlaCompiledCpuFunction::StaticData static_data_; // The backing array for buffer infos. - std::vector buffer_infos_; + std::vector buffer_infos_; // The backing array for the arg index table. - std::vector arg_index_table_; + std::vector arg_index_table_; // The backing array for the result index table. - std::vector result_index_table_; + std::vector result_index_table_; // The backing arrays of arg and result names. We hold the actual strings in // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static // data to refer to. - std::vector nonempty_arg_names_; - std::vector nonempty_variable_names_; - std::vector nonempty_result_names_; + std::vector nonempty_arg_names_; + std::vector nonempty_variable_names_; + std::vector nonempty_result_names_; std::vector arg_names_; std::vector variable_names_; std::vector result_names_; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index acac1efd73881f..b49e699d6e267f 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -182,18 +182,18 @@ TEST(XlaJitCompiledCpuFunction, Sum) { ASSERT_EQ(function.num_results(), 1); // Run the function and check results. - *static_cast(function.arg_data(0)) = 10; - *static_cast(function.arg_data(1)) = 32; + *static_cast(function.arg_data(0)) = 10; + *static_cast(function.arg_data(1)) = 32; EXPECT_TRUE(function.Run()); EXPECT_EQ(function.error_msg(), ""); - EXPECT_EQ(*static_cast(function.result_data(0)), 42); + EXPECT_EQ(*static_cast(function.result_data(0)), 42); // Run the function again. - *static_cast(function.arg_data(0)) = 100; - *static_cast(function.arg_data(1)) = 320; + *static_cast(function.arg_data(0)) = 100; + *static_cast(function.arg_data(1)) = 320; EXPECT_TRUE(function.Run()); EXPECT_EQ(function.error_msg(), ""); - EXPECT_EQ(*static_cast(function.result_data(0)), 420); + EXPECT_EQ(*static_cast(function.result_data(0)), 420); // Check name to index lookups. EXPECT_TRUE(function.HasNameIndices()); @@ -268,20 +268,20 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) { ASSERT_EQ(function.num_results(), 2); // Run the function and check results. - *static_cast(function.arg_data(0)) = 10; - *static_cast(function.arg_data(1)) = 32; + *static_cast(function.arg_data(0)) = 10; + *static_cast(function.arg_data(1)) = 32; EXPECT_TRUE(function.Run()); EXPECT_EQ(function.error_msg(), ""); - EXPECT_EQ(*static_cast(function.result_data(0)), 10); - EXPECT_EQ(*static_cast(function.result_data(1)), 42); + EXPECT_EQ(*static_cast(function.result_data(0)), 10); + EXPECT_EQ(*static_cast(function.result_data(1)), 42); // Run the function again. - *static_cast(function.arg_data(0)) = 100; - *static_cast(function.arg_data(1)) = 320; + *static_cast(function.arg_data(0)) = 100; + *static_cast(function.arg_data(1)) = 320; EXPECT_TRUE(function.Run()); EXPECT_EQ(function.error_msg(), ""); - EXPECT_EQ(*static_cast(function.result_data(0)), 100); - EXPECT_EQ(*static_cast(function.result_data(1)), 420); + EXPECT_EQ(*static_cast(function.result_data(0)), 100); + EXPECT_EQ(*static_cast(function.result_data(1)), 420); // Check name to index lookups. EXPECT_TRUE(function.HasNameIndices()); @@ -325,7 +325,7 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { int VisibleDeviceCount() const override { return 0; } - const string& Name() const override { return name_; } + const std::string& Name() const override { return name_; } absl::StatusOr> DescriptionForDevice( int ordinal) const override { @@ -338,7 +338,7 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { } private: - string name_; + std::string name_; }; TF_EXPECT_OK( diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 4a570827029330..baefe0138d43dd 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -207,9 +207,9 @@ static absl::Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S16) { - *out = literal.Get({}); + *out = literal.Get({}); } else if (literal.shape().element_type() == xla::S32) { - *out = literal.Get({}); + *out = literal.Get({}); } else if (literal.shape().element_type() == xla::S64) { *out = literal.Get({}); } else { @@ -370,7 +370,7 @@ static absl::Status LiteralToInt64Vector(const xla::LiteralSlice& literal, int64_t size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { for (int64_t i = 0; i < size; ++i) { - out->push_back(literal.Get({i})); + out->push_back(literal.Get({i})); } } else if (literal.shape().element_type() == xla::S64) { for (int64_t i = 0; i < size; ++i) { @@ -422,7 +422,7 @@ absl::Status XlaOpKernelContext::ConstantInputAsInt64Literal( case xla::S32: { *out = xla::Literal( xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64)); - auto src_data = literal.data(); + auto src_data = literal.data(); for (int64_t i = 0; i < src_data.size(); ++i) { out->data()[i] = src_data[i]; } @@ -677,7 +677,7 @@ xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) { return type; } -void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { +void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp handle) { SetOutputExpression( index, XlaExpression::XlaOp(handle, context_->expected_output_dtype(index))); @@ -688,7 +688,7 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { } void XlaOpKernelContext::SetTensorListOutput(int index, - const xla::XlaOp& handle) { + const xla::XlaOp handle) { SetOutputExpression(index, XlaExpression::TensorList(handle)); } @@ -811,7 +811,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; - CHECK(context_->input(name, &tensor).ok()); + CHECK_OK(context_->input(name, &tensor)); return *tensor; } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index b0830d0766acb2..30de5a796d03a1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -249,7 +249,7 @@ class XlaOpKernelContext { // Sets output `index` to the XlaOp `handle`. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. - void SetOutput(int index, const xla::XlaOp& handle); + void SetOutput(int index, xla::XlaOp handle); // Sets output `index` to compile-time constant `host_tensor`, where // `host_tensor` is a tensor in host memory. It is preferable to use @@ -260,7 +260,7 @@ class XlaOpKernelContext { void SetOutputExpression(int index, const XlaExpression& expression); // Sets output `index` to the Tensor List `handle`. - void SetTensorListOutput(int index, const xla::XlaOp& handle); + void SetTensorListOutput(int index, xla::XlaOp handle); // Status handling. void SetStatus(const absl::Status& status) { context_->SetStatus(status); } @@ -341,27 +341,27 @@ class XlaOpKernelContext { // Gets an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::XlaComputation* GetOrCreateMax(const DataType type); + const xla::XlaComputation* GetOrCreateMax(DataType type); // Gets an XLA lambda to compute Min. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::XlaComputation* GetOrCreateMin(const DataType type); + const xla::XlaComputation* GetOrCreateMin(DataType type); // Gets an XLA lambda to compute Add. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::XlaComputation* GetOrCreateAdd(const DataType type); + const xla::XlaComputation* GetOrCreateAdd(DataType type); // Gets an XLA lambda to compute LogAddExp. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::XlaComputation* GetOrCreateLogAddExp(const DataType type); + const xla::XlaComputation* GetOrCreateLogAddExp(DataType type); // Gets an XLA lambda to compute Mul. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::XlaComputation* GetOrCreateMul(const DataType type); + const xla::XlaComputation* GetOrCreateMul(DataType type); // Returns stack trace encoded as a string at a given module, or an empty // string if none found. diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 445065971f2a6a..c74db865769229 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -61,7 +61,7 @@ static absl::Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { NodeDef node_def; node_def.set_name("_XlaLaunch-op"); node_def.set_op("XlaLaunch"); - string kernel_class_name; + std::string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); VLOG(1) << "LaunchOpHasKernelForDevice" @@ -128,7 +128,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; } /* static */ void XlaOpRegistry::RegisterCompilationDevice( - const string& device_name, const DeviceRegistration& registration) { + const std::string& device_name, const DeviceRegistration& registration) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto result = @@ -138,7 +138,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; } /* static */ void XlaOpRegistry::RegisterBackend( - const string& compilation_device_name, + const std::string& compilation_device_name, absl::Span supported_types, BackendOpFilter op_filter) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); @@ -151,14 +151,14 @@ XlaOpRegistry::~XlaOpRegistry() = default; } /* static */ bool XlaOpRegistry::IsCompilationDevice( - const string& device_name) { + const std::string& device_name) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); return registry.backends_.find(device_name) != registry.backends_.end(); } /* static */ bool XlaOpRegistry::GetCompilationDevice( - const string& device_name, const DeviceRegistration** registration) { + const std::string& device_name, const DeviceRegistration** registration) { XlaOpRegistry& registry = Instance(); // Lazily register the CPU and GPU JIT devices the first time @@ -235,7 +235,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { // 2. Process op registration without device allowlists: // this pass registers the kernels for all the other supported backends. for (auto& ops : registry.ops_) { - const string& op_name = ops.first; + const std::string& op_name = ops.first; std::vector>& op_registrations = ops.second; // Partition the op registration so that the ones with device allowlists // precede the one without device allowlist. @@ -247,7 +247,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { // Collect a set of backend registered by ops with device allowlists. // The op registration without allowlists will register a generic kernel // for all other backends not in this set. - std::unordered_set allowlisted_backend; + std::unordered_set allowlisted_backend; for (auto& op_registration : op_registrations) { if (op_registration->has_device_allowlist) { allowlisted_backend.insert(op_registration->device_allowlist.begin(), @@ -267,7 +267,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { } TF_CHECK_OK(lookup_status); - std::unordered_set type_attrs; + std::unordered_set type_attrs; for (const OpDef::AttrDef& attr_def : op_def->attr()) { if (attr_def.type() == "type" || attr_def.type() == "list(type)") { type_attrs.insert(attr_def.name()); @@ -309,7 +309,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { // b) the types allowed by the OpDef, and // c) the type constraints. bool unsatisfiable_type_constraint = false; - for (const string& type_attr : type_attrs) { + for (const std::string& type_attr : type_attrs) { KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); attr_constraint->set_name(type_attr); auto* allowed_values = @@ -375,7 +375,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { } std::vector XlaOpRegistry::DeviceKernels( - const string& compilation_device_name, + const std::string& compilation_device_name, bool include_compilation_only_kernels) { // Ensure compilation kernels registered. RegisterCompilationKernels(); @@ -403,8 +403,8 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } -/*static*/ std::vector XlaOpRegistry::GetAllRegisteredOps() { - std::vector ops; +/*static*/ std::vector XlaOpRegistry::GetAllRegisteredOps() { + std::vector ops; XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); ops.reserve(registry.ops_.size()); @@ -416,7 +416,7 @@ std::vector XlaOpRegistry::DeviceKernels( } /*static*/ const std::unordered_set* -XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { +XlaOpRegistry::CompileTimeConstantInputArgNames(const std::string& op) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto it = registry.ops_.find(op); @@ -435,10 +435,10 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { DCHECK(op_def != nullptr || op_kernel != nullptr); - std::unordered_set compile_time_constant_inputs_from_attr; - std::vector compile_time_constant_inputs_vect_from_attr; + std::unordered_set compile_time_constant_inputs_from_attr; + std::vector compile_time_constant_inputs_vect_from_attr; - const std::unordered_set* compile_time_constant_inputs; + const std::unordered_set* compile_time_constant_inputs; if (TryGetNodeAttr(node_def, kXlaCompileTimeConstantInputsAttr, &compile_time_constant_inputs_vect_from_attr)) { @@ -459,7 +459,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { << " required constants are: " << absl::StrJoin(*compile_time_constant_inputs, ", "); - for (const string& input : *compile_time_constant_inputs) { + for (const std::string& input : *compile_time_constant_inputs) { if (op_def) { NameRangeMap input_name_ranges; TF_RETURN_IF_ERROR( @@ -486,7 +486,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { return absl::OkStatus(); } -/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { +/*static*/ bool XlaOpRegistry::IsMetadataOp(const std::string& op) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto it = registry.ops_.find(op); @@ -500,8 +500,8 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { return it->second.front()->is_metadata_op; } -std::vector XlaOpRegistry::BackendNames() { - std::vector names; +std::vector XlaOpRegistry::BackendNames() { + std::vector names; XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); names.reserve(registry.backends_.size()); @@ -511,7 +511,7 @@ std::vector XlaOpRegistry::BackendNames() { return names; } -bool XlaOpRegistry::IsBackendRegistered(const string& name) { +bool XlaOpRegistry::IsBackendRegistered(const std::string& name) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); return registry.backends_.find(name) != registry.backends_.end(); @@ -524,7 +524,7 @@ XlaOpRegistry& XlaOpRegistry::Instance() { XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = string(name); + registration_->name = std::string(name); } XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name( @@ -572,7 +572,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowStringType() { XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( absl::string_view attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[string(attr_name)]; + registration_->type_constraints[std::string(attr_name)]; types.insert(allowed); return *this; } @@ -580,7 +580,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( absl::string_view attr_name, absl::Span allowed) { std::set& types = - registration_->type_constraints[string(attr_name)]; + registration_->type_constraints[std::string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -628,7 +628,7 @@ XlaBackendRegistrar::XlaBackendRegistrar( absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(string(name), types, op_filter); + registry.RegisterBackend(std::string(name), types, op_filter); AddSymbolicExecutionDevice(name); } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 5eaf0fb2d42bfa..9ce6e263f8feb4 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -139,7 +139,7 @@ class XlaOpRegistry { // Describes how to compile operators assigned to a device. struct DeviceRegistration { // The name of the an XLA compilation device to use to compile code. - string compilation_device_name; + std::string compilation_device_name; // When should we autocluster operators assigned to this device? AutoclusteringPolicy autoclustering_policy; @@ -190,25 +190,25 @@ class XlaOpRegistry { // `backend_op_filter` should return true if the op should be registered on // the device; it may optionally modify the KernelDef. typedef bool (*BackendOpFilter)(KernelDef* kdef); - static void RegisterBackend(const string& compilation_device_name, + static void RegisterBackend(const std::string& compilation_device_name, absl::Span supported_types, BackendOpFilter op_filter); // Returns the names of the registered backends. - static std::vector BackendNames(); + static std::vector BackendNames(); // Returns true iff a backend with the given name is registered. - static bool IsBackendRegistered(const string& name); + static bool IsBackendRegistered(const std::string& name); // Registers `device_name` for XLA compilation, using information from // `registration`. // Does nothing if a registration for `device_name` already exists. - static void RegisterCompilationDevice(const string& device_name, + static void RegisterCompilationDevice(const std::string& device_name, const DeviceRegistration& registration); // Returns whether the device name is for the JIT device used exclusively for // TF2XLA conversion. - static bool IsCompilationDevice(const string& device_name); + static bool IsCompilationDevice(const std::string& device_name); // Returns the JIT device name associated with 'device_name', setting // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they @@ -216,7 +216,7 @@ class XlaOpRegistry { // JIT device is registered. // '*enable_jit_by_default' is set to true if we should try to JIT using this // device when the JIT is enabled via the Session OptimizerOptions. - static bool GetCompilationDevice(const string& device_name, + static bool GetCompilationDevice(const std::string& device_name, const DeviceRegistration** registration); // Registers all JIT kernels on JIT devices, if not already registered. @@ -227,11 +227,11 @@ class XlaOpRegistry { // 'compilation_device_name'. Does not include kernels registered as // CompilationOnly, iff include_compilation_only_kernels=false. static std::vector DeviceKernels( - const string& compilation_device_name, + const std::string& compilation_device_name, bool include_compilation_only_kernels); // Returns all operations for which there are XLA kernels on any device. - static std::vector GetAllRegisteredOps(); + static std::vector GetAllRegisteredOps(); // Returns (via `result`) the indices of inputs to `node_def` that must be // compile-time constants. Returns an empty vector if the op is not @@ -265,11 +265,11 @@ class XlaOpRegistry { // Return names of arguments for a given op which are supposed to be // constants. static const std::unordered_set* - CompileTimeConstantInputArgNames(const string& op); + CompileTimeConstantInputArgNames(const std::string& op); // Returns true if `op` is a "metadata" op, one that only looks at the shapes // of its operands and not their values. - static bool IsMetadataOp(const string& op); + static bool IsMetadataOp(const std::string& op); private: friend class XlaBackendRegistrar; @@ -298,15 +298,15 @@ class XlaOpRegistry { }; // Map from compilation device names to a description of the backend. - std::unordered_map backends_ TF_GUARDED_BY(mutex_); + std::unordered_map backends_ TF_GUARDED_BY(mutex_); // Map from Tensorflow device names to the corresponding JIT device metadata. - std::unordered_map compilation_devices_ + std::unordered_map compilation_devices_ TF_GUARDED_BY(mutex_); // A description of a Tensorflow operator that can be compiled to XLA. struct OpRegistration { - string name; + std::string name; // Should this operator be registered only on compilation devices, without a // dummy kernel registered on the corresponding XLA device? @@ -325,15 +325,15 @@ class XlaOpRegistry { bool allow_string_type = false; // Mapping from attribute name to a list of supported types. - std::unordered_map> type_constraints; + std::unordered_map> type_constraints; // An optional allowlist of devices. If there is no allowlist, all devices // are permitted. bool has_device_allowlist = false; - std::unordered_set device_allowlist; + std::unordered_set device_allowlist; // Names of arguments that must be compile-time constants. - std::unordered_set compile_time_constant_inputs; + std::unordered_set compile_time_constant_inputs; // True if this is a "metadata" op, one that only looks at the shapes of its // operands and not their values. @@ -360,8 +360,8 @@ class XlaOpRegistry { // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Registrations present under the same key must satisfy IsCompatible above, // and this is checked during registration. - std::unordered_map>> ops_ - TF_GUARDED_BY(mutex_); + std::unordered_map>> + ops_ TF_GUARDED_BY(mutex_); // Have we already registered the JIT kernels on the JIT devices? bool jit_kernels_registered_ = false; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 5b894d07e121ba..962b0e473a826c 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -51,29 +51,29 @@ namespace tensorflow { } /*static*/ std::unique_ptr XlaResource::CreateStack( - string name, DataType type, int64_t max_size) { + std::string name, DataType type, int64_t max_size) { return std::make_unique( XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(), /*initial_value=*/xla::XlaOp(), /*max_array_size=*/max_size, - /*tensor_array_gradients=*/std::set{}, + /*tensor_array_gradients=*/std::set{}, /*tensor_array_multiple_writes_aggregate=*/false); } /*static*/ std::unique_ptr XlaResource::CreateTensorArray( - string name, DataType type, TensorShape shape, xla::XlaOp initial_value, - int64_t max_array_size) { + std::string name, DataType type, TensorShape shape, + xla::XlaOp initial_value, int64_t max_array_size) { return std::make_unique( XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape, initial_value, max_array_size, - /*tensor_array_gradients=*/std::set{}, + /*tensor_array_gradients=*/std::set{}, /*tensor_array_multiple_writes_aggregate=*/false); } XlaResource::XlaResource( - Kind kind, int arg_num, string name, DataType type, TensorShape shape, + Kind kind, int arg_num, std::string name, DataType type, TensorShape shape, xla::XlaOp initial_value, int64_t max_array_size, - const std::set& tensor_array_gradients, + const std::set& tensor_array_gradients, bool tensor_array_multiple_writes_aggregate, const std::optional& definition_stack_trace) : kind_(kind), @@ -89,7 +89,7 @@ XlaResource::XlaResource( definition_stack_trace_(definition_stack_trace) { CHECK(kind_ != kInvalid); - for (const string& gradient : tensor_array_gradients) { + for (const std::string& gradient : tensor_array_gradients) { tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, @@ -163,7 +163,7 @@ absl::Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { value_ = xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()), - xla::ConstantR0(builder, 0)}); + xla::ConstantR0(builder, 0)}); break; } @@ -175,7 +175,7 @@ absl::Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } absl::Status XlaResource::GetOrCreateTensorArrayGradient( - const string& source, xla::XlaBuilder* builder, + const std::string& source, xla::XlaBuilder* builder, XlaResource** gradient_out) { VLOG(2) << "Gradient lookup for resource: " << name_ << " gradient: " << source; @@ -214,9 +214,9 @@ absl::Status XlaResource::Pack(xla::XlaOp* pack, return absl::OkStatus(); } -absl::Status XlaResource::SetFromPack(const std::set& gradient_sources, - const xla::XlaOp pack, - xla::XlaBuilder* builder) { +absl::Status XlaResource::SetFromPack( + const std::set& gradient_sources, const xla::XlaOp pack, + xla::XlaBuilder* builder) { if (gradient_sources.empty()) { if (!initialized()) { initial_value_ = pack; diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index d4c8f7c1c9347f..07c826d21e8b3d 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -43,18 +43,19 @@ class XlaResource { static absl::string_view KindToString(Kind kind); // Creates a new Stack resource. - static std::unique_ptr CreateStack(string name, DataType type, + static std::unique_ptr CreateStack(std::string name, + DataType type, int64_t max_size); // Creates a new TensorArray resource. static std::unique_ptr CreateTensorArray( - string name, DataType type, TensorShape shape, xla::XlaOp initial_value, - int64_t max_array_size); + std::string name, DataType type, TensorShape shape, + xla::XlaOp initial_value, int64_t max_array_size); - XlaResource(Kind kind, int arg_num, string name, DataType type, + XlaResource(Kind kind, int arg_num, std::string name, DataType type, TensorShape shape, xla::XlaOp initial_value, int64_t max_array_size, - const std::set& tensor_array_gradients, + const std::set& tensor_array_gradients, bool tensor_array_multiple_writes_aggregate, const std::optional& definition_stack_trace = std::nullopt); @@ -72,7 +73,7 @@ class XlaResource { int arg_num() const { return arg_num_; } // A descriptive name for the resource, used in error messages. - const string& name() const { return name_; } + const std::string& name() const { return name_; } // Current type and value of the resource. Uninitialized resources are // represented by a default (zero) handle and type DT_INVALID. @@ -121,7 +122,7 @@ class XlaResource { // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator // documentation for TensorArrayGradV3 for details. - absl::Status GetOrCreateTensorArrayGradient(const string& source, + absl::Status GetOrCreateTensorArrayGradient(const std::string& source, xla::XlaBuilder* builder, XlaResource** gradient_out); @@ -138,7 +139,7 @@ class XlaResource { // If `reset_initial_values` is true, sets the initial_values as well as the // values. // Opposite of Pack(). - absl::Status SetFromPack(const std::set& gradient_sources, + absl::Status SetFromPack(const std::set& gradient_sources, xla::XlaOp pack, xla::XlaBuilder* builder); bool IsOverwritten() { return is_overwritten_; } @@ -164,15 +165,15 @@ class XlaResource { // string, irrespective of the number of calls to TensorArrayGrad. The map // is ordered since values are packed into tuples by Pack() sorted by name // order. - const std::map>& tensor_array_gradients() - const { + const std::map>& + tensor_array_gradients() const { return tensor_array_gradients_; } private: const Kind kind_; const int arg_num_; - const string name_; + const std::string name_; DataType type_; TensorShape shape_; @@ -186,7 +187,7 @@ class XlaResource { int64_t max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; - std::map> tensor_array_gradients_; + std::map> tensor_array_gradients_; bool is_overwritten_ = false; std::optional definition_stack_trace_; diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 29aea47709cd6d..dd8da3665c5294 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -493,7 +493,6 @@ cc_library( "@local_tsl//tsl/platform:framework_lite_hdrs", "@local_xla//xla/tsl/framework:numeric_types.h", "@local_xla//xla/tsl/framework:type_traits.h", - "@local_xla//xla/tsl/platform/default:integral_types.h", ], visibility = ["//visibility:public"], deps = [ @@ -1537,7 +1536,6 @@ cc_library( hdrs = [ "//tensorflow/core/platform:tflite_portable_logging_hdrs", "@local_tsl//tsl/platform:tflite_portable_logging_hdrs", - "@local_xla//xla/tsl/platform/default:integral_types.h", ], compatible_with = get_compatible_with_portable(), copts = tf_copts(), @@ -1938,21 +1936,15 @@ tf_cc_tests( ) tf_cc_tests( - name = "cell_reader_test", + name = "test_utils_test", size = "small", srcs = [ - "//tensorflow/core/lib/monitoring:cell_reader_test.cc", "//tensorflow/core/lib/monitoring:test_utils_test.cc", ], deps = [ ":protos_all_cc", ":test", ":test_main", - "//tensorflow/core/lib/monitoring:cell_reader", - "//tensorflow/core/lib/monitoring:counter", - "//tensorflow/core/lib/monitoring:gauge", - "//tensorflow/core/lib/monitoring:percentile_sampler", - "//tensorflow/core/lib/monitoring:sampler", "//tensorflow/core/lib/monitoring:test_utils", "//tensorflow/core/lib/monitoring:types", "//tensorflow/core/platform:errors", diff --git a/tensorflow/core/activity_watcher/activity.h b/tensorflow/core/activity_watcher/activity.h index eecd207a33fe27..fba51b43f8a3ce 100644 --- a/tensorflow/core/activity_watcher/activity.h +++ b/tensorflow/core/activity_watcher/activity.h @@ -32,7 +32,7 @@ namespace tensorflow { namespace activity_watcher { -using ActivityId = tsl::uint64; +using ActivityId = uint64_t; constexpr ActivityId kActivityNotRecorded = 0; constexpr int kWatcherDisabled = 0; @@ -45,7 +45,7 @@ enum ActivityCategory { kRendezvous = 5, }; -static tsl::string ToString(ActivityCategory category) { +static std::string ToString(ActivityCategory category) { switch (category) { case ActivityCategory::kCollective: return "Collective"; @@ -64,17 +64,17 @@ static tsl::string ToString(ActivityCategory category) { // An activity to be recorded. struct Activity { - using Attributes = absl::flat_hash_map; + using Attributes = absl::flat_hash_map; // A human readable title of the activity. - tsl::string title; + std::string title; // The category of the activity. ActivityCategory category = ActivityCategory::kMisc; // Key/value pairs that are attached to the activity. Attributes attributes; Activity() = default; - Activity(tsl::string title, ActivityCategory category) + Activity(std::string title, ActivityCategory category) : title(std::move(title)), category(category) {} - Activity(tsl::string title, ActivityCategory category, Attributes attributes) + Activity(std::string title, ActivityCategory category, Attributes attributes) : title(std::move(title)), category(category), attributes(std::move(attributes)) {} diff --git a/tensorflow/core/activity_watcher/activity_utils.cc b/tensorflow/core/activity_watcher/activity_utils.cc index b3631076c5c2d9..58b3909a25789c 100644 --- a/tensorflow/core/activity_watcher/activity_utils.cc +++ b/tensorflow/core/activity_watcher/activity_utils.cc @@ -28,7 +28,7 @@ namespace tensorflow { namespace activity_watcher { std::unique_ptr ActivityFromContext( - OpKernelContext* context, tsl::string name, ActivityCategory category, + OpKernelContext* context, std::string name, ActivityCategory category, Activity::Attributes additional_attributes) { Activity::Attributes attributes(std::move(additional_attributes)); if (context) { diff --git a/tensorflow/core/activity_watcher/activity_utils.h b/tensorflow/core/activity_watcher/activity_utils.h index 64958cd5e09744..749ef1326ae565 100644 --- a/tensorflow/core/activity_watcher/activity_utils.h +++ b/tensorflow/core/activity_watcher/activity_utils.h @@ -29,7 +29,7 @@ namespace activity_watcher { // A convenient way to create an activity. Writes OpKernelContext information // and given attributes to a new activity and returns. std::unique_ptr ActivityFromContext( - OpKernelContext* context, tsl::string name, ActivityCategory category, + OpKernelContext* context, std::string name, ActivityCategory category, Activity::Attributes additional_attributes = Activity::Attributes()); } // namespace activity_watcher diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 76b8cc01324619..caf20c11b93566 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -65,6 +65,7 @@ cc_library( "//tensorflow/core:op_gen_lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 7f844e88ba90c6..3c954cf076ddc8 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -43,26 +43,27 @@ namespace { constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; -string DefaultApiDefDir() { +std::string DefaultApiDefDir() { return GetDataDependencyFilepath( io::JoinPath("tensorflow", "core", "api_def", "base_api")); } -string PythonApiDefDir() { +std::string PythonApiDefDir() { return GetDataDependencyFilepath( io::JoinPath("tensorflow", "core", "api_def", "python_api")); } // Reads golden ApiDef files and returns a map from file name to ApiDef file // contents. -void GetGoldenApiDefs(Env* env, const string& api_files_dir, - std::unordered_map* name_to_api_def) { - std::vector matching_paths; +void GetGoldenApiDefs( + Env* env, const std::string& api_files_dir, + std::unordered_map* name_to_api_def) { + std::vector matching_paths; TF_CHECK_OK(env->GetMatchingPaths( io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths)); for (auto& file_path : matching_paths) { - string file_contents; + std::string file_contents; TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents)); file_contents = PBTxtFromMultiline(file_contents); @@ -76,8 +77,9 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir, } void TestAllApiDefsHaveCorrespondingOp( - const OpList& ops, const std::unordered_map& api_defs_map) { - std::unordered_set op_names; + const OpList& ops, + const std::unordered_map& api_defs_map) { + std::unordered_set op_names; for (const auto& op : ops.op()) { op_names.insert(op.name()); } @@ -89,7 +91,8 @@ void TestAllApiDefsHaveCorrespondingOp( } void TestAllApiDefInputArgsAreValid( - const OpList& ops, const std::unordered_map& api_defs_map) { + const OpList& ops, + const std::unordered_map& api_defs_map) { for (const auto& op : ops.op()) { const auto api_def_iter = api_defs_map.find(op.name()); if (api_def_iter == api_defs_map.end()) { @@ -113,7 +116,8 @@ void TestAllApiDefInputArgsAreValid( } void TestAllApiDefOutputArgsAreValid( - const OpList& ops, const std::unordered_map& api_defs_map) { + const OpList& ops, + const std::unordered_map& api_defs_map) { for (const auto& op : ops.op()) { const auto api_def_iter = api_defs_map.find(op.name()); if (api_def_iter == api_defs_map.end()) { @@ -137,7 +141,8 @@ void TestAllApiDefOutputArgsAreValid( } void TestAllApiDefAttributeNamesAreValid( - const OpList& ops, const std::unordered_map& api_defs_map) { + const OpList& ops, + const std::unordered_map& api_defs_map) { for (const auto& op : ops.op()) { const auto api_def_iter = api_defs_map.find(op.name()); if (api_def_iter == api_defs_map.end()) { @@ -159,7 +164,7 @@ void TestAllApiDefAttributeNamesAreValid( } void TestDeprecatedAttributesSetCorrectly( - const std::unordered_map& api_defs_map) { + const std::unordered_map& api_defs_map) { for (const auto& name_and_api_def : api_defs_map) { int num_deprecated_endpoints = 0; const auto& api_def = name_and_api_def.second; @@ -186,7 +191,7 @@ void TestDeprecatedAttributesSetCorrectly( } void TestDeprecationVersionSetCorrectly( - const std::unordered_map& api_defs_map) { + const std::unordered_map& api_defs_map) { for (const auto& name_and_api_def : api_defs_map) { const auto& name = name_and_api_def.first; const auto& api_def = name_and_api_def.second; @@ -205,13 +210,13 @@ class BaseApiTest : public ::testing::Test { protected: BaseApiTest() { OpRegistry::Global()->Export(false, &ops_); - const std::vector multi_line_fields = {"description"}; + const std::vector multi_line_fields = {"description"}; Env* env = Env::Default(); GetGoldenApiDefs(env, DefaultApiDefDir(), &api_defs_map_); } OpList ops_; - std::unordered_map api_defs_map_; + std::unordered_map api_defs_map_; }; // Check that all ops have an ApiDef. @@ -233,7 +238,7 @@ TEST_F(BaseApiTest, AllApiDefsHaveCorrespondingOp) { TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_); } -string GetOpDefHasDocStringError(const string& op_name) { +std::string GetOpDefHasDocStringError(const std::string& op_name) { return strings::Printf( "OpDef for %s has a doc string. " "Doc strings must be defined in ApiDef instead of OpDef. " @@ -301,13 +306,13 @@ class PythonApiTest : public ::testing::Test { protected: PythonApiTest() { OpRegistry::Global()->Export(false, &ops_); - const std::vector multi_line_fields = {"description"}; + const std::vector multi_line_fields = {"description"}; Env* env = Env::Default(); GetGoldenApiDefs(env, PythonApiDefDir(), &api_defs_map_); } OpList ops_; - std::unordered_map api_defs_map_; + std::unordered_map api_defs_map_; }; // Check that ApiDefs have a corresponding op. diff --git a/tensorflow/core/api_def/base_api/api_def_ComplexAbs.pbtxt b/tensorflow/core/api_def/base_api/api_def_ComplexAbs.pbtxt index 7c4db1f721a032..41868ddc6c649f 100644 --- a/tensorflow/core/api_def/base_api/api_def_ComplexAbs.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ComplexAbs.pbtxt @@ -1,5 +1,12 @@ op { graph_op_name: "ComplexAbs" + attr { + name: "Tout" + description: <