[Mlir-commits] [mlir] [MLIR][NVVM] Update dot.accumulate NVVM Ops (PR #140518)
Srinivasa Ravi
llvmlistbot at llvm.org
Tue May 20 23:45:20 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/140518
>From 7aeccec1ed7e3a311cd7bc1b05070cdf1ba2362a Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 15 May 2025 16:41:08 +0530
Subject: [PATCH] [MLIR][NVVM] Update dot.accumulate NVVM Ops
This change:
- Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit
dot-product accumulate operation.
- Refactors the recently added dot.accumulate.4way and adds a verifier.
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 88 +++++++++++++++++----
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 42 +++++++++-
mlir/test/Dialect/LLVMIR/nvvm.mlir | 15 +++-
mlir/test/Target/LLVMIR/nvvmir.mlir | 46 ++++++++++-
4 files changed, 166 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 654aff71f25be..6421031195d1a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3445,25 +3445,25 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
}
//===----------------------------------------------------------------------===//
-// NVVM dot.accumulate.4way Op
+// NVVM dot.accumulate Ops
//===----------------------------------------------------------------------===//
-def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
-def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateSigned : I32EnumAttrCase<"S", 1, "s">;
+def DotAccumulateUnsigned : I32EnumAttrCase<"U", 0, "u">;
-def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
- "NVVM DotAccumulate4WayType",
- [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
+def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
+ "NVVM DotAccumulateType",
+ [DotAccumulateSigned, DotAccumulateUnsigned]> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}
-def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
+def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
let assemblyFormat = "`<` $value `>`";
}
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
- let summary = "Four-way byte dot product-accumulate instruction.";
+ let summary = "Four-way byte dot product-accumulate instruction";
let description = [{
Performs a four-way byte dot-product which is accumulated in a 32-bit
result.
@@ -3471,9 +3471,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
computed.
The `a_type` and `b_type` attributes specify the type of the elements in `a`
and `b` respectively.
- If `a_type` or `b_type` is `s8`, then the elements in the corresponding
+ If `a_type` or `b_type` is `s`, then the elements in the corresponding
vector are sign-extended to 32-bit before the dot product is computed.
- If `a_type` or `b_type` is `u8`, then the elements in the corresponding
+ If `a_type` or `b_type` is `u`, then the elements in the corresponding
vector are zero-extended to 32-bit instead.
Operand `c` is a 32-bit integer to which the result is accumulated. It is
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
@@ -3483,9 +3483,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
let arguments = (ins
VectorOfLengthAndType<[4], [I8]>:$a,
- DotAccumulate4WayTypeAttr:$a_type,
+ DotAccumulateTypeAttr:$a_type,
VectorOfLengthAndType<[4], [I8]>:$b,
- DotAccumulate4WayTypeAttr:$b_type,
+ DotAccumulateTypeAttr:$b_type,
I32:$c
);
@@ -3495,8 +3495,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
let extraClassDeclaration = [{
static llvm::Intrinsic::ID
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
- NVVM::DotAccumulate4WayType b_type);
+ getIntrinsicID(NVVM::DotAccumulateType a_type,
+ NVVM::DotAccumulateType b_type);
llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
}];
@@ -3508,6 +3508,66 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
}];
}
+def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
+ let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
+ let description = [{
+ Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a
+ 32-bit result.
+ Operand `a` is a vector of two 16-bit elements and operand `b` a vector
+ of four 8-bit elements between which the dot product is computed.
+
+ The `a_type` and `b_type` attributes specify the type of the elements in `a`
+ and `b` respectively.
+ If `a_type` or `b_type` is `s`, then the elements in the corresponding
+ vector are sign-extended to 32-bit before the dot product is computed.
+ If `a_type` or `b_type` is `u`, then the elements in the corresponding
+ vector are zero-extended to 32-bit instead.
+
+ The `hi` boolean attribute specifies which two bytes of `b` are used for
+ the dot product. If `hi` is true, then the dot product is computed between
+ `a` and elements at indices 2 and 3 of `b`. If `hi` is false, then the dot
+ product is computed between `a` and elements at indices 0 and 1 of `b`.
+ By default, `hi` is false.
+
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
+ treated as holding a signed integer if any of `a_type` or `b_type` is
+ signed.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
+ }];
+
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [I16]>:$a,
+ DotAccumulateTypeAttr:$a_type,
+ VectorOfLengthAndType<[4], [I8]>:$b,
+ DotAccumulateTypeAttr:$b_type,
+ I32:$c,
+ DefaultValuedAttr<BoolAttr, "false">:$hi
+ );
+
+ let results = (outs I32:$res);
+
+ let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder,
+ llvm::SmallVector<llvm::Value *> &args);
+ llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+ }];
+
+ string llvmBuilder = [{
+ llvm::SmallVector<llvm::Value *> args;
+
+ llvm::Intrinsic::ID
+ id = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder, args);
+
+ $res = createIntrinsicCall(builder, id, args);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 1ea3f96fa75f5..d4b013d2e7f1e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1211,6 +1211,13 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
llvm::Type::getInt32Ty(builder.getContext()));
}
+llvm::Value *
+NVVM::DotAccumulate2WayOp::getPackedArg(llvm::Value *arg,
+ llvm::IRBuilderBase &builder) {
+ return builder.CreateBitCast(arg,
+ llvm::Type::getInt32Ty(builder.getContext()));
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
@@ -1599,10 +1606,10 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
llvm::Intrinsic::ID
-DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
- NVVM::DotAccumulate4WayType b_type) {
- bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
- bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
+DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulateType a_type,
+ NVVM::DotAccumulateType b_type) {
+ bool is_a_siext = a_type == NVVM::DotAccumulateType::S;
+ bool is_b_siext = b_type == NVVM::DotAccumulateType::S;
unsigned type = (is_a_siext << 1) | is_b_siext;
switch (type) {
case 0:
@@ -1618,6 +1625,33 @@ DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
}
}
+llvm::Intrinsic::ID DotAccumulate2WayOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder,
+ llvm::SmallVector<llvm::Value *> &args) {
+ auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
+
+ args.push_back(curOp.getPackedArg(mt.lookupValue(curOp.getA()), builder));
+ args.push_back(curOp.getPackedArg(mt.lookupValue(curOp.getB()), builder));
+ args.push_back(builder.getInt1(curOp.getHi()));
+ args.push_back(mt.lookupValue(curOp.getC()));
+
+ bool is_a_siext = curOp.getAType() == NVVM::DotAccumulateType::S;
+ bool is_b_siext = curOp.getBType() == NVVM::DotAccumulateType::S;
+ unsigned type = (is_a_siext << 1) | is_b_siext;
+ switch (type) {
+ case 0:
+ return llvm::Intrinsic::nvvm_idp2a_u_u;
+ case 1:
+ return llvm::Intrinsic::nvvm_idp2a_u_s;
+ case 2:
+ return llvm::Intrinsic::nvvm_idp2a_s_u;
+ case 3:
+ return llvm::Intrinsic::nvvm_idp2a_s_s;
+ default:
+ llvm_unreachable("Invalid DP2a type");
+ }
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..69e19de40a68a 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -579,11 +579,20 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
}
// CHECK-LABEL: @dot_accumulate_4way
-func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
- %1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dot.accumulate.4way %a_vec <u>, %b_vec <u>, %c: vector<4xi8>, vector<4xi8>
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
- %3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dot.accumulate.4way %a_vec <s>, %b_vec <s>, %c: vector<4xi8>, vector<4xi8>
+ return
+}
+
+// CHECK-LABEL: @dot_accumulate_2way
+func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) {
+ // CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi16>, vector<4xi8>
+ %1 = nvvm.dot.accumulate.2way %a_vec <u>, %b_vec <u>, %c: vector<2xi16>, vector<4xi8>
+ // CHECK: nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {hi = true} : vector<2xi16>, vector<4xi8>
+ %3 = nvvm.dot.accumulate.2way %a_vec <s>, %b_vec <s>, %c {hi = true}: vector<2xi16>, vector<4xi8>
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 894b72733a46a..b26f6afedf7cf 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -851,18 +851,56 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+ %0 = nvvm.dot.accumulate.4way %a <u>, %b <u>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dot.accumulate.4way %a <s>, %b <u>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+ %2 = nvvm.dot.accumulate.4way %a <u>, %b <s>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dot.accumulate.4way %a <s>, %b <s>, %c: vector<4xi8>, vector<4xi8>
+ llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_dot_accumulate_2way
+llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+ %0 = nvvm.dot.accumulate.2way %a <u>, %b <u>, %c: vector<2xi16>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+ %1 = nvvm.dot.accumulate.2way %a <u>, %b <u>, %c {hi = true}: vector<2xi16>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+ %2 = nvvm.dot.accumulate.2way %a <s>, %b <u>, %c: vector<2xi16>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+ %3 = nvvm.dot.accumulate.2way %a <s>, %b <u>, %c {hi = true}: vector<2xi16>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+ %4 = nvvm.dot.accumulate.2way %a <u>, %b <s>, %c: vector<2xi16>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+ %5 = nvvm.dot.accumulate.2way %a <u>, %b <s>, %c {hi = true}: vector<2xi16>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+ %6 = nvvm.dot.accumulate.2way %a <s>, %b <s>, %c: vector<2xi16>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+ %7 = nvvm.dot.accumulate.2way %a <s>, %b <s>, %c {hi = true}: vector<2xi16>, vector<4xi8>
llvm.return
}
More information about the Mlir-commits
mailing list