[Mlir-commits] [mlir] [MLIR][NVVM] Add `dot.accumulate.4way` OP (PR #139043)

Srinivasa Ravi llvmlistbot at llvm.org
Sun May 11 21:37:12 PDT 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/139043

>From c8914d11ba4870a8195c9b8c39323e685836619a Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 30 Apr 2025 15:31:40 +0530
Subject: [PATCH] [MLIR][NVVM] Add support for dp4a instructions

This change adds the `dp4a` Op to the NVVM dialect to perform four-way
byte dot product-accumulate operation.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 64 +++++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 28 +++++++++
 mlir/test/Dialect/LLVMIR/nvvm.mlir          |  9 +++
 mlir/test/Target/LLVMIR/nvvmir.mlir         | 22 +++++++
 4 files changed, 123 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6540273b216e3..654aff71f25be 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3444,6 +3444,70 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM dot.accumulate.4way Op
+//===----------------------------------------------------------------------===//
+
+def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
+def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
+
+def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
+                              "NVVM DotAccumulate4WayType",
+                              [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
+  let cppNamespace = "::mlir::NVVM";
+  let genSpecializedAttr = 0;
+}
+
+def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
+  let summary = "Four-way byte dot product-accumulate instruction.";
+  let description = [{
+    Performs a four-way byte dot-product which is accumulated in a 32-bit
+    result.
+    Operand `a` and `b` are vectors of 4 bytes between which the dot product is 
+    computed.
+    The `a_type` and `b_type` attributes specify the type of the elements in `a`
+    and `b` respectively.
+    If `a_type` or `b_type` is `s8`, then the elements in the corresponding 
+    vector are sign-extended to 32-bit before the dot product is computed.
+    If `a_type` or `b_type` is `u8`, then the elements in the corresponding 
+    vector are zero-extended to 32-bit instead.
+    Operand `c` is a 32-bit integer to which the result is accumulated. It is
+    treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
+    
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
+  }];
+  
+  let arguments = (ins
+    VectorOfLengthAndType<[4], [I8]>:$a,
+    DotAccumulate4WayTypeAttr:$a_type,
+    VectorOfLengthAndType<[4], [I8]>:$b,
+    DotAccumulate4WayTypeAttr:$b_type,
+    I32:$c
+  );
+
+  let results = (outs I32:$res);
+
+  let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
+  
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID
+    getIntrinsicID(NVVM::DotAccumulate4WayType a_type, 
+                   NVVM::DotAccumulate4WayType b_type);
+    llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+  }];
+
+  string llvmBuilder = [{
+    llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
+    llvm::Value* argA = op.getPackedArg($a, builder);
+    llvm::Value* argB = op.getPackedArg($b, builder);
+    $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3c3731a63e268..1ea3f96fa75f5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -33,6 +33,7 @@
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/Casting.h"
@@ -1203,6 +1204,13 @@ LogicalResult NVVM::VoteSyncOp::verify() {
   return success();
 }
 
+llvm::Value *
+NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
+                                        llvm::IRBuilderBase &builder) {
+  return builder.CreateBitCast(arg,
+                               llvm::Type::getInt32Ty(builder.getContext()));
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//
@@ -1590,6 +1598,26 @@ static void nvvmInferResultRanges(Operation *op, Value result,
   }
 }
 
+llvm::Intrinsic::ID
+DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
+                                    NVVM::DotAccumulate4WayType b_type) {
+  bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
+  bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
+  unsigned type = (is_a_siext << 1) | is_b_siext;
+  switch (type) {
+  case 0:
+    return llvm::Intrinsic::nvvm_idp4a_u_u;
+  case 1:
+    return llvm::Intrinsic::nvvm_idp4a_u_s;
+  case 2:
+    return llvm::Intrinsic::nvvm_idp4a_s_u;
+  case 3:
+    return llvm::Intrinsic::nvvm_idp4a_s_s;
+  default:
+    llvm_unreachable("Invalid DP4a type");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index d3915492c38a0..e8425638cc9be 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -578,6 +578,15 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
   return
 }
 
+// CHECK-LABEL: @dot_accumulate_4way
+func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+  // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
+  // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
+  return
+}
+
 // -----
 
 // Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 3a0713f2feee8..894b72733a46a 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -844,3 +844,25 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size:
   nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_dot_accumulate_4way
+llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) {
+  // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+  %0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+  %1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+  %2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+  %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+  llvm.return
+}



More information about the Mlir-commits mailing list