[Mlir-commits] [mlir] [MLIR][NVVM] Add tcgen05.mma MLIR Ops (PR #164356)
Guray Ozen
llvmlistbot at llvm.org
Mon Oct 27 09:42:14 PDT 2025
================
@@ -2694,6 +2706,587 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
return {intrinsicID, args};
}
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair
+Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ const bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+
+ // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift];
+ static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2][2][2][2] = {
+ // without diable output lane
+ {// without scale input D
+ {
+ // shared
+ {// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}},
+ {// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
+ }},
+ },
+ // with scale input D
+ { // shared
+ {// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}},
+ {// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
+ }}}},
+ // with disable output lane
+ { // without scale input D
+ { // shared
+ {// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
+ notIntrinsic}},
+ {// cg1
+ {
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
+ },
+ // cg2
+ {
+ llvm::Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
+ }}},
+ // with scale input D
+ { // shared
+ {// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
+ notIntrinsic}},
+ // tensor
+ {// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
+ }}}}};
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ const unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ llvm::Intrinsic::ID ID =
+ tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+ [ctaGroup - 1][thisOp.getAShift()];
+
+ assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");
+
+ if (hasScaleInputD)
+ args.push_back(ScaleInputD);
+
+ if (hasDisableOutputLane)
+ args.push_back(DisableOutputLane);
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+ if (!hasDisableOutputLane)
+ args.push_back(builder.getInt32(ctaGroup));
+
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane,
+ NVVM::CTAGroupKind ctaGroup, bool hasAShift,
+ NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) {
+ LogicalResult res = success();
+
+ if (disableOutputLane) {
+ mlir::VectorType disableOutputLaneType =
+ cast<mlir::VectorType>(disableOutputLane.getType());
+ if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
+ disableOutputLaneType.getNumElements() != 4) ||
+ (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
+ disableOutputLaneType.getNumElements() != 8))
+ res = emitError(loc) << "Disable Output Lane of length "
+ << disableOutputLaneType.getNumElements()
+ << " is incompatible with CtaGroupAttr";
+ }
+
+ if (hasAShift && !isATensor)
+ res = emitError(
+ loc, "A-shift can be applied only when matrix A is in tensor memory");
+
+ if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
+ collectorOp == Tcgen05MMACollectorOp::USE))
+ res = emitError(
+ loc, "Cannot use collector buffer operation fill or use with ashift");
+ return res;
+}
+
+LogicalResult Tcgen05MMAOp::verify() {
+ return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+ getDisableOutputLane(), getCtaGroup(), getAShift(),
+ getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.sp functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));
+
+ // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift];
+ static constexpr llvm::Intrinsic::ID tcgen05MMASparseIDs[2][2][2][2][2] = {
+ // without diable output lane
+ {// without scale input D
+ {
+ // shared
+ {// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}},
+ {// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
+ }},
+ },
+ // with scale input D
+ { // shared
+ {// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, notIntrinsic}},
+ {// tensor
+ {
+ // cg1
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
+ },
+ {
+ // cg2
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
+ llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
+ }}}},
+ // with disable output lane
+ { // without scale input D
+ { // shared
+ {// cg1
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
+ notIntrinsic}},
+ {// cg1
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
+ },
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
+ }}},
+ // with scale input D
+ { // shared
+ {// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
+ notIntrinsic},
+ // cg2
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
+ notIntrinsic}},
+ // tensor
+ {// cg1
+ {llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
+ // cg2
+ {
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
+ llvm::Intrinsic::
+ nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
+ }}}}};
+
+ llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD());
+ bool hasScaleInputD = ScaleInputD != nullptr;
+
+ llvm::Value *DisableOutputLane =
+ mt.lookupValue(thisOp.getDisableOutputLane());
+ bool hasDisableOutputLane = DisableOutputLane != nullptr;
+
+ unsigned ctaGroup =
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));
+
+ llvm::Intrinsic::ID ID =
+ tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
+ [ctaGroup - 1][thisOp.getAShift()];
+
+ assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp.");
+
+ if (hasScaleInputD)
+ args.push_back(ScaleInputD);
+
+ if (hasDisableOutputLane)
+ args.push_back(DisableOutputLane);
+
+ args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
+
+ if (!hasDisableOutputLane)
+ args.push_back(builder.getInt32(ctaGroup));
+
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ return {ID, args};
+}
+
+LogicalResult Tcgen05MMASparseOp::verify() {
+ return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()),
+ getDisableOutputLane(), getCtaGroup(), getAShift(),
+ getCollectorOp(), getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.block_scale functions
+//===----------------------------------------------------------------------===//
+
+mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+
+ auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixD()));
+
+ llvm::Value *A = mt.lookupValue(thisOp.getMatrixA());
+ bool isATensor = isa<llvm::PointerType>(A->getType());
+ args.push_back(A);
+
+ args.push_back(mt.lookupValue(thisOp.getMatrixB()));
+ args.push_back(mt.lookupValue(thisOp.getIdesc()));
+ args.push_back(mt.lookupValue(thisOp.getEnableInputD()));
+ args.push_back(mt.lookupValue(thisOp.getScaleA()));
+ args.push_back(mt.lookupValue(thisOp.getScaleB()));
+ args.push_back(builder.getInt32(
+ static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()))));
+ args.push_back(
+ builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
+
+ auto kind = thisOp.getKind();
+ auto blockScale = thisOp.getBlockScale();
+ llvm::Intrinsic::ID ID = [&]() {
+ if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
+ return isATensor
+ ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
+ : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
+ }
+ } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
+
+ } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
+ return isATensor
+ ? llvm::Intrinsic::
+ nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
+ : llvm::Intrinsic::
+ nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
+ }
+ }
+ llvm_unreachable("Invalid tcgen05.mma.block_scale attributes");
+ }();
+
+ return {ID, args};
+}
+
+static LogicalResult
+verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp,
+ NVVM::Tcgen05MMABlockScaleKind kind,
+ NVVM::Tcgen05MMABlockScale blockScale,
+ Location loc) {
+ LogicalResult res = success();
+
+ if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
+ kind == Tcgen05MMABlockScaleKind::MXF4NVF4)
+ res = emitError(loc, "mxf4nvf4 requires block scale attribute");
+
+ if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
+ kind != Tcgen05MMABlockScaleKind::MXF4NVF4)
+ res = emitError(loc,
+ llvm::formatv("{} kind does not support block16 attribute",
+ stringifyEnum(kind)));
+ return res;
----------------
grypp wrote:
same comment about the style of the verifier.
https://github.com/llvm/llvm-project/pull/164356
More information about the Mlir-commits
mailing list