[Mlir-commits] [mlir] aca088d - [MLIR][NVVM] Update dot.accumulate.4way NVVM Op (#141223)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 28 22:21:14 PDT 2025


Author: Srinivasa Ravi
Date: 2025-05-29T10:51:11+05:30
New Revision: aca088d802532c5c357c4be6e6fa6e6340d34df2

URL: https://github.com/llvm/llvm-project/commit/aca088d802532c5c357c4be6e6fa6e6340d34df2
DIFF: https://github.com/llvm/llvm-project/commit/aca088d802532c5c357c4be6e6fa6e6340d34df2.diff

LOG: [MLIR][NVVM] Update dot.accumulate.4way NVVM Op (#141223)

This change refactors and updates the `dot.accumulate.4way` NVVM Op to
be more descriptive and readable.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 408537be0a5e4..2424e3af80d2d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3640,36 +3640,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
 }
 
 //===----------------------------------------------------------------------===//
-// NVVM dot.accumulate.4way Op
+// NVVM dot.accumulate Ops
 //===----------------------------------------------------------------------===//
 
-def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
-def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED", 0, "unsigned">;
+def DotAccumulateSigned : I32EnumAttrCase<"SIGNED", 1, "signed">;
 
-def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
-                              "NVVM DotAccumulate4WayType",
-                              [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
+def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
+                              "NVVM DotAccumulateType",
+                              [DotAccumulateSigned, DotAccumulateUnsigned]> {
   let cppNamespace = "::mlir::NVVM";
   let genSpecializedAttr = 0;
 }
 
-def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
+def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
 def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
-  let summary = "Four-way byte dot product-accumulate instruction.";
+  let summary = "Four-way byte dot product-accumulate instruction";
   let description = [{
     Performs a four-way byte dot-product which is accumulated in a 32-bit
     result.
     Operand `a` and `b` are vectors of 4 bytes between which the dot product is 
     computed.
+
     The `a_type` and `b_type` attributes specify the type of the elements in `a`
     and `b` respectively.
-    If `a_type` or `b_type` is `s8`, then the elements in the corresponding 
+    If `a_type` or `b_type` is `signed`, then the elements in the corresponding 
     vector are sign-extended to 32-bit before the dot product is computed.
-    If `a_type` or `b_type` is `u8`, then the elements in the corresponding 
-    vector are zero-extended to 32-bit instead.
+    If `a_type` or `b_type` is `unsigned`, then the elements in the 
+    corresponding vector are zero-extended to 32-bit instead.
+
     Operand `c` is a 32-bit integer to which the result is accumulated. It is
     treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
     
@@ -3678,9 +3680,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   
   let arguments = (ins
     VectorOfLengthAndType<[4], [I8]>:$a,
-    DotAccumulate4WayTypeAttr:$a_type,
+    DotAccumulateTypeAttr:$a_type,
     VectorOfLengthAndType<[4], [I8]>:$b,
-    DotAccumulate4WayTypeAttr:$b_type,
+    DotAccumulateTypeAttr:$b_type,
     I32:$c
   );
 
@@ -3689,17 +3691,15 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
   
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID
-    getIntrinsicID(NVVM::DotAccumulate4WayType a_type, 
-                   NVVM::DotAccumulate4WayType b_type);
-    llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase &builder);
   }];
 
   string llvmBuilder = [{
-    llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
-    llvm::Value* argA = op.getPackedArg($a, builder);
-    llvm::Value* argB = op.getPackedArg($b, builder);
-    $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+    auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+                        *op, moduleTranslation, builder);
+    $res = createIntrinsicCall(builder, id, args);
   }];
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8036ea27f524f..648b6b087e592 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1205,13 +1205,6 @@ LogicalResult NVVM::VoteSyncOp::verify() {
   return success();
 }
 
-llvm::Value *
-NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
-                                        llvm::IRBuilderBase &builder) {
-  return builder.CreateBitCast(arg,
-                               llvm::Type::getInt32Ty(builder.getContext()));
-}
-
 /// Packs the given `field` into the `result`.
 /// The `result` is 64-bits and each `field` can be 32-bits or narrower.
 static llvm::Value *
@@ -1692,24 +1685,31 @@ static void nvvmInferResultRanges(Operation *op, Value result,
   }
 }
 
-llvm::Intrinsic::ID
-DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
-                                    NVVM::DotAccumulate4WayType b_type) {
-  bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
-  bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
-  unsigned type = (is_a_siext << 1) | is_b_siext;
-  switch (type) {
-  case 0:
-    return llvm::Intrinsic::nvvm_idp4a_u_u;
-  case 1:
-    return llvm::Intrinsic::nvvm_idp4a_u_s;
-  case 2:
-    return llvm::Intrinsic::nvvm_idp4a_s_u;
-  case 3:
-    return llvm::Intrinsic::nvvm_idp4a_s_s;
-  default:
-    llvm_unreachable("Invalid DP4a type");
-  }
+static llvm::Value *getAsPackedI32(llvm::Value *arg,
+                                   llvm::IRBuilderBase &builder) {
+  return builder.CreateBitCast(arg,
+                               llvm::Type::getInt32Ty(builder.getContext()));
+}
+
+NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
+  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
+  args.push_back(mt.lookupValue(curOp.getC()));
+
+  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
+  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
+  unsigned type = (isASigned << 1) | isBSigned;
+  const llvm::Intrinsic::ID ids[] = {
+      llvm::Intrinsic::nvvm_idp4a_u_u,
+      llvm::Intrinsic::nvvm_idp4a_u_s,
+      llvm::Intrinsic::nvvm_idp4a_s_u,
+      llvm::Intrinsic::nvvm_idp4a_s_s,
+  };
+  return {ids[type], args};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..77b302155cb12 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -579,11 +579,11 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
 }
 
 // CHECK-LABEL: @dot_accumulate_4way
-func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
   // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
-  %1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.4way %a_vec <unsigned>, %b_vec <unsigned>, %c: vector<4xi8>, vector<4xi8>
   // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
-  %3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.4way %a_vec <signed>, %b_vec <signed>, %c: vector<4xi8>, vector<4xi8>
   return
 }
 

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c6def56199f37..e892fc43f4a39 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -851,18 +851,18 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+  %0 = nvvm.dot.accumulate.4way %a <unsigned>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.4way %a <signed>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+  %2 = nvvm.dot.accumulate.4way %a <unsigned>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.4way %a <signed>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
   llvm.return
 }


        


More information about the Mlir-commits mailing list