[Mlir-commits] [mlir] aca088d - [MLIR][NVVM] Update dot.accumulate.4way NVVM Op (#141223)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 28 22:21:14 PDT 2025
Author: Srinivasa Ravi
Date: 2025-05-29T10:51:11+05:30
New Revision: aca088d802532c5c357c4be6e6fa6e6340d34df2
URL: https://github.com/llvm/llvm-project/commit/aca088d802532c5c357c4be6e6fa6e6340d34df2
DIFF: https://github.com/llvm/llvm-project/commit/aca088d802532c5c357c4be6e6fa6e6340d34df2.diff
LOG: [MLIR][NVVM] Update dot.accumulate.4way NVVM Op (#141223)
This change refactors and updates the `dot.accumulate.4way` NVVM Op to
be more descriptive and readable.
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 408537be0a5e4..2424e3af80d2d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3640,36 +3640,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
}
//===----------------------------------------------------------------------===//
-// NVVM dot.accumulate.4way Op
+// NVVM dot.accumulate Ops
//===----------------------------------------------------------------------===//
-def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
-def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED", 0, "unsigned">;
+def DotAccumulateSigned : I32EnumAttrCase<"SIGNED", 1, "signed">;
-def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
- "NVVM DotAccumulate4WayType",
- [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
+def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
+ "NVVM DotAccumulateType",
+ [DotAccumulateSigned, DotAccumulateUnsigned]> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}
-def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
+def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
let assemblyFormat = "`<` $value `>`";
}
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
- let summary = "Four-way byte dot product-accumulate instruction.";
+ let summary = "Four-way byte dot product-accumulate instruction";
let description = [{
Performs a four-way byte dot-product which is accumulated in a 32-bit
result.
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
computed.
+
The `a_type` and `b_type` attributes specify the type of the elements in `a`
and `b` respectively.
- If `a_type` or `b_type` is `s8`, then the elements in the corresponding
+ If `a_type` or `b_type` is `signed`, then the elements in the corresponding
vector are sign-extended to 32-bit before the dot product is computed.
- If `a_type` or `b_type` is `u8`, then the elements in the corresponding
- vector are zero-extended to 32-bit instead.
+ If `a_type` or `b_type` is `unsigned`, then the elements in the
+ corresponding vector are zero-extended to 32-bit instead.
+
Operand `c` is a 32-bit integer to which the result is accumulated. It is
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
@@ -3678,9 +3680,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
let arguments = (ins
VectorOfLengthAndType<[4], [I8]>:$a,
- DotAccumulate4WayTypeAttr:$a_type,
+ DotAccumulateTypeAttr:$a_type,
VectorOfLengthAndType<[4], [I8]>:$b,
- DotAccumulate4WayTypeAttr:$b_type,
+ DotAccumulateTypeAttr:$b_type,
I32:$c
);
@@ -3689,17 +3691,15 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
- NVVM::DotAccumulate4WayType b_type);
- llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
- llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
- llvm::Value* argA = op.getPackedArg($a, builder);
- llvm::Value* argB = op.getPackedArg($b, builder);
- $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+ auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ $res = createIntrinsicCall(builder, id, args);
}];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8036ea27f524f..648b6b087e592 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1205,13 +1205,6 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}
-llvm::Value *
-NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
- llvm::IRBuilderBase &builder) {
- return builder.CreateBitCast(arg,
- llvm::Type::getInt32Ty(builder.getContext()));
-}
-
/// Packs the given `field` into the `result`.
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
static llvm::Value *
@@ -1692,24 +1685,31 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
-llvm::Intrinsic::ID
-DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
- NVVM::DotAccumulate4WayType b_type) {
- bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
- bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
- unsigned type = (is_a_siext << 1) | is_b_siext;
- switch (type) {
- case 0:
- return llvm::Intrinsic::nvvm_idp4a_u_u;
- case 1:
- return llvm::Intrinsic::nvvm_idp4a_u_s;
- case 2:
- return llvm::Intrinsic::nvvm_idp4a_s_u;
- case 3:
- return llvm::Intrinsic::nvvm_idp4a_s_s;
- default:
- llvm_unreachable("Invalid DP4a type");
- }
+static llvm::Value *getAsPackedI32(llvm::Value *arg,
+ llvm::IRBuilderBase &builder) {
+ return builder.CreateBitCast(arg,
+ llvm::Type::getInt32Ty(builder.getContext()));
+}
+
+NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
+ args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
+ args.push_back(mt.lookupValue(curOp.getC()));
+
+ bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
+ bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
+ unsigned type = (isASigned << 1) | isBSigned;
+ const llvm::Intrinsic::ID ids[] = {
+ llvm::Intrinsic::nvvm_idp4a_u_u,
+ llvm::Intrinsic::nvvm_idp4a_u_s,
+ llvm::Intrinsic::nvvm_idp4a_s_u,
+ llvm::Intrinsic::nvvm_idp4a_s_s,
+ };
+ return {ids[type], args};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..77b302155cb12 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -579,11 +579,11 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
}
// CHECK-LABEL: @dot_accumulate_4way
-func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
- %1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dot.accumulate.4way %a_vec <unsigned>, %b_vec <unsigned>, %c: vector<4xi8>, vector<4xi8>
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
- %3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dot.accumulate.4way %a_vec <signed>, %b_vec <signed>, %c: vector<4xi8>, vector<4xi8>
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c6def56199f37..e892fc43f4a39 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -851,18 +851,18 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+ %0 = nvvm.dot.accumulate.4way %a <unsigned>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dot.accumulate.4way %a <signed>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+ %2 = nvvm.dot.accumulate.4way %a <unsigned>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dot.accumulate.4way %a <signed>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
llvm.return
}
More information about the Mlir-commits
mailing list