[Mlir-commits] [mlir] [MLIR][NVVM] Update dot.accumulate NVVM Ops (PR #140518)

Srinivasa Ravi llvmlistbot at llvm.org
Mon May 19 02:30:20 PDT 2025


https://github.com/Wolfram70 created https://github.com/llvm/llvm-project/pull/140518

This change:
- Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit dot-product accumulate operation.
- Refactors the recently added dot.accumulate.4way and adds a verifier.

>From 620724cf3b4d036ccb92946bbef52c4fe5ee3b27 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 15 May 2025 16:41:08 +0530
Subject: [PATCH] [MLIR][NVVM] Update dot.accumulate NVVM Ops

This change:
- Adds the dot.accumulate.2way Op to the NVVM dialect for 16-bit to 8-bit
  dot-product accumulate operation.
- Refactors the recently added dot.accumulate.4way and adds a verifier.
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 107 +++++++++++++++++---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  |  68 ++++++++++++-
 mlir/test/Dialect/LLVMIR/nvvm.mlir          |  11 +-
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir |  32 ++++++
 mlir/test/Target/LLVMIR/nvvmir.mlir         |  38 +++++++
 5 files changed, 239 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 654aff71f25be..634251d6a9de1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3445,25 +3445,28 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
 }
 
 //===----------------------------------------------------------------------===//
-// NVVM dot.accumulate.4way Op
+// NVVM dot.accumulate Ops
 //===----------------------------------------------------------------------===//
 
-def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
-def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateS8 : I32EnumAttrCase<"S8", 1, "s8">;
+def DotAccumulateU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateS16 : I32EnumAttrCase<"S16", 2, "s16">;
+def DotAccumulateU16 : I32EnumAttrCase<"U16", 3, "u16">;
 
-def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
-                              "NVVM DotAccumulate4WayType",
-                              [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
+def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
+                              "NVVM DotAccumulateType",
+                              [DotAccumulateS8, DotAccumulateU8, 
+                                DotAccumulateS16, DotAccumulateU16]> {
   let cppNamespace = "::mlir::NVVM";
   let genSpecializedAttr = 0;
 }
 
-def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
+def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
 def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
-  let summary = "Four-way byte dot product-accumulate instruction.";
+  let summary = "Four-way byte dot product-accumulate instruction";
   let description = [{
     Performs a four-way byte dot-product which is accumulated in a 32-bit
     result.
@@ -3481,11 +3484,13 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
   }];
   
+  let hasVerifier = 1;
+  
   let arguments = (ins
     VectorOfLengthAndType<[4], [I8]>:$a,
-    DotAccumulate4WayTypeAttr:$a_type,
+    DotAccumulateTypeAttr:$a_type,
     VectorOfLengthAndType<[4], [I8]>:$b,
-    DotAccumulate4WayTypeAttr:$b_type,
+    DotAccumulateTypeAttr:$b_type,
     I32:$c
   );
 
@@ -3495,8 +3500,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   
   let extraClassDeclaration = [{
     static llvm::Intrinsic::ID
-    getIntrinsicID(NVVM::DotAccumulate4WayType a_type, 
-                   NVVM::DotAccumulate4WayType b_type);
+    getIntrinsicID(NVVM::DotAccumulateType a_type, 
+                   NVVM::DotAccumulateType b_type);
     llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
   }];
 
@@ -3508,6 +3513,84 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   }];
 }
 
