[Mlir-commits] [mlir] [MLIR] Support for dense and sparse MMA with block scaling (PR #170566)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 3 14:00:24 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Kirill Vedernikov (kvederni)
<details>
<summary>Changes</summary>
This change adds dense and sparse MMA with block scaling intrinsics to MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.
---
Patch is 121.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170566.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+449-2)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+633-1)
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir (+525)
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir (+637)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a96d65d3fcacd..1faa435fca6f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2499,6 +2499,30 @@ class NVVM_MMA_OPS {
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
subint_mma_sp_ops, int_mma_sp_ops);
+ // Block scale MMA operations (dense)
+ list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+ [GEOM<16,8,64>],
+ ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+ [GEOM<16,8,32>],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["f32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+ mxf4_mma_ops, mxf8f6f4_mma_ops);
+
+ // Block scale sparse MMA operations
+ list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,128>],
+ ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,64>],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+ mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -3332,7 +3356,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
The optional `orderedMetadata` attribute specifies the metadata ordering:
- Absence (default): Uses standard sparse metadata ordering
- Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
-
+
The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
- `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
- Only valid with ordered metadata and m16n8k64 shape
@@ -3347,7 +3371,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
sparseMetadata[%meta] selector[%sel]
{shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-
+
// With ordered metadata:
%d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
@@ -3416,6 +3440,429 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
}
+def ScaleVecSize1X : I32EnumAttrCase<"X1", 0, "x1">;
+def ScaleVecSize2X : I32EnumAttrCase<"X2", 1, "x2">;
+def ScaleVecSize4X : I32EnumAttrCase<"X4", 2, "x4">;
+
+def ScaleVecSize : I32EnumAttr<
+ "ScaleVecSize",
+ "MMA Scale Vector Sizes",
+ [ScaleVecSize1X, ScaleVecSize2X, ScaleVecSize4X]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def ScaleVecSizeAttr : EnumAttr<NVVM_Dialect, ScaleVecSize, "scale_vec_size"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def UE8M0 : I32EnumAttrCase<"UE8M0", 0, "ue8m0">;
+def UE4M3 : I32EnumAttrCase<"UE4M3", 1, "ue4m3">;
+
+def BlockScaleFormat : I32EnumAttr<
+ "BlockScaleFormat",
+ "MMA Block Scale Format",
+ [UE8M0, UE4M3]
+> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def BlockScaleFormatAttr : EnumAttr<NVVM_Dialect, BlockScaleFormat, "block_scale_format"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def MMABlockScaleKindMXF8F6F4 : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">;
+def MMABlockScaleKindMXF4 : I32EnumAttrCase<"MXF4", 1, "mxf4">;
+def MMABlockScaleKindMXF4NVF4 : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">;
+
+def MMABlockScaleKind : I32EnumAttr<
+ "MMABlockScaleKind",
+ "Block Scale Kind",
+ [MMABlockScaleKindMXF8F6F4, MMABlockScaleKindMXF4, MMABlockScaleKindMXF4NVF4]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, MMABlockScaleKind, "block_scale_kind"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+/// Generate enum value of the mma.block_scale intrinsic.
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
+/// Generate enum value of the mma.sp.block_scale intrinsic.
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma_sp_ordered_metadata_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
+// Returns true if this combination is supported for MMA.BLOCK_SCALE ops.
+// This references the NVVM_MMA_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x")),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_2x"),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_4x"),
+ !eq(stype, "ue4m3")) : true,
+ !and(!eq(geom, "m16n8k32"),
+ !eq(kind, "mxf8f6f4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x")),
+ !eq(stype, "ue8m0")) : true,
+ true: false
+ );
+}
+
+// Returns true if this combination is supported for MMA.SP.BLOCK_SCALE ops.
+// This references the NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x"))): true,
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue8m0"),
+ !eq(scale_vec_size, ".scale_2x")): true,
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue4m3"),
+ !eq(scale_vec_size, ".scale_4x")): true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf8f6f4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x"))): true,
+ true: false
+ );
+}
+
+/// Helper to create the mapping between the configuration and the mma.block_scale
+/// intrinsic enum value.
+class MMA_BLOCK_SCALE_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_block_scale_ops,
+ !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+ !foreach(stype, ["ue8m0", "ue4m3"],
+ !if(NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # kind # "\" == stringifyEnum(kind)"
+ # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+ # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+ # " return " #
+ MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // stype
+ ) // scale_vec_size
+ ) // kind
+ ); // all_mma_block_scale_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+ !listconcat(acc, el));
+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.block_scale
+/// intrinsic enum value.
+class MMA_SP_BLOCK_SCALE_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_sp_block_scale_ops,
+ !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+ !foreach(stype, ["ue8m0", "ue4m3"],
+ !if(NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # kind # "\" == stringifyEnum(kind)"
+ # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+ # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+ # " return " #
+ MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // stype
+ ) // scale_vec_size
+ ) // kind
+ ); // all_mma_sp_block_scale_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+ !listconcat(acc, el));
+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+// Common base class for MMA block scale operations (dense and sparse)
+class NVVM_MmaBlockScaleBase<string mnemonic, list<Trait> traits = []> :
+ NVVM_Op<mnemonic, !listconcat([AttrSizedOperandSegments], traits)> {
+
+ let results = (outs LLVM_AnyStruct:$res);
+
+ // Common attributes shared by both dense and sparse variants
+ dag commonArguments = (ins
+ NVVM_MMAShapeAttr:$shape,
+ OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+ OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+ ScaleVecSizeAttr:$scaleVecSize,
+ BlockScaleFormatAttr:$blockScaleFormat,
+ MMABlockScaleKindAttr:$kind);
+
+ // Common variadic operands for A, B, C matrices
+ dag commonVariadicOperands = (ins
+ Variadic<LLVM_Type>:$operandA,
+ Variadic<LLVM_Type>:$operandB,
+ Variadic<LLVM_Type>:$operandC);
+
+ // Common scale operands for both A and B
+ dag commonScaleOperands = (ins
+ I32:$scaleAData,
+ I16:$byteIdA,
+ I16:$threadIdA,
+ I32:$scaleBData,
+ I16:$byteIdB,
+ I16:$threadIdB);
+
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum,
+ mlir::NVVM::ScaleVecSize scaleVecSize,
+ mlir::NVVM::BlockScaleFormat blockScaleFormat,
+ mlir::NVVM::MMABlockScaleKind kind) {
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+ auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+ switch (svs) {
+ case ScaleVecSize::X1: return ".scale_1x";
+ case ScaleVecSize::X2: return ".scale_2x";
+ case ScaleVecSize::X4: return ".scale_4x";
+ }
+ return "";
+ };
+ }],
+ MMA_BLOCK_SCALE_INTR<>.id, [{
+ return 0;
+ }
+
+ // Common declarations - implementations in NVVMDialect.cpp
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }]);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def NVVM_MmaBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.block_scale"> {
+
+ let summary = "cooperative matrix-multiply and accumulate with block scaling";
+
+ let description = [{
+ The `nvvm.mma.block_scale` operation collectively performs the operation
+ `D = matmul(A * SF_A, B * SF_B) + C` using all threads in a warp.
+
+ A, B, C and D are dense matrices and SF_A and SF_B are scaling factors.
+ Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+ and the data type must be either ue8m0 or ue4m3.
+
+ All the threads in the warp must execute the same `mma.block_scale` operation.
+
+ This operation follows the same design pattern as `nvvm.mma.sync`, with additional
+ scaling operands for both A and B matrices.
+
+ Example:
+ ```mlir
+ %d = nvvm.mma.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>}
+ : (vector<4xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+ ```
+ }];
+
+ // Combine common attributes and operands
+ let arguments = !con(commonArguments, commonVariadicOperands, commonScaleOperands);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+ "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+ "ArrayRef<int64_t>":$shape,
+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+ "ScaleVecSize":$scaleVecSize,
+ "BlockScaleFormat":$blockScaleFormat,
+ "MMABlockScaleKind":$kind)>
+ ];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::MmaBlockScaleOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ $res = createIntrinsicCall(builder, id, args);
+ }];
+}
+
+def NVVM_MmaSpBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.sp.block_scale"> {
+
+ let summary = "cooperative sparse matrix-multiply and accumulate with block scaling";
+
+ let description = [{
+ The `nvvm.mma.sp.block_scale` operation collectively performs the operation
+ `D = matmul(A_sparse * SF_A, B * SF_B) + C` using all threads in a warp.
+
+ A is a sparse matrix, and B, C and D are dense matrices.
+ SF_A and SF_B are scaling factors.
+ Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+ and the data type must be either ue8m0 or ue4m3.
+
+ This operation is similar to `nvvm.mma.block_scale` but with structured sparsity
+ in the A operand. The sparsity follows the 2:4 structured sparse pattern
+ where 2 out of every 4 elements are non-zero.
+
+ All the threads in the warp must execute the same `mma.sp.block_scale` operation.
+
+ The `sparseMetadata` operand provides the sparsity indices that indicate
+ which elements in the A operand are non-zero. The `sparsitySelector`
+ controls how the indices are distributed among threads in the warp and
+ should typically be 0 or 1.
+
+ This operation follows the same design pattern as `nvvm.mma.sp.sync`, with additional
+ scaling operands for both A and B matrices. Note that sparse block scale operations
+ always use ordered metadata (sm_90+).
+
+ Example:
+ ```mlir
+ %d = nvvm.mma.sp.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+ ```
+ }];
+
+ // Sparse-specific attributes and operands
+ dag sparseSpecificArguments = (ins
+ UnitAttr:$orderedMetadata);
+
+ dag sparseSpecificOperands = (ins
+ I32:$sparseMetadata,
+ I32:$sparsitySelector);
+
+ // Combine common and sparse-specific attributes and operands
+ let arguments = !con(commonArguments, sparseSpecificArguments,
+ commonVariadicOperands, sparseSpecificOperands,
+ commonScaleOperands);
+
+ // Override extraClassDeclaration to use sparse intrinsics
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum,
+ mlir::NVVM::ScaleVecSize scaleVecSize,
+ mlir::NVVM::BlockScaleFormat blockScaleFormat,
+ mlir::NVVM::MMABlockScaleKind kind) {
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+ auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+ switch (svs) {
+ case ScaleVecSize::X1: return ".scale_1x";
+ case ScaleVecSize::X2: return ".scale_2x";
+ case ScaleVecSize::X4: return ".scale_4x";
+ }
+ return "";
+ };
+ }],
+ MMA_SP_BLOCK_SCALE_INTR<>.id, [{
+ return 0;
+ }
+
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }]);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "Value":$sparseMetadata, "Value":$sparsitySelector,
+ "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+ "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+ "ArrayRef<int64_t>":$shape,
+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+ "ScaleVecSize":$scaleVecSize,
+ "BlockScaleFormat":$blockScaleFormat,
+ "MMABlockScaleKind":$kind)>
+ ];
+
+ string llvmBuilder = [{
+ auto [id, a...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/170566
More information about the Mlir-commits
mailing list