[Mlir-commits] [mlir] [MLIR][NVVM] Add tcgen05.mma MLIR Ops (PR	#164356)
    Rajat Bajpai 
    llvmlistbot at llvm.org
       
    Tue Oct 21 22:08:31 PDT 2025
    
    
  
================
@@ -2557,6 +2569,525 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
   return {intrinsicID, args};
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                                    llvm::IRBuilderBase &builder) {
+
+  auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  args.push_back(mt.lookupValue(thisOp.getD()));
+  llvm::Value *A = mt.lookupValue(thisOp.getA());
+  args.push_back(A);
+  args.push_back(mt.lookupValue(thisOp.getB()));
+  args.push_back(mt.lookupValue(thisOp.getIdesc()));
+  args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+  llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+  bool hasScaleInputD = ScaleInputD != nullptr;
+  llvm::Value *DisableOutputLane =
+      mt.lookupValue(thisOp.getDisableOutputLane());
+  bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+  unsigned ctaGroup =
+      static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+  bool isATensor = isa<llvm::PointerType>(A->getType());
+  bool enableAshift = thisOp.getAshift();
+
+  llvm::Intrinsic::ID ID = [&]() {
+    // [isATensor][enableAshift]
+    static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2] = {
+        // shared
+        {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+        // tensor
+        {llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+         llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift}};
+
+    // Scaled [isATensor][enableAshift]
+    static constexpr llvm::Intrinsic::ID scaledIDs[2][2] = {
+        // shared
+        {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+        // tensor
+        {llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+         llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift}};
+
+    // Scaled + disable output lane [isATensor][enableAshift][ctaGroup-1]
+    static constexpr llvm::Intrinsic::ID disableOutputLaneIDs[2][2][2] = {
+        // shared
+        {{llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+          llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2},
+         {notIntrinsic, notIntrinsic}},
+        // tensor
+        {{llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+          llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2},
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift}}};
+
+    // Scaled + disable output lane [isATensor][enableAshift][ctaGroup-1]
+    static constexpr llvm::Intrinsic::ID scaledDisableOutputLaneIDs[2][2][2] = {
+        // shared
+        {{llvm::Intrinsic::
+              nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2},
+         {notIntrinsic, notIntrinsic}},
+        // tensor
+        {{llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2},
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift}}};
+
+    if (hasDisableOutputLane) {
+      if (hasScaleInputD) {
+        args.push_back(ScaleInputD);
+        ID = scaledDisableOutputLaneIDs[isATensor][enableAshift][ctaGroup - 1];
+      } else
+        ID = disableOutputLaneIDs[isATensor][enableAshift][ctaGroup - 1];
+      args.push_back(DisableOutputLane);
+      args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+    } else {
+      if (hasScaleInputD) {
+        args.push_back(ScaleInputD);
+        ID = scaledIDs[isATensor][enableAshift];
+      } else
+        ID = tcgen05MMAIDs[isATensor][enableAshift];
+      args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+      args.push_back(builder.getInt32(ctaGroup));
+    }
+    return ID;
+  }();
----------------
rajatbajpai wrote:
I believe you meant 96 to 128 bytes. We're not currently pressed for stack space, and we can always optimize this later if the need arises. Meanwhile, this approach offers more readable logic.
https://github.com/llvm/llvm-project/pull/164356
    
    
More information about the Mlir-commits
mailing list