+def DotAccumulate2WayModeLo : I32EnumAttrCase<"LO", 0, "lo">;
+def DotAccumulate2WayModeHi : I32EnumAttrCase<"HI", 1, "hi">;
+
+def DotAccumulate2WayMode : I32EnumAttr<"DotAccumulate2WayMode",
+                              "NVVM DotAccumulate2WayMode",
+                              [DotAccumulate2WayModeLo, DotAccumulate2WayModeHi]> {
+  let cppNamespace = "::mlir::NVVM";
+  let genSpecializedAttr = 0;
+}
+
+def DotAccumulate2WayModeAttr : EnumAttr<NVVM_Dialect, DotAccumulate2WayMode, "dot_accumulate_2way_mode"> {
+  let assemblyFormat = "$value";
+}
+
+def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
+  let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
+  let description = [{
+    Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a 
+    32-bit result.
+    Operand `a` is a vector of two 16-bit elements and operand `b` a vector 
+    of four 8-bit elements between which the dot product is computed.
+
+    The `a_type` and `b_type` attributes specify the type of the elements in `a`
+    and `b` respectively.
+    If `a_type` is `s16`, then the elements in `a` are sign-extended to 
+    32-bit before the dot product is computed.
+    If `a_type` is `u16`, then the elements in `a` are zero-extended to 
+    32-bit instead.
+    If `b_type` is `s8`, then the elements in `b` are sign-extended to 
+    32-bit before the dot product is computed.
+    If `b_type` is `u8`, then the elements in `b` are zero-extended to 
+    32-bit instead.
+
+    The 'mode` attribute specifies which two bytes of `b` are used for the dot
+    product. If `mode` is `lo`, then the dot product is computed between `a` 
+    and elements at indices 2 and 3 of `b`. If `mode` is `hi`, then the dot 
+    product is computed between `a` and elements at indices 0 and 1 of `b`.
+    Operand `c` is a 32-bit integer to which the result is accumulated. It is
+    treated as holding a signed integer if any of `a_type` or `b_type` is 
+    signed.
+    
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
+  }];
+
+  let hasVerifier = 1;
+
+  let arguments = (ins
+    DotAccumulate2WayModeAttr:$mode,
+    VectorOfLengthAndType<[2], [I16]>:$a,
+    DotAccumulateTypeAttr:$a_type,
+    VectorOfLengthAndType<[4], [I8]>:$b,
+    DotAccumulateTypeAttr:$b_type,
+    I32:$c
+  );
+
+  let results = (outs I32:$res);
+
+  let assemblyFormat = "$mode $a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
+  
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID
+    getIntrinsicID(NVVM::DotAccumulateType a_type, 
+                   NVVM::DotAccumulateType b_type);
+    llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+    llvm::Value* isHi(NVVM::DotAccumulate2WayMode mode, 
+                            llvm::IRBuilderBase& builder);
+  }];
+  
+  string llvmBuilder = [{
+    llvm::Intrinsic::ID id = NVVM::DotAccumulate2WayOp::getIntrinsicID($a_type, $b_type);
+    llvm::Value* argA = op.getPackedArg($a, builder);
+    llvm::Value* argB = op.getPackedArg($b, builder);
+    $res = createIntrinsicCall(builder, id,
+              {argA, argB, op.isHi($mode, builder), $c}
+            );
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 1ea3f96fa75f5..2b60a34edf313 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1211,6 +1211,46 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
                                llvm::Type::getInt32Ty(builder.getContext()));
 }
 
+LogicalResult NVVM::DotAccumulate4WayOp::verify() {
+  NVVM::DotAccumulateType aType = getAType();
+  NVVM::DotAccumulateType bType = getBType();
+
+  if (aType != NVVM::DotAccumulateType::S8 &&
+      aType != NVVM::DotAccumulateType::U8)
+    return emitOpError("a_type must be S8 or U8");
+  if (bType != NVVM::DotAccumulateType::S8 &&
+      bType != NVVM::DotAccumulateType::U8)
+    return emitOpError("b_type must be S8 or U8");
+
+  return success();
+}
+
+llvm::Value *
+NVVM::DotAccumulate2WayOp::getPackedArg(llvm::Value *arg,
+                                        llvm::IRBuilderBase &builder) {
+  return builder.CreateBitCast(arg,
+                               llvm::Type::getInt32Ty(builder.getContext()));
+}
+
+llvm::Value *NVVM::DotAccumulate2WayOp::isHi(NVVM::DotAccumulate2WayMode mode,
+                                             llvm::IRBuilderBase &builder) {
+  return builder.getInt1(mode == NVVM::DotAccumulate2WayMode::HI);
+}
+
+LogicalResult NVVM::DotAccumulate2WayOp::verify() {
+  NVVM::DotAccumulateType aType = getAType();
+  NVVM::DotAccumulateType bType = getBType();
+
+  if (aType != NVVM::DotAccumulateType::S16 &&
+      aType != NVVM::DotAccumulateType::U16)
+    return emitOpError("a_type must be S16 or U16");
+  if (bType != NVVM::DotAccumulateType::S8 &&
+      bType != NVVM::DotAccumulateType::U8)
+    return emitOpError("b_type must be S8 or U8");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//
@@ -1599,10 +1639,10 @@ static void nvvmInferResultRanges(Operation *op, Value result,
 }
 
 llvm::Intrinsic::ID
-DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
-                                    NVVM::DotAccumulate4WayType b_type) {
-  bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
-  bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
+DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulateType a_type,
+                                    NVVM::DotAccumulateType b_type) {
+  bool is_a_siext = a_type == NVVM::DotAccumulateType::S8;
+  bool is_b_siext = b_type == NVVM::DotAccumulateType::S8;
   unsigned type = (is_a_siext << 1) | is_b_siext;
   switch (type) {
   case 0:
@@ -1618,6 +1658,26 @@ DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
   }
 }
 
