[Mlir-commits] [mlir] [MLIR][NVVM] Add support for dp4a instructions (PR #139043)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 8 01:15:04 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This change adds the `dp4a` Op to the NVVM dialect to perform four-way byte dot product-accumulate operation.

PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a

---
Full diff: https://github.com/llvm/llvm-project/pull/139043.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+48) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+8) 
- (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+13) 
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+36) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6540273b216e3..85b3e80711018 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3444,6 +3444,54 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM dp4a Op
+//===----------------------------------------------------------------------===//
+
+def NVVM_Dp4aOp : NVVM_Op<"dp4a"> {
+  let summary = "Four-way byte dot product-accumulate instruction.";
+  let description = [{
+    Performs a four-way byte dot-product which is accumulated in a 32-bit
+    result.
+    Operand `a` and `b` can be passed either as packed 32-bit inputs holding
+    4 byte-inputs for the dot product, or as vectors of 4 i8 elements.
+    The `asigned` and `bsigned` unit attributes specify whether the
+    individual byte inputs in operands `a` and `b` are signed or unsigned
+    respectively.
+    Operand `c` is a 32-bit integer to which the result is accumulated. It is
+    treated as holding a signed integer if any of `a` or `b` are signed.
+    
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
+  }];
+  
+  let arguments = (ins
+    AnyTypeOf<[I32, VectorOfLengthAndType<[4], [I8]>]>:$a,
+    AnyTypeOf<[I32, VectorOfLengthAndType<[4], [I8]>]>:$b,
+    I32:$c,
+    DefaultValuedAttr<UnitAttr, "false">:$a_signed,
+    DefaultValuedAttr<UnitAttr, "false">:$b_signed
+  );
+
+  let results = (outs I32:$res);
+
+  let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($a) `,` type($b)";
+  
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(bool a_signed, bool b_signed);
+  }];
+
+  string llvmBuilder = [{
+    auto id = NVVM::Dp4aOp::getIntrinsicID($a_signed, $b_signed);
+    llvm::Value* argA = $a;
+    llvm::Value* argB = $b;
+    if (!op.getA().getType().isInteger(32))
+      argA = builder.CreateBitCast(argA, llvm::Type::getInt32Ty(builder.getContext()));
+    if (!op.getB().getType().isInteger(32))
+      argB = builder.CreateBitCast(argB, llvm::Type::getInt32Ty(builder.getContext()));
+    $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3c3731a63e268..a4100d7ce3bef 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1590,6 +1590,14 @@ static void nvvmInferResultRanges(Operation *op, Value result,
   }
 }
 
+#define GET_DP4A_ID(a_sign, is_b_signed)                                       \
+  is_b_signed ? llvm::Intrinsic::nvvm_idp4a_##a_sign##_s                       \
+              : llvm::Intrinsic::nvvm_idp4a_##a_sign##_u
+
+llvm::Intrinsic::ID Dp4aOp::getIntrinsicID(bool a_signed, bool b_signed) {
+  return a_signed ? GET_DP4A_ID(s, b_signed) : GET_DP4A_ID(u, b_signed);
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index d3915492c38a0..53ef034821611 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -578,6 +578,19 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
   return
 }
 
+// CHECK-LABEL: @dp4a
+func.func @dp4a_packed(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+  // CHECK:   nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : i32, i32
+  %0 = nvvm.dp4a %a, %b, %c: i32, i32
+  // CHECK:   nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
+  %1 = nvvm.dp4a %a_vec, %b_vec, %c: vector<4xi8>, vector<4xi8>
+  // CHECK:   nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_signed, b_signed} : i32, i32
+  %2 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: i32, i32
+  // CHECK:   nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_signed, b_signed} : vector<4xi8>, vector<4xi8>
+  %3 = nvvm.dp4a %a_vec, %b_vec, %c {a_signed, b_signed}: vector<4xi8>, vector<4xi8>
+  return
+}
+
 // -----
 
 // Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 3a0713f2feee8..4a116f6db37e5 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -844,3 +844,39 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size:
   nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_dp4a_packed
+llvm.func @nvvm_dp4a_packed(%a: i32, %b: i32, %c: i32) {
+  // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.dp4a %a, %b, %c: i32, i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.dp4a %a, %b, %c {a_signed}: i32, i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %2 = nvvm.dp4a %a, %b, %c {b_signed}: i32, i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %3 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: i32, i32
+  llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_dp4a_vec
+llvm.func @nvvm_dp4a_vec(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) {
+  // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+  %0 = nvvm.dp4a %a, %b, %c: vector<4xi8>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+  %1 = nvvm.dp4a %a, %b, %c {a_signed}: vector<4xi8>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+  %2 = nvvm.dp4a %a, %b, %c {b_signed}: vector<4xi8>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+  %3 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: vector<4xi8>, vector<4xi8>
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/139043


More information about the Mlir-commits mailing list