[Mlir-commits] [mlir] [MLIR][NVVM] Add tcgen05.mma MLIR Ops (PR #164356)
Rajat Bajpai
llvmlistbot at llvm.org
Tue Oct 21 05:14:45 PDT 2025
================
@@ -2557,6 +2569,525 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
return {intrinsicID, args};
}
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getD()));
+ llvm::Value *A = mt.lookupValue(thisOp.getA());
+ args.push_back(A);
+ args.push_back(mt.lookupValue(thisOp.getB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ bool enableAshift = thisOp.getAshift();
+
+ llvm::Intrinsic::ID ID = [&]() {
+ // [isATensor][enableAshift]
+ static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2] = {
+ // shared
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+ // tensor
+ {llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift}};
+
+ // Scaled [isATensor][enableAshift]
+ static constexpr llvm::Intrinsic::ID scaledIDs[2][2] = {
+ // shared
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+ // tensor
+ {llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift}};
+
+ // Scaled + disable output lane [isATensor][enableAshift][ctaGroup-1]
+ static constexpr llvm::Intrinsic::ID disableOutputLaneIDs[2][2][2] = {
+ // shared
+ {{llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+ llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2},
+ {notIntrinsic, notIntrinsic}},
+ // tensor
+ {{llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2},
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift}}};
+
+ // Scaled + disable output lane [isATensor][enableAshift][ctaGroup-1]
+ static constexpr llvm::Intrinsic::ID scaledDisableOutputLaneIDs[2][2][2] = {
+ // shared
+ {{llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2},
+ {notIntrinsic, notIntrinsic}},
+ // tensor
+ {{llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2},
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift}}};
+
+ if (hasDisableOutputLane) {
+ if (hasScaleInputD) {
+ args.push_back(ScaleInputD);
+ ID = scaledDisableOutputLaneIDs[isATensor][enableAshift][ctaGroup - 1];
+ } else
+ ID = disableOutputLaneIDs[isATensor][enableAshift][ctaGroup - 1];
+ args.push_back(DisableOutputLane);
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+ } else {
+ if (hasScaleInputD) {
+ args.push_back(ScaleInputD);
+ ID = scaledIDs[isATensor][enableAshift];
+ } else
+ ID = tcgen05MMAIDs[isATensor][enableAshift];
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+ args.push_back(builder.getInt32(ctaGroup));
+ }
+ return ID;
+ }();
----------------
rajatbajpai wrote:
I believe we can add few more dimensions to simplify the conditional logic little bit. For example,
```suggestion
// [hasDisableOutputLane][hasScaleInputD][isATensor][enableAshift][ctaGroup-1]
static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2][2][2][2] = {
{ // without disable output lane
{ // without scaled input
{ // shared
{llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
{notIntrinsic, notIntrinsic}
},
{ // tensor
{llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift},
{notIntrinsic, notIntrinsic}
}
},
{ // with scaled input
{ // shared
{llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
{notIntrinsic, notIntrinsic}
},
{ // tensor
{llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift},
{notIntrinsic, notIntrinsic}
}
}
},
{ // with disable output lane
{ // without scaled input
{ // shared
{llvm::Intrinsic::
nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
llvm::Intrinsic::
nvvm_tcgen05_mma_shared_disable_output_lane_cg2},
{notIntrinsic, notIntrinsic}
},
{ // tensor
{llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_disable_output_lane_cg2},
{llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift}
}
},
{ // with scaled input
{ // shared
{llvm::Intrinsic::
nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
llvm::Intrinsic::
nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2},
{notIntrinsic, notIntrinsic}
},
{ // tensor
{llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2},
{llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift,
llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift}
}
}
}
};
if (hasScaleInputD)
args.push_back(ScaleInputD);
int ctaGroupIdx = hasDisableOutputLane ? (ctaGroup - 1) : 0;
llvm::Intrinsic::ID ID = tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD]
[isATensor][enableAshift][ctaGroupIdx];
if (hasDisableOutputLane) {
args.push_back(DisableOutputLane);
args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
} else {
args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
args.push_back(builder.getInt32(ctaGroup));
}
```
https://github.com/llvm/llvm-project/pull/164356
More information about the Mlir-commits
mailing list