-
Notifications
You must be signed in to change notification settings - Fork 2
Tensor Read/Write Handles #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces explicit read and write tensor handle APIs so clients can acquire and release IPC-based access to tensor shards with server-side locking, and refactors tensor storage on the NodeAgent to use a concurrent TensorShard map. It also wires new request/response message types through the messaging layer and exposes the new capabilities to Python, with updated end-to-end tests validating read/write and concurrent-read behavior.
Changes:
- Add protocol-level GetReadHandle/ReleaseReadHandle and GetWriteHandle/ReleaseWriteHandle messages plus corresponding client-side methods and Python bindings, including per-client UUIDs to identify lock ownership.
- Refactor NodeAgent to store allocated tensors as
TensorShardobjects in a concurrent map and to track active read/write locks viaTensorShardReadHandle/TensorShardWriteHandleper client and shard. - Update
test_client_node_agentGPU tests to use the new handle APIs, add coverage for write handles and multiple concurrent read handles, and adjust distributed allocation tests to verify via read handles.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
test/test_client_node_agent.py |
Updates existing tests to use get_read_handle and adds new GPU tests for write handles, concurrent read handles, and updated distributed allocation checks. |
csrc/setu/node_manager/NodeAgent.h |
Declares new read/write handle request handlers, introduces a concurrent TensorShardsConcurrentMap, and adds per-client read/write lock tracking maps within the Handler. |
csrc/setu/node_manager/NodeAgent.cpp |
Implements handling for Get/Release read/write handle requests using TensorShard and lock handles, updates the client message dispatch, and changes tensor allocation to store TensorShard objects in the concurrent map. |
csrc/setu/commons/messages/ReleaseWriteHandleResponse.h |
Defines the ReleaseWriteHandleResponse message struct, including serialization, deserialization, and a string representation. |
csrc/setu/commons/messages/ReleaseWriteHandleResponse.cpp |
Implements binary serialization and deserialization logic for ReleaseWriteHandleResponse. |
csrc/setu/commons/messages/ReleaseWriteHandleRequest.h |
Introduces ReleaseWriteHandleRequest, carrying ClientId and ShardId with validation and string formatting helpers. |
csrc/setu/commons/messages/ReleaseWriteHandleRequest.cpp |
Implements serialization and deserialization for ReleaseWriteHandleRequest. |
csrc/setu/commons/messages/ReleaseReadHandleResponse.h |
Defines the ReleaseReadHandleResponse struct for acknowledging read handle release. |
csrc/setu/commons/messages/ReleaseReadHandleResponse.cpp |
Implements serialization and deserialization for ReleaseReadHandleResponse. |
csrc/setu/commons/messages/ReleaseReadHandleRequest.h |
Adds ReleaseReadHandleRequest to request releasing a read handle, including ClientId and ShardId fields with validation. |
csrc/setu/commons/messages/ReleaseReadHandleRequest.cpp |
Implements serialization and deserialization for ReleaseReadHandleRequest. |
csrc/setu/commons/messages/Messages.h |
Replaces GetTensorHandle* in the unified message variants with the new Get/Release read/write handle messages and includes their headers, so dispatch and serialization work with the new API. |
csrc/setu/commons/messages/GetWriteHandleResponse.h |
Introduces GetWriteHandleResponse, carrying an optional TensorIPCSpec for write access and associated metadata. |
csrc/setu/commons/messages/GetWriteHandleResponse.cpp |
Implements serialization/deserialization of GetWriteHandleResponse, including the optional IPC spec. |
csrc/setu/commons/messages/GetWriteHandleRequest.h |
Defines GetWriteHandleRequest including ClientId and ShardId used to request a write handle. |
csrc/setu/commons/messages/GetWriteHandleRequest.cpp |
Implements serialization and deserialization for GetWriteHandleRequest. |
csrc/setu/commons/messages/GetReadHandleResponse.h |
Adds GetReadHandleResponse, paralleling the write variant but for read access. |
csrc/setu/commons/messages/GetReadHandleResponse.cpp |
Implements serialization/deserialization logic for GetReadHandleResponse. |
csrc/setu/commons/messages/GetReadHandleRequest.h |
Introduces GetReadHandleRequest holding ClientId and ShardId for requesting a read handle. |
csrc/setu/commons/messages/GetReadHandleRequest.cpp |
Implements serialization/deserialization for GetReadHandleRequest. |
csrc/setu/commons/datatypes/TensorShardHandle.h |
Adjusts TensorShardReadHandle/TensorShardWriteHandle to obtain device pointers via TensorShard::GetDevicePtr and adds GetTensor() accessors to expose the underlying tensor safely. |
csrc/setu/commons/datatypes/TensorShard.h |
Refactors TensorShard to own a torch::Tensor instead of a raw device pointer, adds accessors for the tensor and its device pointer, and defines TensorShardsConcurrentMap using a concurrent map type. |
csrc/setu/commons/datatypes/Pybind.cpp |
Updates the Python binding for TensorShard to construct from a torch::Tensor and exposes get_device_ptr/get_tensor methods instead of a raw device_ptr field. |
csrc/setu/commons/Types.h |
Introduces ClientId as a UUID alias to identify clients in protocol messages and lock tracking. |
csrc/setu/commons/BoostCommon.h |
Adds an alias ConcurrentMap intended to wrap Boost’s concurrent flat map and provides UUID helpers used to generate per-client IDs. |
csrc/setu/client/Pybind.cpp |
Exposes the new Client methods (get_read_handle, release_read_handle, get_write_handle, release_write_handle, get_client_id) to Python and updates the exported error code enum. |
csrc/setu/client/Client.h |
Extends the client API with read/write handle acquisition and release methods plus a GetClientId() accessor and introduces a ClientId member. |
csrc/setu/client/Client.cpp |
Implements the new client methods by sending/receiving the new handle messages, assigns a UUID-based client identity used on the ZMQ socket, and logs handle operations for debugging. |
| /// Active read locks: shard_id -> (client_id -> read handle) | ||
| /// Handle existence keeps the shared_lock alive | ||
| std::unordered_map<ClientId, ClientReadLockMap> active_read_locks_; | ||
|
|
||
| /// Active write locks: shard_id -> (client_id -> write handle) |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
active_read_locks_ and active_write_locks_ are only cleaned up when a client explicitly sends the corresponding Release*Handle request; if a client process crashes or loses connectivity while holding a lock, the associated handle entry (and therefore the shared/unique lock) will remain held for the lifetime of the NodeAgent process, potentially blocking future readers/writers on that shard. Consider adding a cleanup mechanism (e.g., on client disconnect or via lease/timeout) so that failed clients cannot hold locks indefinitely.
| using PriorityQueue = ::boost::concurrent::sync_priority_queue<Args...>; | ||
| template <typename Key, typename Value, typename Hash = boost::hash<Key>, | ||
| typename Pred = std::equal_to<Key>> | ||
| using ConcurrentMap = ::boost::concurrent_flat_map<Key, Value, Hash, Pred>; |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConcurrentMap is aliased to ::boost::concurrent_flat_map, but the header you include (<boost/unordered/concurrent_flat_map.hpp>) defines boost::unordered::concurrent_flat_map, so this alias will not compile. Please update the alias to reference the correct boost::unordered::concurrent_flat_map type so that TensorShardsConcurrentMap and other concurrent maps build correctly.
| using ConcurrentMap = ::boost::concurrent_flat_map<Key, Value, Hash, Pred>; | |
| using ConcurrentMap = | |
| ::boost::unordered::concurrent_flat_map<Key, Value, Hash, Pred>; |
| /// Active read locks: shard_id -> (client_id -> read handle) | ||
| /// Handle existence keeps the shared_lock alive | ||
| std::unordered_map<ClientId, ClientReadLockMap> active_read_locks_; | ||
|
|
||
| /// Active write locks: shard_id -> (client_id -> write handle) |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation comments for the lock-tracking maps describe the layout as shard_id -> (client_id -> handle), but the actual type is std::unordered_map<ClientId, ClientReadLockMap/ClientWriteLockMap>, i.e., client_id -> (shard_id -> handle). Please update the comments to match the data structure to avoid confusion for future maintainers reasoning about lock ownership.
| /// Active read locks: shard_id -> (client_id -> read handle) | |
| /// Handle existence keeps the shared_lock alive | |
| std::unordered_map<ClientId, ClientReadLockMap> active_read_locks_; | |
| /// Active write locks: shard_id -> (client_id -> write handle) | |
| /// Active read locks: client_id -> (shard_id -> read handle) | |
| /// Handle existence keeps the shared_lock alive | |
| std::unordered_map<ClientId, ClientReadLockMap> active_read_locks_; | |
| /// Active write locks: client_id -> (shard_id -> write handle) |
| TensorShard(TensorShardMetadata metadata_param, DevicePtr device_ptr_param) | ||
| : metadata(std::move(metadata_param)), device_ptr(device_ptr_param) { | ||
| ASSERT_VALID_POINTER_ARGUMENT(device_ptr_param); | ||
| TensorShard(TensorShardMetadata metadata_param, torch::Tensor tensor_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We must template this over the Tensor type. We musn't hardcode it to take torch::tensor
| using ClientWriteLockMap = | ||
| std::unordered_map<ShardId, TensorShardWriteHandlePtr>; | ||
|
|
||
| /// Active read locks: shard_id -> (client_id -> read handle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to spend more time fleshing out this design. This biggest issue I see is that attempting to acquire one write handle will deadlock the system. This is because the write lock will block, not yielding control back to the handler loop. This prevents other clients from releasing the read handle that is blocking the write to proceed.
This patch allows client to get and release tensor write and read handles.