Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ void Train(const nn::parallel::Rank &rank) {
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
rank.thread_rank(), ddp_config);
(*mutable_chunks)[chunk_id]
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank, ddp_config);
}
}
} else if (ddp_world_size > 1) {
Expand All @@ -230,7 +230,7 @@ void Train(const nn::parallel::Rank &rank) {
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
}

DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
Expand Down
6 changes: 3 additions & 3 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ void Train(const nn::parallel::Rank &rank) {
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
rank.thread_rank(), ddp_config);
(*mutable_chunks)[chunk_id]
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank, ddp_config);
}
}
} else if (ddp_world_size > 1) {
Expand All @@ -210,7 +210,7 @@ void Train(const nn::parallel::Rank &rank) {
// are created during the conversion.

auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
}

DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ class Tensor;
class Device;
namespace nn::parallel {
class DistributedDataParallelConfig;
class Rank;
} // namespace nn::parallel
} // namespace infini_train

namespace infini_train::nn::parallel {

class DistributedDataParallel : public nn::Module {
public:
DistributedDataParallel(std::shared_ptr<nn::Module> module, int thread_rank,
DistributedDataParallel(std::shared_ptr<nn::Module> module, const Rank &rank,
DistributedDataParallelConfig ddp_config);

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
Expand Down
10 changes: 6 additions & 4 deletions infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/tensor.h"

Expand All @@ -19,21 +20,22 @@ namespace {
constexpr char kModuleName[] = "module";
} // namespace

DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> module, int thread_rank,
DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> module, const Rank &rank,
const DistributedDataParallelConfig ddp_config)
: ddp_config_(ddp_config),
ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(thread_rank))) {
ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) {
for (auto &param : module->Parameters()) {
auto device = param->GetDevice();
CHECK_EQ(device.index(), thread_rank) << "All parameters must be on the same device as the module";
CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module";
if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) {
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(
function::ReduceOpType::kAvg, ddp_pg_);
param->RegisterPostAccumulateGradHook(std::move(hook));
}
}
for (auto &buffer : module->Buffers()) {
CHECK_EQ(buffer->GetDevice().index(), thread_rank) << "All buffers must be on the same device as the module";
CHECK_EQ(buffer->GetDevice().index(), rank.thread_rank())
<< "All buffers must be on the same device as the module";
}
modules_[kModuleName] = std::move(module);

Expand Down