Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion include/infinicore/quantization/awq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,23 @@ class AWQ : public BaseQuantization {
// information and support multiple quantization schemes.
public:
explicit AWQ(const nlohmann::json &quant_config)
: BaseQuantization(quant_config) {};
: BaseQuantization(quant_config){};

infinicore::quantization::QuantScheme
get_quant_scheme() const override {
return infinicore::quantization::QuantScheme::AWQ_W4A16;
};

int get_packing_num() const {
// For AWQ, we pack 8 int4 weights into a single int32 value.
return 32 / this->get_or<int>("bits", 4); // Default to 8 if not specified in config
}

int get_group_size() const {
// For simplicity, we return a fixed group size here. In a more complete implementation,
// this could be extracted from quant_config_ to support different group sizes.
return this->get_or<int>("group_size", 128); // Standard AWQ group size
}
};

} // namespace infinicore::quantization
26 changes: 25 additions & 1 deletion include/infinicore/quantization/base_quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,34 @@ namespace infinicore::quantization {
class BaseQuantization {
// Base class for quantization schemes. Intended to be extended to support various quantization methods.
public:
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {};
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config){};
virtual ~BaseQuantization() = default;

virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0;
template <typename T>
T get(const std::string &key) const {
if (!quant_config_.contains(key)) {
throw std::out_of_range("Key '" + key + "' not found in config.");
}
try {
return quant_config_.at(key).get<T>();
} catch (const nlohmann::json::type_error &e) {
throw std::runtime_error("Type conversion failed for key '" + key + "': " + std::string(e.what()));
}
}

template <typename T>
T get_or(const std::string &key, const T &default_value) const {
if (!quant_config_.contains(key) || quant_config_.at(key).is_null()) {
return default_value;
}
try {
return quant_config_.at(key).get<T>();
} catch (const nlohmann::json::type_error &) {
// If type conversion fails, return default value
return default_value;
}
}

protected:
nlohmann::json quant_config_;
Expand Down
95 changes: 95 additions & 0 deletions src/infinicore/nn/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "infinicore/ops.hpp"
#include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/linear.hpp"
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/linear_w8a8i8.hpp"
#include <optional>
#include <spdlog/spdlog.h>
Expand Down Expand Up @@ -43,6 +44,15 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
auto output = infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight_packed_tensor, weight_scale_tensor, bias_opt);
return output;
}
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous();
Tensor qweight = static_cast<const Tensor &>(weight_);
Tensor qzeros = static_cast<const Tensor &>(weight_zeros_);
Tensor scales = static_cast<const Tensor &>(weight_scale_);
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
auto output = infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt);
return output;
}
default: {
// Ensure input is contiguous before creating views (required for matmul)
// This prevents hanging when input tensor has non-contiguous memory layout
Expand Down Expand Up @@ -116,6 +126,20 @@ Linear::Linear(size_t in_features, size_t out_features,
}
break;
}
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
weight_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
this->register_parameter("qweight", weight_);
weight_zeros_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
this->register_parameter("qzeros", weight_zeros_);
weight_scale_ = infinicore::nn::Parameter({out_features, in_features}, dtype_, device);
this->register_parameter("scales", weight_scale_);
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device));
Expand Down Expand Up @@ -190,6 +214,39 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
}
break;
}
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_);
int group_size = awq_ptr->get_group_size();
int packing_num = awq_ptr->get_packing_num();

weight_ = infinicore::nn::Parameter({in_features, out_features / packing_num},
infinicore::DataType::I32,
device, 1, tp_rank_, tp_size_);
this->register_parameter("qweight", weight_);

// Weight scale: [out_features, in_features / group_size]
// One FP32 scale per group of weights (group_size=128)

weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features},
dtype_,
device, 1, tp_rank_, tp_size_);
this->register_parameter("scales", weight_scale_);

// Weight zeros (zero points): [out_features, in_features / group_size]
// AWQ implementations (e.g., AutoAWQ) typically store zero points as I32
// for symmetric/asymmetric quantization support
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features / packing_num},
infinicore::DataType::I32,
device, 1, tp_rank_, tp_size_);

this->register_parameter("qzeros", weight_zeros_);
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
Expand Down Expand Up @@ -261,6 +318,44 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, st
}
break;
}
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
// AWQ W4A16 for RowParallelLinear:切分维度为 in_features(权重矩阵的第1维)
// - Weight: packed int4 in I32 containers (8 int4 per I32)
// - Group-wise quantization with group_size=128
// - Scale and zero points stored per group along in_features dimension

auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_);
int group_size = awq_ptr->get_group_size();
int packing_num = awq_ptr->get_packing_num();

// Packed weight: [out_features, in_features / 8]
weight_ = infinicore::nn::Parameter({in_features, out_features / packing_num},
infinicore::DataType::I32,
device, 0, tp_rank_, tp_size_);
this->register_parameter("qweight", weight_);

// Weight scale: [out_features, in_features / group_size]

weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features},
dtype_,
device, 0, tp_rank_, tp_size_);
this->register_parameter("scales", weight_scale_);
// Weight zeros (zero points): [out_features, in_features / group_size]
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features / packing_num},
infinicore::DataType::I32,
device, 0, tp_rank_, tp_size_);
this->register_parameter("qzeros", weight_zeros_);

// Bias handling in RowParallelLinear:
// - Only rank 0 holds the full bias (after all-reduce on output)
// - Other ranks have empty bias parameter
if (bias && (0 == tp_rank_)) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
} else {
bias_ = Parameter();
}
break;
}
default: {
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
Expand Down
20 changes: 12 additions & 8 deletions src/infinicore/ops/linear_w4a16_awq/linear_w4a16_awq.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "infinicore/ops/linear_w4a16_awq.hpp"
#include "infinicore/ops/dequantize_awq.hpp"
#include "infinicore/ops/gemm.hpp"

#include "infinicore/ops/rearrange.hpp"
namespace infinicore::op {

Tensor linear_w4a16_awq(Tensor input,
Expand All @@ -12,7 +12,8 @@ Tensor linear_w4a16_awq(Tensor input,

// Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1]
Size ndim = input->ndim();
Size out_features = weight_packed->shape()[0];
Size element_size = weight_packed->element_size();
Size out_features = weight_packed->shape()[1] * element_size * 2;

// Assign memory to out variables
auto output_shape = input->shape();
Expand All @@ -33,7 +34,7 @@ void linear_w4a16_awq_(Tensor out,

auto weight_packed_shape = weight_packed->shape();
Size out_features = weight_packed_shape[0];
Size in_features = weight_packed_shape[1];
Size in_features = weight_packed_shape[1] * 8;

Size ndim = input->ndim();
assert(out->ndim() == ndim);
Expand All @@ -43,18 +44,21 @@ void linear_w4a16_awq_(Tensor out,
for (size_t i = 0; i < ndim - 1; ++i) {
N *= input_shape[i];
}

auto weight = Tensor::empty(
{out_features, in_features},
out->dtype(),
weight_packed->device());
float alpha = 1.0f;
float beta = 0.0f;
op::dequantize_awq_(weight, weight_packed, weight_scale, weight_zeros);
bias = std::make_optional(bias.value()->as_strided({N, out_features}, {0, 1}));
gemm_(out->view({N, out_features}),
input->view({N, in_features}),
weight->permute({1, 0}), alpha, beta);
if (bias.has_value()) {
rearrange_(out,
bias.value()->as_strided({N, in_features}, {0, 1}));
beta = 1.0f;
}
gemm_(out,
input->view({N, out_features}),
weight, alpha, beta);
}

} // namespace infinicore::op