+llvm::Intrinsic::ID
+DotAccumulate2WayOp::getIntrinsicID(NVVM::DotAccumulateType a_type,
+                                    NVVM::DotAccumulateType b_type) {
+  bool is_a_siext = a_type == NVVM::DotAccumulateType::S16;
+  bool is_b_siext = b_type == NVVM::DotAccumulateType::S8;
+  unsigned type = (is_a_siext << 1) | is_b_siext;
+  switch (type) {
+  case 0:
+    return llvm::Intrinsic::nvvm_idp2a_u_u;
+  case 1:
+    return llvm::Intrinsic::nvvm_idp2a_u_s;
+  case 2:
+    return llvm::Intrinsic::nvvm_idp2a_s_u;
+  case 3:
+    return llvm::Intrinsic::nvvm_idp2a_s_s;
+  default:
+    llvm_unreachable("Invalid DP2a type");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..5568e104afcab 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -579,7 +579,7 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
 }
 
 // CHECK-LABEL: @dot_accumulate_4way
-func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
   // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
   %1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
   // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
@@ -587,6 +587,15 @@ func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: v
   return
 }
 
+// CHECK-LABEL: @dot_accumulate_2way
+func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) {
+  // CHECK:   nvvm.dot.accumulate.2way lo %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi16>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.2way lo %a_vec <u16>, %b_vec <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK:   nvvm.dot.accumulate.2way hi %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi16>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.2way hi %a_vec <s16>, %b_vec <s8>, %c: vector<2xi16>, vector<4xi8>
+  return
+}
+
 // -----
 
 // Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index accec9c7af4f2..e350d5256b5a6 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -248,3 +248,35 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
   %res = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
   llvm.return
 }
+
+// -----
+
+llvm.func @nvvm_dot_accumulate_4way_invalid_type_a(%a_vec : vector<4xi8>, %b_vec : vector<4xi8>, %c : i32) {
+  // expected-error @below {{a_type must be S8 or U8}}
+  %res = nvvm.dot.accumulate.4way %a_vec <u16>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_dot_accumulate_4way_invalid_type_b(%a_vec : vector<4xi8>, %b_vec : vector<4xi8>, %c : i32) {
+  // expected-error @below {{b_type must be S8 or U8}}
+  %res = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u16>, %c: vector<4xi8>, vector<4xi8>
+  llvm.return
+}
+
+// ----
+
+llvm.func @nvvm_dot_accumulate_2way_invalid_type_a(%a_vec : vector<2xi16>, %b_vec : vector<4xi8>, %c : i32) {
+  // expected-error @below {{a_type must be S16 or U16}}
+  %res = nvvm.dot.accumulate.2way lo %a_vec <u8>, %b_vec <u8>, %c: vector<2xi16>, vector<4xi8>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_dot_accumulate_2way_invalid_type_b(%a_vec : vector<2xi16>, %b_vec : vector<4xi8>, %c : i32) {
+  // expected-error @below {{b_type must be S8 or U8}}
+  %res = nvvm.dot.accumulate.2way lo %a_vec <u16>, %b_vec <u16>, %c: vector<2xi16>, vector<4xi8>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 894b72733a46a..4bd9326da2233 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -866,3 +866,41 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
   %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_dot_accumulate_2way
+llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %0 = nvvm.dot.accumulate.2way lo %a <u16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %1 = nvvm.dot.accumulate.2way hi %a <u16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %2 = nvvm.dot.accumulate.2way lo %a <s16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %3 = nvvm.dot.accumulate.2way hi %a <s16>, %b <u8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %4 = nvvm.dot.accumulate.2way lo %a <u16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %5 = nvvm.dot.accumulate.2way hi %a <u16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %6 = nvvm.dot.accumulate.2way lo %a <s16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %7 = nvvm.dot.accumulate.2way hi %a <s16>, %b <s8>, %c: vector<2xi16>, vector<4xi8>  
+  llvm.return
+}



More information about the Mlir-commits mailing list