diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 69f8ba7e..a007dff1 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -220,8 +220,8 @@ void Train(const nn::parallel::Rank &rank) { = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { - (*mutable_chunks)[chunk_id] = std::make_shared(mutable_chunks->at(chunk_id), - rank.thread_rank(), ddp_config); + (*mutable_chunks)[chunk_id] + = std::make_shared(mutable_chunks->at(chunk_id), rank, ddp_config); } } } else if (ddp_world_size > 1) { @@ -230,7 +230,7 @@ void Train(const nn::parallel::Rank &rank) { // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; - model = std::make_shared(model, rank.thread_rank(), ddp_config); + model = std::make_shared(model, rank, ddp_config); } DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 6d4c9a7b..2b1e2121 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -199,8 +199,8 @@ void Train(const nn::parallel::Rank &rank) { = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { - (*mutable_chunks)[chunk_id] = std::make_shared(mutable_chunks->at(chunk_id), - rank.thread_rank(), ddp_config); + (*mutable_chunks)[chunk_id] + = std::make_shared(mutable_chunks->at(chunk_id), rank, ddp_config); } } } else if (ddp_world_size > 1) { @@ -210,7 +210,7 @@ void Train(const nn::parallel::Rank &rank) { // are created during the conversion. auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; - model = std::make_shared(model, rank.thread_rank(), ddp_config); + model = std::make_shared(model, rank, ddp_config); } DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), diff --git a/infini_train/include/nn/parallel/ddp/distributed_data_parallel.h b/infini_train/include/nn/parallel/ddp/distributed_data_parallel.h index 9a7713c5..bba474bf 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_data_parallel.h +++ b/infini_train/include/nn/parallel/ddp/distributed_data_parallel.h @@ -11,6 +11,7 @@ class Tensor; class Device; namespace nn::parallel { class DistributedDataParallelConfig; +class Rank; } // namespace nn::parallel } // namespace infini_train @@ -18,7 +19,7 @@ namespace infini_train::nn::parallel { class DistributedDataParallel : public nn::Module { public: - DistributedDataParallel(std::shared_ptr module, int thread_rank, + DistributedDataParallel(std::shared_ptr module, const Rank &rank, DistributedDataParallelConfig ddp_config); std::vector> Forward(const std::vector> &input_tensors) override; diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 67197747..002fe318 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -11,6 +11,7 @@ #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/process_group.h" +#include "infini_train/include/nn/parallel/rank.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/tensor.h" @@ -19,13 +20,13 @@ namespace { constexpr char kModuleName[] = "module"; } // namespace -DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, int thread_rank, +DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, const Rank &rank, const DistributedDataParallelConfig ddp_config) : ddp_config_(ddp_config), - ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(thread_rank))) { + ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) { for (auto ¶m : module->Parameters()) { auto device = param->GetDevice(); - CHECK_EQ(device.index(), thread_rank) << "All parameters must be on the same device as the module"; + CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module"; if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) { auto hook = std::make_unique( function::ReduceOpType::kAvg, ddp_pg_); @@ -33,7 +34,8 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod } } for (auto &buffer : module->Buffers()) { - CHECK_EQ(buffer->GetDevice().index(), thread_rank) << "All buffers must be on the same device as the module"; + CHECK_EQ(buffer->GetDevice().index(), rank.thread_rank()) + << "All buffers must be on the same device as the module"; } modules_[kModuleName] = std::move(module);