[Mlir-commits] [mlir] [MLIR] Support for dense and sparse MMA with block scaling (PR #170566)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 3 14:00:24 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Kirill Vedernikov (kvederni)

<details>
<summary>Changes</summary>

This change adds dense and sparse MMA with block scaling intrinsics to MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.

---

Patch is 121.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170566.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+449-2) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+633-1) 
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir (+525) 
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir (+637) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a96d65d3fcacd..1faa435fca6f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2499,6 +2499,30 @@ class NVVM_MMA_OPS {
             bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
             subint_mma_sp_ops, int_mma_sp_ops);
 
+  // Block scale MMA operations (dense)
+  list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+            [GEOM<16,8,64>],
+            ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+            [GEOM<16,8,32>],
+            ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+            ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+            ["f32"], []>.ret;
+  list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+            mxf4_mma_ops, mxf8f6f4_mma_ops);
+
+  // Block scale sparse MMA operations
+  list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,128>],
+            ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,64>],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f32"], []>.ret;
+  list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+            mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -3332,7 +3356,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
     The optional `orderedMetadata` attribute specifies the metadata ordering:
     - Absence (default): Uses standard sparse metadata ordering
     - Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
-    
+
     The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
     - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
     - Only valid with ordered metadata and m16n8k64 shape
@@ -3347,7 +3371,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
                           sparseMetadata[%meta] selector[%sel]
                           {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
         : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-    
+
     // With ordered metadata:
     %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
                           sparseMetadata[%meta] selector[%sel]
@@ -3416,6 +3440,429 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
   let hasVerifier = 1;
 }
 
