[Mlir-commits] [mlir] [MLIR][NVVM] Add tcgen05.mma MLIR Ops (PR #164356)

Guray Ozen llvmlistbot at llvm.org
Mon Oct 27 09:41:53 PDT 2025


================
@@ -2694,6 +2706,587 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
   return {intrinsicID, args};
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                                    llvm::IRBuilderBase &builder) {
+
+  auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+  llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+  const bool isATensor = isa<llvm::PointerType>(A->getType());
+  args.push_back(A);
+
+  args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+  args.push_back(mt.lookupValue(thisOp.getIdesc()));
+  args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+  // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift];
+  static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2][2][2][2] = {
+      // without diable output lane
+      {// without scale input D
+       {
+           // shared
+           {// cg1
+            {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+            // cg2
+            {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}},
+           {// tensor
+            {
+                // cg1
+                llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+                llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+            },
+            {
+                // cg2
+                llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+                llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+            }},
+       },
+       // with scale input D
+       { // shared
+        {// cg1
+         {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+         // cg2
+         {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}},
+        {// tensor
+         {
+             // cg1
+             llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+             llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+         },
+         {
+             // cg2
+             llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+             llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+         }}}},
+      // with disable output lane
+      {  // without scale input D
+       { // shared
+        {// cg1
+         {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+          notIntrinsic},
+         // cg2
+         {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
+          notIntrinsic}},
+        {// cg1
+         {
+             llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+             llvm::Intrinsic::
+                 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+         },
+         // cg2
+         {
+             llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
+             llvm::Intrinsic::
+                 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
+         }}},
+       // with scale input D
+       { // shared
+        {// cg1
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+          notIntrinsic},
+         // cg2
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
+          notIntrinsic}},
+        // tensor
+        {// cg1
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
+         // cg2
+         {
+             llvm::Intrinsic::
+                 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
+             llvm::Intrinsic::
+                 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
+         }}}}};
+
+  llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+  bool hasScaleInputD = ScaleInputD != nullptr;
+
+  llvm::Value *DisableOutputLane =
+      mt.lookupValue(thisOp.getDisableOutputLane());
+  bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+  const unsigned ctaGroup =
+      static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+  llvm::Intrinsic::ID ID =
+      tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+                   [ctaGroup - 1][thisOp.getAShift()];
+
+  assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
+
+  if (hasScaleInputD)
+    args.push_back(ScaleInputD);
+
+  if (hasDisableOutputLane)
+    args.push_back(DisableOutputLane);
+
+  args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+  if (!hasDisableOutputLane)
+    args.push_back(builder.getInt32(ctaGroup));
+
+  args.push_back(
+      builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+  return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
+                   NVVM::CTAGroupKind ctaGroup, bool hasAShift,
+                   NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
+  LogicalResult res = success();
+
+  if (disableOutputLane) {
+    mlir::VectorType disableOutputLaneType =
+        cast<mlir::VectorType>(disableOutputLane.getType());
+    if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
+         disableOutputLaneType.getNumElements() != 4) ||
+        (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
+         disableOutputLaneType.getNumElements() != 8))
+      res = emitError(loc) << "Disable Output Lane of length "
+                           << disableOutputLaneType.getNumElements()
+                           << " is incompatible with CtaGroupAttr";
+  }
+
+  if (hasAShift && !isATensor)
+    res = emitError(
+        loc, "A-shift can be applied only when matrix A is in tensor memory");
+
+  if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
+                            collectorOp == Tcgen05MMACollectorOp::USE))
+    res = emitError(
+        loc, "Cannot use collector buffer operation fill or use with ashift");
+  return res;
+}
+
+LogicalResult Tcgen05MMAOp::verify() {
+  return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+                            getDisableOutputLane(), getCtaGroup(), getAShift(),
+                            getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.sp functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+  auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+  llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+  bool isATensor = isa<llvm::PointerType>(A->getType());
+  args.push_back(A);
+
+  args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+  args.push_back(mt.lookupValue(thisOp.getIdesc()));
+  args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+  args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+
+  // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift];
+  static constexpr llvm::Intrinsic::ID tcgen05MMASparseIDs[2][2][2][2][2] = {
----------------
grypp wrote:

Same comment goes for the creating table in the same style

https://github.com/llvm/llvm-project/pull/164356


More information about the Mlir-commits mailing list