[Mlir-commits] [mlir] [MLIR][NVVM] Add tcgen05.mma MLIR Ops (PR #164356)

Rajat Bajpai llvmlistbot at llvm.org
Tue Oct 21 05:14:45 PDT 2025


================
@@ -2557,6 +2569,525 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
   return {intrinsicID, args};
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                                    llvm::IRBuilderBase &builder) {
+
+  auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  args.push_back(mt.lookupValue(thisOp.getD()));
+  llvm::Value *A = mt.lookupValue(thisOp.getA());
+  args.push_back(A);
+  args.push_back(mt.lookupValue(thisOp.getB()));
+  args.push_back(mt.lookupValue(thisOp.getIdesc()));
+  args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+  llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+  bool hasScaleInputD = ScaleInputD != nullptr;
+  llvm::Value *DisableOutputLane =
+      mt.lookupValue(thisOp.getDisableOutputLane());
+  bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+  unsigned ctaGroup =
+      static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+  bool isATensor = isa<llvm::PointerType>(A->getType());
+  bool enableAshift = thisOp.getAshift();
+
+  llvm::Intrinsic::ID ID = [&]() {
+    // [isATensor][enableAshift]
+    static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2] = {
+        // shared
+        {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+        // tensor
+        {llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+         llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift}};
+
+    // Scaled [isATensor][enableAshift]
+    static constexpr llvm::Intrinsic::ID scaledIDs[2][2] = {
+        // shared
+        {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+        // tensor
+        {llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+         llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift}};
+
+    // Scaled + disable output lane [isATensor][enableAshift][ctaGroup-1]
+    static constexpr llvm::Intrinsic::ID disableOutputLaneIDs[2][2][2] = {
+        // shared
+        {{llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+          llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2},
+         {notIntrinsic, notIntrinsic}},
+        // tensor
+        {{llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+          llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2},
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift}}};
+
+    // Scaled + disable output lane [isATensor][enableAshift][ctaGroup-1]
+    static constexpr llvm::Intrinsic::ID scaledDisableOutputLaneIDs[2][2][2] = {
+        // shared
+        {{llvm::Intrinsic::
+              nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2},
+         {notIntrinsic, notIntrinsic}},
+        // tensor
+        {{llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2},
+         {llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift,
+          llvm::Intrinsic::
+              nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift}}};
+
+    if (hasDisableOutputLane) {
+      if (hasScaleInputD) {
+        args.push_back(ScaleInputD);
+        ID = scaledDisableOutputLaneIDs[isATensor][enableAshift][ctaGroup - 1];
+      } else
+        ID = disableOutputLaneIDs[isATensor][enableAshift][ctaGroup - 1];
+      args.push_back(DisableOutputLane);
+      args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+    } else {
+      if (hasScaleInputD) {
+        args.push_back(ScaleInputD);
+        ID = scaledIDs[isATensor][enableAshift];
+      } else
+        ID = tcgen05MMAIDs[isATensor][enableAshift];
+      args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+      args.push_back(builder.getInt32(ctaGroup));
+    }
+    return ID;
+  }();
----------------
rajatbajpai wrote:

I believe we can add few more dimensions to simplify the conditional logic little bit. For example,

```suggestion
// [hasDisableOutputLane][hasScaleInputD][isATensor][enableAshift][ctaGroup-1]
static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2][2][2][2] = {
    { // without disable output lane
        { // without scaled input
            { // shared
                {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
                {notIntrinsic, notIntrinsic}
            },
            { // tensor
                {llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
                 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift},
                {notIntrinsic, notIntrinsic}
            }
        },
        { // with scaled input
            { // shared
                {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
                {notIntrinsic, notIntrinsic}
            },
            { // tensor
                {llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
                 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift},
                {notIntrinsic, notIntrinsic}
            }
        }
    },
    { // with disable output lane
        { // without scaled input
            { // shared
                {llvm::Intrinsic::
                     nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
                 llvm::Intrinsic::
                     nvvm_tcgen05_mma_shared_disable_output_lane_cg2},
                {notIntrinsic, notIntrinsic}
            },
            { // tensor
                {llvm::Intrinsic::
                     nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
                 llvm::Intrinsic::
                     nvvm_tcgen05_mma_tensor_disable_output_lane_cg2},
                {llvm::Intrinsic::
                     nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
                 llvm::Intrinsic::
                     nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift}
            }
        },
        { // with scaled input
            { // shared
                {llvm::Intrinsic::
                     nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
                 llvm::Intrinsic::
                     nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2},
                {notIntrinsic, notIntrinsic}
            },
            { // tensor
                {llvm::Intrinsic::
                     nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
                 llvm::Intrinsic::
                     nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2},
                {llvm::Intrinsic::
                     nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift,
                 llvm::Intrinsic::
                     nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift}
            }
        }
    }
};

if (hasScaleInputD)
  args.push_back(ScaleInputD);

int ctaGroupIdx = hasDisableOutputLane ? (ctaGroup - 1) : 0;
llvm::Intrinsic::ID ID = tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD]
                                     [isATensor][enableAshift][ctaGroupIdx];

if (hasDisableOutputLane) {
  args.push_back(DisableOutputLane);
  args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
} else {
  args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
  args.push_back(builder.getInt32(ctaGroup));
}
```

https://github.com/llvm/llvm-project/pull/164356


More information about the Mlir-commits mailing list