-
Notifications
You must be signed in to change notification settings - Fork 2
Client Free Shards #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds shard lifecycle management so client-owned shards can be freed (both locally on NodeAgent and in MetaStore), introduces shard-registration validation (dtype/dim/overlap checks), and expands Python/C++ test coverage for these behaviors.
Changes:
- Add
FreeShardsRequest/Responsemessage flow: Client → NodeAgent → Coordinator → MetaStore (plus NodeAgent local tensor cleanup). - Add MetaStore shard validation (dtype/dim consistency, overlap prevention) and
FreeShardsupport. - Update Client shard tracking to group shards by tensor name and expose updated Python bindings; add new Python + C++ tests.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| test/test_client_node_agent.py | Adds multi-node validation tests and client-destruction shard-free tests; adjusts fixture scope. |
| csrc/test/native/MetaStoreTest.cpp | Adds C++ unit tests for shard validation and FreeShard behavior. |
| csrc/setu/node_manager/NodeAgent.h | Declares FreeShards request/response handlers and local free helper. |
| csrc/setu/node_manager/NodeAgent.cpp | Implements FreeShards handling and local tensor/metadata cleanup. |
| csrc/setu/metastore/MetaStore.h | Adds FreeShard API and validation helper; renames internal shard tracking struct/map. |
| csrc/setu/metastore/MetaStore.cpp | Implements registration validation and FreeShard; updates internal bookkeeping. |
| csrc/setu/coordinator/Coordinator.h | Adds handler declaration for FreeShardsRequest. |
| csrc/setu/coordinator/Coordinator.cpp | Wires FreeShardsRequest handling and registration failure propagation. |
| csrc/setu/commons/messages/Pybind.cpp | Exposes FreeShards request/response to Python. |
| csrc/setu/commons/messages/Messages.h | Extends message variants to include FreeShards request/response. |
| csrc/setu/commons/messages/FreeShardsRequest.h | Introduces FreeShardsRequest message definition. |
| csrc/setu/commons/messages/FreeShardsRequest.cpp | Implements FreeShardsRequest serialization/deserialization. |
| csrc/setu/commons/messages/FreeShardsResponse.h | Introduces FreeShardsResponse message definition. |
| csrc/setu/commons/messages/FreeShardsResponse.cpp | Implements FreeShardsResponse serialization/deserialization. |
| csrc/setu/commons/datatypes/TensorShardSpec.h | Adds Overlaps() API for shard-range overlap checks. |
| csrc/setu/commons/datatypes/TensorShardSpec.cpp | Implements Overlaps() overlap logic. |
| csrc/setu/client/Pybind.cpp | Updates Python get_shards() binding to flatten per-tensor shard map. |
| csrc/setu/client/Client.h | Changes shard storage to unordered_map<TensorName, vector<...>>; adds FreeShards() helper and new shard accessors. |
| csrc/setu/client/Client.cpp | Tracks shards per tensor and frees shards on destruction via new FreeShards request. |
Comments suppressed due to low confidence (1)
test/test_client_node_agent.py:374
multi_node_infrastructurestarts a NodeAgent pinned tocuda:1, but only checkstorch.cuda.is_available(). On single-GPU systems this will fail. Consider skipping whentorch.cuda.device_count() < 2(or choosing device indices based on available GPUs).
@pytest.fixture(scope="module")
def multi_node_infrastructure():
"""Start Coordinator and two NodeAgents for distributed tensor tests."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
| } | ||
|
|
||
| ErrorCode error_code = | ||
| all_freed ? ErrorCode::kSuccess : ErrorCode::kInvalidArguments; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious how we'll handle partial frees, I don't think it'll impact much right now but maybe it's better to communicate back information about the shards that couldn't be freed for whatever reason?
I guess we need to be extra careful or we could end up with inconsistent states between the NodeAgent and Coordinator as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree but probably it would be better to do a separate refactoring for communication and error handling because now it's pretty chaotic across the codebase
d285d0f to
9024b9c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 6 comments.
| registered_shards_data_.erase(tensor_name); | ||
| LOG_DEBUG("Removed empty tensor entry: {}", tensor_name); | ||
| } | ||
|
|
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After erasing the tensor_name entry from registered_shards_data_ on line 211, the function continues to use tensor_name in the LOG_DEBUG on line 212. While this is safe because tensor_name is a local reference to the key, it could be clearer. Additionally, the function returns inside the loop after finding and freeing a shard, which is correct. However, consider adding a comment explaining why early return is safe here (because shard IDs are unique across all tensors).
| registered_shards_data_.erase(tensor_name); | |
| LOG_DEBUG("Removed empty tensor entry: {}", tensor_name); | |
| } | |
| // Copy the tensor name so we can safely use it after erasing the map entry. | |
| const auto tensor_name_copy = tensor_name; | |
| registered_shards_data_.erase(tensor_name_copy); | |
| LOG_DEBUG("Removed empty tensor entry: {}", tensor_name_copy); | |
| } | |
| // Shard IDs are globally unique across all tensors, so it's safe to | |
| // return early once we've found and freed the matching shard. |
| void Coordinator::Handler::HandleFreeShardsRequest( | ||
| const Identity& node_agent_identity, const FreeShardsRequest& request) { | ||
| LOG_INFO("Coordinator received FreeShardsRequest with {} shards", | ||
| request.shard_ids.size()); | ||
|
|
||
| bool all_freed = true; | ||
| for (const auto& shard_id : request.shard_ids) { | ||
| if (!metastore_.FreeShard(shard_id)) { | ||
| LOG_WARNING("Failed to free shard: {}", shard_id); | ||
| all_freed = false; | ||
| } | ||
| } | ||
|
|
||
| ErrorCode error_code = | ||
| all_freed ? ErrorCode::kSuccess : ErrorCode::kInvalidArguments; | ||
| FreeShardsResponse response(request.request_id, error_code); | ||
| outbox_queue_.push(OutboxMessage{node_agent_identity, response}); | ||
| } |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When freeing multiple shards, if some succeed and some fail, the function sets error_code to kInvalidArguments. However, this doesn't provide information about which shards failed. Consider either: 1) using a more specific error code, 2) logging which specific shards failed to free for debugging, or 3) returning partial success information to the client. The current implementation makes it difficult to diagnose which shard caused the failure.
| // Free local tensor resources | ||
| for (const auto& shard_id : request.shard_ids) { | ||
| FreeTensor(shard_id); | ||
| } | ||
|
|
||
| request_id_to_client_identity_[request.request_id] = client_identity; | ||
|
|
||
| // Forward request to coordinator to update MetaStore |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NodeAgent frees local tensor resources before forwarding the request to the Coordinator. If the Coordinator's metastore update fails (e.g., shard doesn't exist in metastore), the local resources are already freed but the client receives an error. This could lead to state inconsistency. Consider either: 1) forwarding to the Coordinator first and only freeing local resources after receiving a successful response, or 2) accepting this ordering but documenting that local cleanup happens regardless of metastore success.
| // Free local tensor resources | |
| for (const auto& shard_id : request.shard_ids) { | |
| FreeTensor(shard_id); | |
| } | |
| request_id_to_client_identity_[request.request_id] = client_identity; | |
| // Forward request to coordinator to update MetaStore | |
| // NOTE: | |
| // We intentionally free local tensor resources (storage and metadata) | |
| // immediately upon receiving the FreeShardsRequest, *before* forwarding | |
| // the request to the Coordinator for MetaStore updates. | |
| // | |
| // This means that if the Coordinator's metastore update later fails | |
| // (for example, if a shard is unknown or already removed on the | |
| // Coordinator side), the local NodeAgent will still have already | |
| // removed its corresponding resources, while the client may observe | |
| // an error from the Coordinator. In other words, local cleanup is | |
| // performed regardless of metastore update success, and the | |
| // Coordinator's view is updated on a best-effort basis. | |
| // | |
| // This ordering is accepted by design; callers should not assume | |
| // that a failed metastore update implies that resources remain | |
| // allocated on the NodeAgent. | |
| for (const auto& shard_id : request.shard_ids) { | |
| FreeTensor(shard_id); | |
| } | |
| request_id_to_client_identity_[request.request_id] = client_identity; | |
| // Forward request to coordinator to update MetaStore (best-effort) |
| struct FreeShardsRequest : public BaseRequest { | ||
| /// @brief Constructs a request with auto-generated request ID. | ||
| explicit FreeShardsRequest(std::vector<ShardId> shard_ids_param) | ||
| : BaseRequest(), shard_ids(std::move(shard_ids_param)) { |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FreeShardsRequest constructor asserts that shard_ids cannot be empty. However, FreeShards() in Client.cpp checks if client_shards_ is empty and returns early (line 191-193). This means an empty FreeShardsRequest will never be created when client_shards_ is empty, so the assertion in the constructor will never trigger in this code path. Consider either removing the assertion if empty lists should be allowed, or documenting why empty requests are explicitly disallowed to catch programming errors in other code paths.
| : BaseRequest(), shard_ids(std::move(shard_ids_param)) { | |
| : BaseRequest(), shard_ids(std::move(shard_ids_param)) { | |
| // Enforce class invariant: a FreeShardsRequest must always reference at | |
| // least one shard. Some callers (e.g., client code) may already guard | |
| // against empty shard lists before constructing this object, but this | |
| // assertion defends against other incorrect call sites by treating an | |
| // empty list as a programming error. |
| bool TensorShardSpec::Overlaps(const TensorShardSpec& other) const { | ||
| // Shards overlap if and only if ALL dimensions overlap | ||
| for (std::size_t i = 0; i < dims.size(); ++i) { | ||
| // Ranges [start1, end1) and [start2, end2) overlap iff | ||
| // start1 < end2 && start2 < end1 | ||
| bool dim_overlaps = | ||
| dims[i].start < other.dims[i].end && other.dims[i].start < dims[i].end; | ||
| if (!dim_overlaps) { | ||
| return false; // Found non-overlapping dimension | ||
| } | ||
| } | ||
| return true; // All dimensions overlap | ||
| } |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Overlaps function assumes both TensorShardSpec objects have the same number of dimensions (as noted in the doc comment). However, the function doesn't validate this precondition and will access out-of-bounds indices if the dimension counts differ. While ValidateShardRegistration checks dimension count before calling Overlaps, consider adding an assertion at the start of Overlaps to catch misuse and make the precondition explicit in the code.
| ASSERT_VALID_ARGUMENTS(it != client_shards_.end(), | ||
| "No shards found for tensor: {}", name); |
Copilot
AI
Feb 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetShards throws an assertion if the tensor name doesn't exist. This is inconsistent with the return type pattern used elsewhere in the codebase. Consider returning an empty vector or using std::optional to indicate when a tensor name is not found, allowing callers to handle this case gracefully without assertions.
| ASSERT_VALID_ARGUMENTS(it != client_shards_.end(), | |
| "No shards found for tensor: {}", name); | |
| if (it == client_shards_.end()) { | |
| static const std::vector<TensorShardRefPtr> kEmptyShards; | |
| LOG_DEBUG("No shards found for tensor: {}", name); | |
| return kEmptyShards; | |
| } |
9024b9c to
aebb063
Compare
aebb063 to
85be695
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.
| using setu::commons::datatypes::TensorShardRefPtr; | ||
| using setu::commons::messages::ClientRequest; | ||
| using setu::commons::messages::FreeShardsRequest; | ||
| using setu::commons::messages::FreeShardsResponse; |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FreeShardsResponse type is imported but never used in this file. Since FreeShards doesn't wait for or process a response, this import should be removed to avoid confusion.
| using setu::commons::messages::FreeShardsResponse; |
|
|
||
|
|
||
| @pytest.fixture(scope="function") | ||
| @pytest.fixture(scope="module") |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the fixture scope from "function" to "module" means all tests using this fixture will share the same infrastructure (Coordinator and NodeAgents). While tests use unique tensor names, the asynchronous nature of FreeShardsRequest (which doesn't wait for a response in the client) could lead to race conditions between tests. If a test completes and its client disconnects (triggering FreeShardsRequest), but the next test starts before the FreeShardsResponse is processed, it could potentially see stale shard state. Consider adding explicit synchronization points or keeping function scope to ensure proper test isolation.
This patch does the following:
Merge after #25