[Mlir-commits] [mlir] [MLIR][NVVM] Add tcgen05.mma MLIR Ops (PR #164356)

Rajat Bajpai llvmlistbot at llvm.org
Tue Oct 21 22:08:31 PDT 2025


================
@@ -2557,6 +2569,525 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
   return {intrinsicID, args};
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                                    llvm::IRBuilderBase &builder) {
+
+  auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  args.push_back(mt.lookupValue(thisOp.getD()));
+  llvm::Value *A = mt.lookupValue(thisOp.getA());
+  args.push_back(A);
+  args.push_back(mt.lookupValue(thisOp.getB()));
+  args.push_back(mt.lookupValue(thisOp.getIdesc()));
+  args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+  llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+  bool hasScaleInputD = ScaleInputD != nullptr;
+  llvm::Value *DisableOutputLane =
+      mt.lookupValue(thisOp.getDisableOutputLane());
+  bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+  unsigned ctaGroup =
+      static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+  bool isATensor = isa<llvm::PointerType>(A->getType());
+  bool enableAshift = thisOp.getAshift();
+
+  llvm::Intrinsic::ID ID = [&]() {
+    // [isATensor][enableAshift]
+    static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2] = {
+        // shared
+        {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+        // tensor
+        {llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+         llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift}};
+
+    // Scaled [isATensor][enableAshift]
+    static constexpr llvm::Intrinsic::ID scaledIDs[2][2] = {
+        // shared
+        {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+        // tensor
+        {llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+         llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift}};
+
+    // Scaled + disable output lane [isATensor][enableAshift][ctaGroup-1]
+    static constexpr llvm::Intrinsic::ID disableOutputLaneIDs[2][2][2] = {
+        // shared
+        {{llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+          llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2},
+         {notIntrinsic, notIntrinsic}},
+        // tensor
+        {{llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+          llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2},
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift}}};
+
+    // Scaled + disable output lane [isATensor][enableAshift][ctaGroup-1]
+    static constexpr llvm::Intrinsic::ID scaledDisableOutputLaneIDs[2][2][2] = {
+        // shared
+        {{llvm::Intrinsic::
+              nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2},
+         {notIntrinsic, notIntrinsic}},
+        // tensor
+        {{llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2},
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift}}};
+
+    if (hasDisableOutputLane) {
+      if (hasScaleInputD) {
+        args.push_back(ScaleInputD);
+        ID = scaledDisableOutputLaneIDs[isATensor][enableAshift][ctaGroup - 1];
+      } else
+        ID = disableOutputLaneIDs[isATensor][enableAshift][ctaGroup - 1];
+      args.push_back(DisableOutputLane);
+      args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+    } else {
+      if (hasScaleInputD) {
+        args.push_back(ScaleInputD);
+        ID = scaledIDs[isATensor][enableAshift];
+      } else
+        ID = tcgen05MMAIDs[isATensor][enableAshift];
+      args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+      args.push_back(builder.getInt32(ctaGroup));
+    }
+    return ID;
+  }();
----------------
rajatbajpai wrote:

I believe you meant 96 to 128 bytes. We're not currently pressed for stack space, and we can always optimize this later if the need arises. Meanwhile, this approach offers more readable logic.

https://github.com/llvm/llvm-project/pull/164356


More information about the Mlir-commits mailing list