[Mlir-commits] [mlir] [MLIR][NVVM] Add support for dp4a instructions (PR #139043)
Srinivasa Ravi
llvmlistbot at llvm.org
Thu May 8 03:19:23 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/139043
>From 2b259fc10478635c223c86b6e24bbb7bbc6ba6b9 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 30 Apr 2025 15:31:40 +0530
Subject: [PATCH] [MLIR][NVVM] Add support for dp4a instructions
This change adds the `dp4a` Op to the NVVM dialect to perform four-way
byte dot product-accumulate operation.
For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 47 +++++++++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 15 +++++++
mlir/test/Dialect/LLVMIR/nvvm.mlir | 9 ++++
mlir/test/Target/LLVMIR/nvvmir.mlir | 22 ++++++++++
4 files changed, 93 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6540273b216e3..1329a0035d178 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3444,6 +3444,53 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// NVVM dp4a Op
+//===----------------------------------------------------------------------===//
+
+def NVVM_Dp4aOp : NVVM_Op<"dp4a"> {
+ let summary = "Four-way byte dot product-accumulate instruction.";
+ let description = [{
+ Performs a four-way byte dot-product which is accumulated in a 32-bit
+ result.
+ Operand `a` and `b` are vectors of 4 bytes between which the dot product is
+ computed.
+ By default, the byte inputs are zero-extended to 32-bit before the dot
+ product is computed. The `a_siext` and `b_siext` unit attributes can be
+ used to mention that the individual byte inputs in the corresponding
+ operand are signed and need to be sign-extended instead.
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
+ treated as holding a signed integer if any of `a` or `b` are to be
+ sign-extended.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
+ }];
+
+ let arguments = (ins
+ VectorOfLengthAndType<[4], [I8]>:$a,
+ VectorOfLengthAndType<[4], [I8]>:$b,
+ I32:$c,
+ DefaultValuedAttr<UnitAttr, "false">:$a_siext,
+ DefaultValuedAttr<UnitAttr, "false">:$b_siext
+ );
+
+ let results = (outs I32:$res);
+
+ let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($a) `,` type($b)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(bool a_siext, bool b_siext);
+ llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+ }];
+
+ string llvmBuilder = [{
+ llvm::Intrinsic::ID id = NVVM::Dp4aOp::getIntrinsicID($a_siext, $b_siext);
+ llvm::Value* argA = op.getPackedArg($a, builder);
+ llvm::Value* argB = op.getPackedArg($b, builder);
+ $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3c3731a63e268..96b954b0576d7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -34,6 +34,7 @@
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1203,6 +1204,12 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}
+llvm::Value *NVVM::Dp4aOp::getPackedArg(llvm::Value *arg,
+ llvm::IRBuilderBase &builder) {
+ return builder.CreateBitCast(arg,
+ llvm::Type::getInt32Ty(builder.getContext()));
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
@@ -1590,6 +1597,14 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+#define GET_DP4A_ID(a_sign, is_b_siext) \
+ is_b_siext ? llvm::Intrinsic::nvvm_idp4a_##a_sign##_s \
+ : llvm::Intrinsic::nvvm_idp4a_##a_sign##_u
+
+llvm::Intrinsic::ID Dp4aOp::getIntrinsicID(bool a_siext, bool b_siext) {
+ return a_siext ? GET_DP4A_ID(s, b_siext) : GET_DP4A_ID(u, b_siext);
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index d3915492c38a0..d38db8a085e8b 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -578,6 +578,15 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
return
}
+// CHECK-LABEL: @dp4a
+func.func @dp4a(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dp4a %a_vec, %b_vec, %c: vector<4xi8>, vector<4xi8>
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_siext, b_siext} : vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dp4a %a_vec, %b_vec, %c {a_siext, b_siext}: vector<4xi8>, vector<4xi8>
+ return
+}
+
// -----
// Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 3a0713f2feee8..6325f7e3aa14b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -844,3 +844,25 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size:
nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
llvm.return
}
+
+// -----
+// CHECK-LABEL: @nvvm_dp4a
+llvm.func @nvvm_dp4a(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) {
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %0 = nvvm.dp4a %a, %b, %c: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %1 = nvvm.dp4a %a, %b, %c {a_siext}: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %2 = nvvm.dp4a %a, %b, %c {b_siext}: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %3 = nvvm.dp4a %a, %b, %c {a_siext, b_siext}: vector<4xi8>, vector<4xi8>
+ llvm.return
+}
More information about the Mlir-commits
mailing list