[Mlir-commits] [mlir] 4e4273c - [MLIR][NVVM] Add dot.accumulate.2way Op (#140518)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 4 00:59:49 PDT 2025


Author: Srinivasa Ravi
Date: 2025-06-04T13:29:46+05:30
New Revision: 4e4273c9409dfbbfb42ca74468eaf9bd843bc376

URL: https://github.com/llvm/llvm-project/commit/4e4273c9409dfbbfb42ca74468eaf9bd843bc376
DIFF: https://github.com/llvm/llvm-project/commit/4e4273c9409dfbbfb42ca74468eaf9bd843bc376.diff

LOG: [MLIR][NVVM] Add dot.accumulate.2way Op (#140518)

This change adds the `dot.accumulate.2way` Op to the NVVM dialect for
16-bit to 8-bit dot-product accumulate operation.

PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 2424e3af80d2d..596a584d485ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3703,6 +3703,60 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   }];
 }
 
+def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
+  let summary = "Two-way 16-bit to 8-bit dot product-accumulate instruction";
+  let description = [{
+    Performs a two-way 16-bit to 8-bit dot-product which is accumulated in a 
+    32-bit result.
+    Operand `a` is a vector of two 16-bit elements and operand `b` a vector 
+    of four 8-bit elements between which the dot product is computed.
+
+    The `a_type` and `b_type` attributes specify the type of the elements in `a`
+    and `b` respectively.
+    If `a_type` or `b_type` is `s`, then the elements in the corresponding 
+    vector are sign-extended to 32-bit before the dot product is computed.
+    If `a_type` or `b_type` is `u`, then the elements in the corresponding 
+    vector are zero-extended to 32-bit instead.
+
+    The `b_hi` boolean attribute specifies which two bytes of `b` are used for 
+    the dot product. If `b_hi` is true, then the dot product is computed 
+    between  `a` and elements at indices 2 and 3 of `b`. If `b_hi` is false, 
+    then the dot product is computed between `a` and elements at indices 0 and 
+    1 of `b`.
+
+    Operand `c` is a 32-bit integer to which the result is accumulated. It is
+    treated as holding a signed integer if any of `a_type` or `b_type` is 
+    signed.
+    
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp2a)
+  }];
+
+  let arguments = (ins
+    VectorOfLengthAndType<[2], [I16]>:$a,
+    DotAccumulateTypeAttr:$a_type,
+    VectorOfLengthAndType<[4], [I8]>:$b,
+    DotAccumulateTypeAttr:$b_type,
+    I32:$c,
+    BoolAttr:$b_hi
+  );
+
+  let results = (outs I32:$res);
+
+  let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
+  
+  let extraClassDeclaration = [{
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase &builder);
+  }];
+  
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
+                        *op, moduleTranslation, builder);
+    $res = createIntrinsicCall(builder, id, args);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 648b6b087e592..a77ff1e32dc23 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1712,6 +1712,28 @@ NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
   return {ids[type], args};
 }
 
+NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
+  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
+  args.push_back(builder.getInt1(curOp.getBHi()));
+  args.push_back(mt.lookupValue(curOp.getC()));
+
+  bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
+  bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
+  unsigned type = (isASigned << 1) | isBSigned;
+  const llvm::Intrinsic::ID ids[] = {
+      llvm::Intrinsic::nvvm_idp2a_u_u,
+      llvm::Intrinsic::nvvm_idp2a_u_s,
+      llvm::Intrinsic::nvvm_idp2a_s_u,
+      llvm::Intrinsic::nvvm_idp2a_s_s,
+  };
+  return {ids[type], args};
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 77b302155cb12..a02d33f50e0d2 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -587,6 +587,15 @@ func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i
   return
 }
 
+// CHECK-LABEL: @dot_accumulate_2way
+func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c: i32) {
+  // CHECK:   nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {b_hi = false} : vector<2xi16>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.2way %a_vec <unsigned>, %b_vec <unsigned>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
+  // CHECK:   nvvm.dot.accumulate.2way %{{.*}}, %{{.*}}, %{{.*}} {b_hi = true} : vector<2xi16>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.2way %a_vec <signed>, %b_vec <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
+  return
+}
+
 // -----
 
 // Just check these don't emit errors.

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index e892fc43f4a39..660d0a22dce9c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -866,3 +866,41 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
   %3 = nvvm.dot.accumulate.4way %a <signed>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_dot_accumulate_2way
+llvm.func @nvvm_dot_accumulate_2way(%a: vector<2xi16>, %b: vector<4xi8>, %c: i32) {
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %0 = nvvm.dot.accumulate.2way %a <unsigned>, %b <unsigned>, %c {b_hi = false} : vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %1 = nvvm.dot.accumulate.2way %a <unsigned>, %b <unsigned>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %2 = nvvm.dot.accumulate.2way %a <signed>, %b <unsigned>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %3 = nvvm.dot.accumulate.2way %a <signed>, %b <unsigned>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %4 = nvvm.dot.accumulate.2way %a <unsigned>, %b <signed>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %5 = nvvm.dot.accumulate.2way %a <unsigned>, %b <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 false, i32 %{{.*}})
+  %6 = nvvm.dot.accumulate.2way %a <signed>, %b <signed>, %c {b_hi = false}: vector<2xi16>, vector<4xi8>
+  // CHECK: %[[a_cast:.*]] = bitcast <2 x i16> %{{.*}} to i32
+  // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+  // CHECK: call i32 @llvm.nvvm.idp2a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i1 true, i32 %{{.*}})
+  %7 = nvvm.dot.accumulate.2way %a <signed>, %b <signed>, %c {b_hi = true}: vector<2xi16>, vector<4xi8>
+  llvm.return
+}


        


More information about the Mlir-commits mailing list