+def ScaleVecSize1X  : I32EnumAttrCase<"X1", 0, "x1">;
+def ScaleVecSize2X  : I32EnumAttrCase<"X2", 1, "x2">;
+def ScaleVecSize4X  : I32EnumAttrCase<"X4", 2, "x4">;
+
+def ScaleVecSize : I32EnumAttr<
+  "ScaleVecSize",
+  "MMA Scale Vector Sizes",
+  [ScaleVecSize1X, ScaleVecSize2X, ScaleVecSize4X]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+
+def ScaleVecSizeAttr : EnumAttr<NVVM_Dialect, ScaleVecSize, "scale_vec_size"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def UE8M0 : I32EnumAttrCase<"UE8M0", 0, "ue8m0">;
+def UE4M3 : I32EnumAttrCase<"UE4M3", 1, "ue4m3">;
+
+def BlockScaleFormat : I32EnumAttr<
+  "BlockScaleFormat",
+  "MMA Block Scale Format",
+  [UE8M0, UE4M3]
+> {
+  let cppNamespace = "::mlir::NVVM";
+  let genSpecializedAttr = 0;
+}
+
+def BlockScaleFormatAttr : EnumAttr<NVVM_Dialect, BlockScaleFormat, "block_scale_format"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def MMABlockScaleKindMXF8F6F4  : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">;
+def MMABlockScaleKindMXF4  : I32EnumAttrCase<"MXF4", 1, "mxf4">;
+def MMABlockScaleKindMXF4NVF4  : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">;
+
+def MMABlockScaleKind : I32EnumAttr<
+  "MMABlockScaleKind",
+  "Block Scale Kind",
+  [MMABlockScaleKindMXF8F6F4, MMABlockScaleKindMXF4, MMABlockScaleKindMXF4NVF4]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+
+def MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, MMABlockScaleKind, "block_scale_kind"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+/// Generate enum value of the mma.block_scale intrinsic.
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+                           WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string id = "llvm::Intrinsic::nvvm_mma_block_scale"
+              # "_" # A.geom
+              # "_row_col"
+              # "_" # Kind
+              # !subst(".", "_", ScaleVecSize)
+              # signature
+              # "_" # SType;
+}
+
+/// Generate enum value of the mma.sp.block_scale intrinsic.
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+                              WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string id = "llvm::Intrinsic::nvvm_mma_sp_ordered_metadata_block_scale"
+              # "_" # A.geom
+              # "_row_col"
+              # "_" # Kind
+              # !subst(".", "_", ScaleVecSize)
+              # signature
+              # "_" # SType;
+}
+
+// Returns true if this combination is supported for MMA.BLOCK_SCALE ops.
+// This references the NVVM_MMA_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+                                     string stype, string scale_vec_size> {
+  string geom = frags[0].geom;
+  bit ret = !cond(
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_2x")),
+         !eq(stype, "ue8m0")) : true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(scale_vec_size, ".scale_2x"),
+         !eq(stype, "ue8m0")) : true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(scale_vec_size, ".scale_4x"),
+         !eq(stype, "ue4m3")) : true,
+    !and(!eq(geom, "m16n8k32"),
+         !eq(kind, "mxf8f6f4"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_1x")),
+         !eq(stype, "ue8m0")) : true,
+    true: false
+  );
+}
+
+// Returns true if this combination is supported for MMA.SP.BLOCK_SCALE ops.
+// This references the NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+                                        string stype, string scale_vec_size> {
+  string geom = frags[0].geom;
+  bit ret = !cond(
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4"),
+         !eq(stype, "ue8m0"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_2x"))): true,
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(stype, "ue8m0"),
+         !eq(scale_vec_size, ".scale_2x")): true,
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(stype, "ue4m3"),
+         !eq(scale_vec_size, ".scale_4x")): true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf8f6f4"),
+         !eq(stype, "ue8m0"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_1x"))): true,
+    true: false
+  );
+}
+
+/// Helper to create the mapping between the configuration and the mma.block_scale
+/// intrinsic enum value.
+class MMA_BLOCK_SCALE_INTR {
+  list<list<list<list<string>>>> cond0 =
+    !foreach(op, NVVM_MMA_OPS.all_mma_block_scale_ops,
+      !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+        !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+          !foreach(stype, ["ue8m0", "ue4m3"],
+            !if(NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+                "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+                # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+                # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+                # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+                # " && \"" # kind # "\" == stringifyEnum(kind)"
+                # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+                # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+                # "  return " #
+                MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+                "") // if supported
+          ) // stype
+        ) // scale_vec_size
+      ) // kind
+    ); // all_mma_block_scale_ops
+  list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+                                       !listconcat(acc, el));
+  list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+  list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+  string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.block_scale
+/// intrinsic enum value.
+class MMA_SP_BLOCK_SCALE_INTR {
+  list<list<list<list<string>>>> cond0 =
+    !foreach(op, NVVM_MMA_OPS.all_mma_sp_block_scale_ops,
+      !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+        !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+          !foreach(stype, ["ue8m0", "ue4m3"],
+            !if(NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+                "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+                # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+                # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+                # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+                # " && \"" # kind # "\" == stringifyEnum(kind)"
+                # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+                # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+                # "  return " #
+                MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+                "") // if supported
+          ) // stype
+        ) // scale_vec_size
+      ) // kind
+    ); // all_mma_sp_block_scale_ops
+  list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+                                       !listconcat(acc, el));
+  list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+  list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+  string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+// Common base class for MMA block scale operations (dense and sparse)
+class NVVM_MmaBlockScaleBase<string mnemonic, list<Trait> traits = []> :
+    NVVM_Op<mnemonic, !listconcat([AttrSizedOperandSegments], traits)> {
+
+  let results = (outs LLVM_AnyStruct:$res);
+
+  // Common attributes shared by both dense and sparse variants
+  dag commonArguments = (ins
+           NVVM_MMAShapeAttr:$shape,
+           OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+           OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+           ScaleVecSizeAttr:$scaleVecSize,
+           BlockScaleFormatAttr:$blockScaleFormat,
+           MMABlockScaleKindAttr:$kind);
+
+  // Common variadic operands for A, B, C matrices
+  dag commonVariadicOperands = (ins
+           Variadic<LLVM_Type>:$operandA,
+           Variadic<LLVM_Type>:$operandB,
+           Variadic<LLVM_Type>:$operandC);
+
+  // Common scale operands for both A and B
+  dag commonScaleOperands = (ins
+             I32:$scaleAData,
+             I16:$byteIdA,
+             I16:$threadIdA,
+             I32:$scaleBData,
+             I16:$byteIdB,
+             I16:$threadIdB);
+
+  let extraClassDeclaration = !strconcat([{
+      static llvm::Intrinsic::ID getIntrinsicID(
+            int64_t m, int64_t n, uint64_t k,
+            mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+            mlir::NVVM::MMATypes eltypeCEnum,
+            mlir::NVVM::ScaleVecSize scaleVecSize,
+            mlir::NVVM::BlockScaleFormat blockScaleFormat,
+            mlir::NVVM::MMABlockScaleKind kind) {
+        llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+        llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+        llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+        auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+          switch (svs) {
+            case ScaleVecSize::X1: return ".scale_1x";
+            case ScaleVecSize::X2: return ".scale_2x";
+            case ScaleVecSize::X4: return ".scale_4x";
+          }
+          return "";
+        };
+        }],
+        MMA_BLOCK_SCALE_INTR<>.id, [{
+        return 0;
+      }
+
+      // Common declarations - implementations in NVVMDialect.cpp
+      MMATypes accumPtxType();
+      MMATypes resultPtxType();
+
+      static mlir::NVVM::IDArgPair
+      getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                            llvm::IRBuilderBase& builder);
+    }]);
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
+def NVVM_MmaBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.block_scale"> {
+
+  let summary = "cooperative matrix-multiply and accumulate with block scaling";
+
+  let description = [{
+    The `nvvm.mma.block_scale` operation collectively performs the operation
+    `D = matmul(A * SF_A, B * SF_B) + C` using all threads in a warp.
+
+    A, B, C and D are dense matrices and SF_A and SF_B are scaling factors.
+    Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+    and the data type must be either ue8m0 or ue4m3.
+
+    All the threads in the warp must execute the same `mma.block_scale` operation.
+
+    This operation follows the same design pattern as `nvvm.mma.sync`, with additional
+    scaling operands for both A and B matrices.
+
+    Example:
+    ```mlir
+    %d = nvvm.mma.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+                              scaleA[%scaleAData, %byteIdA, %threadIdA]
+                              scaleB[%scaleBData, %byteIdB, %threadIdB]
+                              {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+                               multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+                               multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+                               scaleVecSize = #nvvm.scale_vec_size<x2>,
+                               blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+                               kind = #nvvm.block_scale_kind<mxf4nvf4>}
+        : (vector<4xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+    ```
+  }];
+
+  // Combine common attributes and operands
+  let arguments = !con(commonArguments, commonVariadicOperands, commonScaleOperands);
+
+  let builders = [
+      OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+        "ValueRange":$operandB, "ValueRange":$operandC,
+        "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+        "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+        "ArrayRef<int64_t>":$shape,
+        "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+        "ScaleVecSize":$scaleVecSize,
+        "BlockScaleFormat":$blockScaleFormat,
+        "MMABlockScaleKind":$kind)>
+    ];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::MmaBlockScaleOp::getIntrinsicIDAndArgs(
+                      *op, moduleTranslation, builder);
+    $res = createIntrinsicCall(builder, id, args);
+  }];
+}
+
+def NVVM_MmaSpBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.sp.block_scale"> {
+
+  let summary = "cooperative sparse matrix-multiply and accumulate with block scaling";
+
+  let description = [{
+    The `nvvm.mma.sp.block_scale` operation collectively performs the operation
+    `D = matmul(A_sparse * SF_A, B * SF_B) + C` using all threads in a warp.
+
+    A is a sparse matrix, and B, C and D are dense matrices.
+    SF_A and SF_B are scaling factors.
+    Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+    and the data type must be either ue8m0 or ue4m3.
+
+    This operation is similar to `nvvm.mma.block_scale` but with structured sparsity
+    in the A operand. The sparsity follows the 2:4 structured sparse pattern
+    where 2 out of every 4 elements are non-zero.
+
+    All the threads in the warp must execute the same `mma.sp.block_scale` operation.
+
+    The `sparseMetadata` operand provides the sparsity indices that indicate
+    which elements in the A operand are non-zero. The `sparsitySelector`
+    controls how the indices are distributed among threads in the warp and
+    should typically be 0 or 1.
+
+    This operation follows the same design pattern as `nvvm.mma.sp.sync`, with additional
+    scaling operands for both A and B matrices. Note that sparse block scale operations
+    always use ordered metadata (sm_90+).
+
+    Example:
+    ```mlir
+    %d = nvvm.mma.sp.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+                                 sparseMetadata[%meta] selector[%sel]
+                                 scaleA[%scaleAData, %byteIdA, %threadIdA]
+                                 scaleB[%scaleBData, %byteIdB, %threadIdB]
+                                 {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+                                  multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+                                  multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+                                  scaleVecSize = #nvvm.scale_vec_size<x2>,
+                                  blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+                                  kind = #nvvm.block_scale_kind<mxf4>}
+        : (vector<2xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+    ```
+  }];
+
+  // Sparse-specific attributes and operands
+  dag sparseSpecificArguments = (ins
+           UnitAttr:$orderedMetadata);
+
+  dag sparseSpecificOperands = (ins
+             I32:$sparseMetadata,
+             I32:$sparsitySelector);
+
+  // Combine common and sparse-specific attributes and operands
+  let arguments = !con(commonArguments, sparseSpecificArguments, 
+                       commonVariadicOperands, sparseSpecificOperands,
+                       commonScaleOperands);
+
+  // Override extraClassDeclaration to use sparse intrinsics
+  let extraClassDeclaration = !strconcat([{
+      static llvm::Intrinsic::ID getIntrinsicID(
+            int64_t m, int64_t n, uint64_t k,
+            mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+            mlir::NVVM::MMATypes eltypeCEnum,
+            mlir::NVVM::ScaleVecSize scaleVecSize,
+            mlir::NVVM::BlockScaleFormat blockScaleFormat,
+            mlir::NVVM::MMABlockScaleKind kind) {
+        llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+        llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+        llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+        auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+          switch (svs) {
+            case ScaleVecSize::X1: return ".scale_1x";
+            case ScaleVecSize::X2: return ".scale_2x";
+            case ScaleVecSize::X4: return ".scale_4x";
+          }
+          return "";
+        };
+        }],
+        MMA_SP_BLOCK_SCALE_INTR<>.id, [{
+        return 0;
+      }
+
+      MMATypes accumPtxType();
+      MMATypes resultPtxType();
+
+      static mlir::NVVM::IDArgPair
+      getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                            llvm::IRBuilderBase& builder);
+    }]);
+
+  let builders = [
+      OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+        "ValueRange":$operandB, "ValueRange":$operandC,
+        "Value":$sparseMetadata, "Value":$sparsitySelector,
+        "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+        "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+        "ArrayRef<int64_t>":$shape,
+        "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+        "ScaleVecSize":$scaleVecSize,
+        "BlockScaleFormat":$blockScaleFormat,
+        "MMABlockScaleKind":$kind)>
+    ];
+
+  string llvmBuilder = [{
+    auto [id, a...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/170566


More information about the Mlir-commits mailing list