[Mlir-commits] [mlir] [MLIR][NVVM] Support for dense and sparse MMA with block scaling (PR #170566)
Kirill Vedernikov
llvmlistbot at llvm.org
Fri Dec 12 03:03:46 PST 2025
https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/170566
>From 277da943277f20f3cc2cfcfbf1a5af07645722f2 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 3 Dec 2025 22:55:24 +0100
Subject: [PATCH 01/11] [MLIR] Support for dense and sparse MMA with block
scaling
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 451 ++++++++++++-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 634 ++++++++++++++++-
.../Dialect/LLVMIR/nvvm-mma-blockscale.mlir | 525 +++++++++++++++
.../LLVMIR/nvvm-mma-sparse-blockscale.mlir | 637 ++++++++++++++++++
4 files changed, 2244 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir
create mode 100644 mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a96d65d3fcacd..1faa435fca6f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2499,6 +2499,30 @@ class NVVM_MMA_OPS {
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
subint_mma_sp_ops, int_mma_sp_ops);
+ // Block scale MMA operations (dense)
+ list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+ [GEOM<16,8,64>],
+ ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+ [GEOM<16,8,32>],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["f32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+ mxf4_mma_ops, mxf8f6f4_mma_ops);
+
+ // Block scale sparse MMA operations
+ list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,128>],
+ ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,64>],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+ mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -3332,7 +3356,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
The optional `orderedMetadata` attribute specifies the metadata ordering:
- Absence (default): Uses standard sparse metadata ordering
- Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
-
+
The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
- `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
- Only valid with ordered metadata and m16n8k64 shape
@@ -3347,7 +3371,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
sparseMetadata[%meta] selector[%sel]
{shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-
+
// With ordered metadata:
%d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
@@ -3416,6 +3440,429 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
}
+def ScaleVecSize1X : I32EnumAttrCase<"X1", 0, "x1">;
+def ScaleVecSize2X : I32EnumAttrCase<"X2", 1, "x2">;
+def ScaleVecSize4X : I32EnumAttrCase<"X4", 2, "x4">;
+
+def ScaleVecSize : I32EnumAttr<
+ "ScaleVecSize",
+ "MMA Scale Vector Sizes",
+ [ScaleVecSize1X, ScaleVecSize2X, ScaleVecSize4X]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def ScaleVecSizeAttr : EnumAttr<NVVM_Dialect, ScaleVecSize, "scale_vec_size"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def UE8M0 : I32EnumAttrCase<"UE8M0", 0, "ue8m0">;
+def UE4M3 : I32EnumAttrCase<"UE4M3", 1, "ue4m3">;
+
+def BlockScaleFormat : I32EnumAttr<
+ "BlockScaleFormat",
+ "MMA Block Scale Format",
+ [UE8M0, UE4M3]
+> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def BlockScaleFormatAttr : EnumAttr<NVVM_Dialect, BlockScaleFormat, "block_scale_format"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def MMABlockScaleKindMXF8F6F4 : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">;
+def MMABlockScaleKindMXF4 : I32EnumAttrCase<"MXF4", 1, "mxf4">;
+def MMABlockScaleKindMXF4NVF4 : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">;
+
+def MMABlockScaleKind : I32EnumAttr<
+ "MMABlockScaleKind",
+ "Block Scale Kind",
+ [MMABlockScaleKindMXF8F6F4, MMABlockScaleKindMXF4, MMABlockScaleKindMXF4NVF4]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, MMABlockScaleKind, "block_scale_kind"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+/// Generate enum value of the mma.block_scale intrinsic.
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
+/// Generate enum value of the mma.sp.block_scale intrinsic.
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma_sp_ordered_metadata_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
+// Returns true if this combination is supported for MMA.BLOCK_SCALE ops.
+// This references the NVVM_MMA_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x")),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_2x"),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_4x"),
+ !eq(stype, "ue4m3")) : true,
+ !and(!eq(geom, "m16n8k32"),
+ !eq(kind, "mxf8f6f4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x")),
+ !eq(stype, "ue8m0")) : true,
+ true: false
+ );
+}
+
+// Returns true if this combination is supported for MMA.SP.BLOCK_SCALE ops.
+// This references the NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x"))): true,
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue8m0"),
+ !eq(scale_vec_size, ".scale_2x")): true,
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue4m3"),
+ !eq(scale_vec_size, ".scale_4x")): true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf8f6f4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x"))): true,
+ true: false
+ );
+}
+
+/// Helper to create the mapping between the configuration and the mma.block_scale
+/// intrinsic enum value.
+class MMA_BLOCK_SCALE_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_block_scale_ops,
+ !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+ !foreach(stype, ["ue8m0", "ue4m3"],
+ !if(NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # kind # "\" == stringifyEnum(kind)"
+ # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+ # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+ # " return " #
+ MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // stype
+ ) // scale_vec_size
+ ) // kind
+ ); // all_mma_block_scale_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+ !listconcat(acc, el));
+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.block_scale
+/// intrinsic enum value.
+class MMA_SP_BLOCK_SCALE_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_sp_block_scale_ops,
+ !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+ !foreach(stype, ["ue8m0", "ue4m3"],
+ !if(NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # kind # "\" == stringifyEnum(kind)"
+ # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+ # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+ # " return " #
+ MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // stype
+ ) // scale_vec_size
+ ) // kind
+ ); // all_mma_sp_block_scale_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+ !listconcat(acc, el));
+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+// Common base class for MMA block scale operations (dense and sparse)
+class NVVM_MmaBlockScaleBase<string mnemonic, list<Trait> traits = []> :
+ NVVM_Op<mnemonic, !listconcat([AttrSizedOperandSegments], traits)> {
+
+ let results = (outs LLVM_AnyStruct:$res);
+
+ // Common attributes shared by both dense and sparse variants
+ dag commonArguments = (ins
+ NVVM_MMAShapeAttr:$shape,
+ OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+ OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+ ScaleVecSizeAttr:$scaleVecSize,
+ BlockScaleFormatAttr:$blockScaleFormat,
+ MMABlockScaleKindAttr:$kind);
+
+ // Common variadic operands for A, B, C matrices
+ dag commonVariadicOperands = (ins
+ Variadic<LLVM_Type>:$operandA,
+ Variadic<LLVM_Type>:$operandB,
+ Variadic<LLVM_Type>:$operandC);
+
+ // Common scale operands for both A and B
+ dag commonScaleOperands = (ins
+ I32:$scaleAData,
+ I16:$byteIdA,
+ I16:$threadIdA,
+ I32:$scaleBData,
+ I16:$byteIdB,
+ I16:$threadIdB);
+
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum,
+ mlir::NVVM::ScaleVecSize scaleVecSize,
+ mlir::NVVM::BlockScaleFormat blockScaleFormat,
+ mlir::NVVM::MMABlockScaleKind kind) {
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+ auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+ switch (svs) {
+ case ScaleVecSize::X1: return ".scale_1x";
+ case ScaleVecSize::X2: return ".scale_2x";
+ case ScaleVecSize::X4: return ".scale_4x";
+ }
+ return "";
+ };
+ }],
+ MMA_BLOCK_SCALE_INTR<>.id, [{
+ return 0;
+ }
+
+ // Common declarations - implementations in NVVMDialect.cpp
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }]);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def NVVM_MmaBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.block_scale"> {
+
+ let summary = "cooperative matrix-multiply and accumulate with block scaling";
+
+ let description = [{
+ The `nvvm.mma.block_scale` operation collectively performs the operation
+ `D = matmul(A * SF_A, B * SF_B) + C` using all threads in a warp.
+
+ A, B, C and D are dense matrices and SF_A and SF_B are scaling factors.
+ Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+ and the data type must be either ue8m0 or ue4m3.
+
+ All the threads in the warp must execute the same `mma.block_scale` operation.
+
+ This operation follows the same design pattern as `nvvm.mma.sync`, with additional
+ scaling operands for both A and B matrices.
+
+ Example:
+ ```mlir
+ %d = nvvm.mma.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>}
+ : (vector<4xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+ ```
+ }];
+
+ // Combine common attributes and operands
+ let arguments = !con(commonArguments, commonVariadicOperands, commonScaleOperands);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+ "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+ "ArrayRef<int64_t>":$shape,
+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+ "ScaleVecSize":$scaleVecSize,
+ "BlockScaleFormat":$blockScaleFormat,
+ "MMABlockScaleKind":$kind)>
+ ];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::MmaBlockScaleOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ $res = createIntrinsicCall(builder, id, args);
+ }];
+}
+
+def NVVM_MmaSpBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.sp.block_scale"> {
+
+ let summary = "cooperative sparse matrix-multiply and accumulate with block scaling";
+
+ let description = [{
+ The `nvvm.mma.sp.block_scale` operation collectively performs the operation
+ `D = matmul(A_sparse * SF_A, B * SF_B) + C` using all threads in a warp.
+
+ A is a sparse matrix, and B, C and D are dense matrices.
+ SF_A and SF_B are scaling factors.
+ Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+ and the data type must be either ue8m0 or ue4m3.
+
+ This operation is similar to `nvvm.mma.block_scale` but with structured sparsity
+ in the A operand. The sparsity follows the 2:4 structured sparse pattern
+ where 2 out of every 4 elements are non-zero.
+
+ All the threads in the warp must execute the same `mma.sp.block_scale` operation.
+
+ The `sparseMetadata` operand provides the sparsity indices that indicate
+ which elements in the A operand are non-zero. The `sparsitySelector`
+ controls how the indices are distributed among threads in the warp and
+ should typically be 0 or 1.
+
+ This operation follows the same design pattern as `nvvm.mma.sp.sync`, with additional
+ scaling operands for both A and B matrices. Note that sparse block scale operations
+ always use ordered metadata (sm_90+).
+
+ Example:
+ ```mlir
+ %d = nvvm.mma.sp.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+ ```
+ }];
+
+ // Sparse-specific attributes and operands
+ dag sparseSpecificArguments = (ins
+ UnitAttr:$orderedMetadata);
+
+ dag sparseSpecificOperands = (ins
+ I32:$sparseMetadata,
+ I32:$sparsitySelector);
+
+ // Combine common and sparse-specific attributes and operands
+ let arguments = !con(commonArguments, sparseSpecificArguments,
+ commonVariadicOperands, sparseSpecificOperands,
+ commonScaleOperands);
+
+ // Override extraClassDeclaration to use sparse intrinsics
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum,
+ mlir::NVVM::ScaleVecSize scaleVecSize,
+ mlir::NVVM::BlockScaleFormat blockScaleFormat,
+ mlir::NVVM::MMABlockScaleKind kind) {
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+ auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+ switch (svs) {
+ case ScaleVecSize::X1: return ".scale_1x";
+ case ScaleVecSize::X2: return ".scale_2x";
+ case ScaleVecSize::X4: return ".scale_4x";
+ }
+ return "";
+ };
+ }],
+ MMA_SP_BLOCK_SCALE_INTR<>.id, [{
+ return 0;
+ }
+
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }]);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "Value":$sparseMetadata, "Value":$sparsitySelector,
+ "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+ "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+ "ArrayRef<int64_t>":$shape,
+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+ "ScaleVecSize":$scaleVecSize,
+ "BlockScaleFormat":$blockScaleFormat,
+ "MMABlockScaleKind":$kind)>
+ ];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ $res = createIntrinsicCall(builder, id, args);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ada4223ac12de..a06fe19b8ad1c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -694,7 +694,7 @@ void MmaOp::print(OpAsmPrinter &p) {
}
}
std::optional<MMATypes> inferredType =
- inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
if (inferredType)
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
@@ -1559,6 +1559,638 @@ LogicalResult MmaSpOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// MMA Block Scale Operations - Shared Helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Shared structure for MMA operand fragments (A, B, C)
+struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+};
+
+// Helper to print operand list in the format: name[operands]
+void printOperandList(OpAsmPrinter &p, StringRef name,
+ ArrayRef<Value> operands) {
+ p << " " << name << "[";
+ p.printOperands(operands);
+ p << "]";
+}
+
+// Helper to parse operand list in the format: name[operands]
+LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser.parseOperandList(regs,
+ OpAsmParser::Delimiter::OptionalSquare).failed())
+ return failure();
+ return success();
+}
+
+// Helper to process operand fragments and determine which attributes can be inferred
+template <typename Op>
+void processOperandFragments(Op &op, std::array<OperandFragment, 3> &frags,
+ SmallVectorImpl<Type> ®Types,
+ SmallVectorImpl<StringRef> &ignoreAttrNames) {
+ for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(op.getOperand(operandIdx));
+ if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
+ regTypes.push_back(op.getOperand(operandIdx).getType());
+ }
+ }
+ if (fragIdx < 2) {
+ regTypes.push_back(frag.regs[0].getType());
+ }
+ std::optional<MMATypes> inferredType =
+ MmaOp::inferOperandMMAType(regTypes.back(),
+ /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+}
+
+// Helper to parse type signature: (A_type, B_type, C_type)
+LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
+ SmallVectorImpl<Type> &operandTypes) {
+ if (parser.parseColon().failed() || parser.parseLParen().failed())
+ return failure();
+
+ for (int i = 0; i < 3; i++) {
+ if (i > 0 && parser.parseComma().failed())
+ return failure();
+ Type ty;
+ if (parser.parseType(ty).failed())
+ return failure();
+ operandTypes.push_back(ty);
+ }
+
+ return parser.parseRParen();
+}
+
+// Helper to infer and set multiplicand PTX type attributes
+void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
+ const SmallVectorImpl<Type> &operandTypes) {
+ if (!attrs.get("multiplicandAPtxType")) {
+ if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[0],
+ false)) {
+ attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
+ }
+ }
+ if (!attrs.get("multiplicandBPtxType")) {
+ if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[1],
+ false)) {
+ attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
+ }
+ }
+}
+
+// Helper to add common block scale attributes
+void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
+ ArrayRef<int64_t> shape,
+ ScaleVecSize scaleVecSize,
+ BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute("shape",
+ builder.getAttr<MMAShapeAttr>(shape[0], shape[1],
+ shape[2]));
+ result.addAttribute("scaleVecSize",
+ ScaleVecSizeAttr::get(ctx, scaleVecSize));
+ result.addAttribute("blockScaleFormat",
+ BlockScaleFormatAttr::get(ctx, blockScaleFormat));
+ result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
+}
+
+// Helper to infer and add multiplicand PTX types to builder
+void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result,
+ ValueRange operandA, ValueRange operandB,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+ if (multiplicandPtxTypes) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, *res));
+ if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, *res));
+ }
+}
+
+// Template helper for common accumPtxType/resultPtxType implementation
+template <typename OpTy>
+MMATypes inferPtxTypeFromResult(OpTy op) {
+ return *MmaOp::inferOperandMMAType(
+ cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
+ /*isAccumulator=*/true);
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// MmaBlockScaleOp
+//===----------------------------------------------------------------------===//
+
+void MmaBlockScaleOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ std::array<OperandFragment, 3> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
+
+ processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
+
+ // Print A, B, C operands
+ for (const auto &frag : frags)
+ printOperandList(p, frag.operandName, frag.regs);
+
+ // Print scale operands
+ printOperandList(p, "scaleA",
+ {getScaleAData(), getByteIdA(), getThreadIdA()});
+ printOperandList(p, "scaleB",
+ {getScaleBData(), getByteIdB(), getThreadIdB()});
+
+ p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+ // Print type signature
+ p << " : (";
+ llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+ frags[1].regs[0].getType(),
+ frags[2].regs[0].getType()}, p);
+ p << ")";
+ p.printArrowTypeList(TypeRange{this->getRes().getType()});
+}
+
+ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ struct LocalOperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<LocalOperandFragment, 3> frags;
+ NamedAttrList namedAttributes;
+
+ // Parse A[...] B[...] C[...]
+ if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
+ parseMmaOperand(parser, "B", frags[1].regs).failed() ||
+ parseMmaOperand(parser, "C", frags[2].regs).failed())
+ return failure();
+
+ // Parse scale operands: scaleA[...] scaleB[...]
+ SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
+ if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
+ parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse type signature
+ SmallVector<Type, 3> operandTypes;
+ if (parseMmaTypeSignature(parser, operandTypes).failed())
+ return failure();
+
+ // Parse result type
+ SmallVector<Type, 1> resultTypes;
+ if (parser.parseArrowTypeList(resultTypes).failed())
+ return failure();
+
+ // Infer element types and resolve operands
+ for (const auto &[idx, frag] : llvm::enumerate(frags)) {
+ frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
+ /*isAccumulator=*/idx >= 2);
+ if (parser.resolveOperands(frag.regs, operandTypes[idx],
+ parser.getNameLoc(), result.operands).failed())
+ return failure();
+ }
+
+ // Resolve scale operands
+ SmallVector<Type, 3> scaleTypes = {
+ builder.getI32Type(), builder.getI16Type(), builder.getI16Type()
+ };
+ if (parser.resolveOperands(scaleAOperands, scaleTypes,
+ parser.getNameLoc(), result.operands).failed() ||
+ parser.resolveOperands(scaleBOperands, scaleTypes,
+ parser.getNameLoc(), result.operands).failed())
+ return failure();
+
+ // Add attributes
+ result.addAttributes(namedAttributes);
+ inferAndSetMultiplicandTypes(parser.getContext(),
+ result.attributes, operandTypes);
+
+ result.addTypes(resultTypes);
+ result.addAttribute(
+ MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
+ return success();
+}
+
+void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result,
+ Type resultType, ValueRange operandA,
+ ValueRange operandB, ValueRange operandC,
+ Value scaleAData, Value byteIdA, Value threadIdA,
+ Value scaleBData, Value byteIdB, Value threadIdB,
+ ArrayRef<int64_t> shape,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ ScaleVecSize scaleVecSize,
+ BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+
+ addBlockScaleAttributes(builder, result, shape, scaleVecSize,
+ blockScaleFormat, kind);
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+ result.addOperands({scaleAData, byteIdA, threadIdA,
+ scaleBData, byteIdB, threadIdB});
+
+ addInferredMultiplicandTypes(builder.getContext(), result,
+ operandA, operandB, multiplicandPtxTypes);
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()),
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
+}
+
+MMATypes MmaBlockScaleOp::accumPtxType() {
+ return inferPtxTypeFromResult(*this);
+}
+
+MMATypes MmaBlockScaleOp::resultPtxType() {
+ return inferPtxTypeFromResult(*this);
+}
+
+NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
+
+ SmallVector<llvm::Value *> args;
+ // Add A, B, C operands
+ for (Value operand : curOp.getOperandA())
+ args.push_back(mt.lookupValue(operand));
+ for (Value operand : curOp.getOperandB())
+ args.push_back(mt.lookupValue(operand));
+ for (Value operand : curOp.getOperandC())
+ args.push_back(mt.lookupValue(operand));
+
+ // Add scale operands
+ args.push_back(mt.lookupValue(curOp.getScaleAData()));
+ args.push_back(mt.lookupValue(curOp.getByteIdA()));
+ args.push_back(mt.lookupValue(curOp.getThreadIdA()));
+ args.push_back(mt.lookupValue(curOp.getScaleBData()));
+ args.push_back(mt.lookupValue(curOp.getByteIdB()));
+ args.push_back(mt.lookupValue(curOp.getThreadIdB()));
+
+ unsigned intId = MmaBlockScaleOp::getIntrinsicID(
+ curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
+ *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
+ curOp.accumPtxType(),
+ curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind());
+
+ return {intId, args};
+}
+
+LogicalResult MmaBlockScaleOp::verify() {
+ LogicalResult result = success();
+ int m = getShape().getM();
+ int n = getShape().getN();
+ int k = getShape().getK();
+
+ if (m == 16 && n == 8 && k == 64) {
+ if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
+ getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
+ result = emitOpError(
+ "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
+ if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
+ if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
+ result = emitOpError(
+ "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
+ if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
+ result = emitOpError(
+ "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
+ } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
+ if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+ (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+ result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
+ "attributes for mma.m16n8k64.mxf4nvf4");
+ } else
+ result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
+ } else if (m == 16 && n == 8 && k == 32) {
+ if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
+ getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
+ result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+ "attributes for mma.m16n8k32");
+ } else
+ result = emitOpError("unsupported Geom for mma with block scaling");
+ return result;
+}
+
+//===----------------------------------------------------------------------===//
+// MmaSpBlockScaleOp
+//===----------------------------------------------------------------------===//
+
+void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ std::array<OperandFragment, 3> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
+
+ processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
+
+ // Print A, B, C operands
+ for (const auto &frag : frags)
+ printOperandList(p, frag.operandName, frag.regs);
+
+ // Print sparse-specific operands
+ printOperandList(p, "sparseMetadata", {getSparseMetadata()});
+ printOperandList(p, "selector", {getSparsitySelector()});
+
+ // Print scale operands
+ printOperandList(p, "scaleA",
+ {getScaleAData(), getByteIdA(), getThreadIdA()});
+ printOperandList(p, "scaleB",
+ {getScaleBData(), getByteIdB(), getThreadIdB()});
+
+ p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+ // Print type signature
+ p << " : (";
+ llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+ frags[1].regs[0].getType(),
+ frags[2].regs[0].getType()}, p);
+ p << ")";
+ p.printArrowTypeList(TypeRange{this->getRes().getType()});
+}
+
+ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ struct LocalOperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<LocalOperandFragment, 3> frags;
+ NamedAttrList namedAttributes;
+
+ // Parse A[...] B[...] C[...]
+ if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
+ parseMmaOperand(parser, "B", frags[1].regs).failed() ||
+ parseMmaOperand(parser, "C", frags[2].regs).failed())
+ return failure();
+
+ // Parse sparse-specific operands
+ SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands, selectorOperands;
+ if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
+ parseMmaOperand(parser, "selector", selectorOperands).failed())
+ return failure();
+
+ // Parse scale operands
+ SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
+ if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
+ parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse type signature
+ SmallVector<Type, 3> operandTypes;
+ if (parseMmaTypeSignature(parser, operandTypes).failed())
+ return failure();
+
+ // Parse result type
+ SmallVector<Type, 1> resultTypes;
+ if (parser.parseArrowTypeList(resultTypes).failed())
+ return failure();
+
+ // Infer element types and resolve operands
+ for (const auto &[idx, frag] : llvm::enumerate(frags)) {
+ frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
+ /*isAccumulator=*/idx >= 2);
+ if (parser.resolveOperands(frag.regs, operandTypes[idx],
+ parser.getNameLoc(), result.operands).failed())
+ return failure();
+ }
+
+ // Resolve sparse metadata and selector
+ Type i32Type = builder.getI32Type();
+ if (parser.resolveOperands(metadataOperands, i32Type,
+ parser.getNameLoc(), result.operands).failed() ||
+ parser.resolveOperands(selectorOperands, i32Type,
+ parser.getNameLoc(), result.operands).failed())
+ return failure();
+
+ // Resolve scale operands
+ SmallVector<Type, 3> scaleTypes = {
+ i32Type, builder.getI16Type(), builder.getI16Type()
+ };
+ if (parser.resolveOperands(scaleAOperands, scaleTypes,
+ parser.getNameLoc(), result.operands).failed() ||
+ parser.resolveOperands(scaleBOperands, scaleTypes,
+ parser.getNameLoc(), result.operands).failed())
+ return failure();
+
+ // Add attributes
+ result.addAttributes(namedAttributes);
+ inferAndSetMultiplicandTypes(parser.getContext(),
+ result.attributes, operandTypes);
+
+ // orderedMetadata is mandatory
+ if (!result.attributes.get("orderedMetadata"))
+ result.addAttribute("orderedMetadata", builder.getUnitAttr());
+
+ result.addTypes(resultTypes);
+ result.addAttribute(
+ MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // sparseMetadata
+ 1, // sparsitySelector
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
+ return success();
+}
+
+void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result,
+ Type resultType, ValueRange operandA,
+ ValueRange operandB, ValueRange operandC,
+ Value sparseMetadata, Value sparsitySelector,
+ Value scaleAData, Value byteIdA, Value threadIdA,
+ Value scaleBData, Value byteIdB, Value threadIdB,
+ ArrayRef<int64_t> shape,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ ScaleVecSize scaleVecSize,
+ BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+
+ addBlockScaleAttributes(builder, result, shape, scaleVecSize,
+ blockScaleFormat, kind);
+ result.addAttribute("orderedMetadata", builder.getUnitAttr());
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+ result.addOperands({sparseMetadata, sparsitySelector,
+ scaleAData, byteIdA, threadIdA,
+ scaleBData, byteIdB, threadIdB});
+
+ addInferredMultiplicandTypes(builder.getContext(), result,
+ operandA, operandB, multiplicandPtxTypes);
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()),
+ 1, // sparseMetadata
+ 1, // sparsitySelector
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
+}
+
+MMATypes MmaSpBlockScaleOp::accumPtxType() {
+ return inferPtxTypeFromResult(*this);
+}
+
+MMATypes MmaSpBlockScaleOp::resultPtxType() {
+ return inferPtxTypeFromResult(*this);
+}
+
+NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
+
+ SmallVector<llvm::Value *> args;
+ // Add A, B, C operands
+ for (Value operand : curOp.getOperandA())
+ args.push_back(mt.lookupValue(operand));
+ for (Value operand : curOp.getOperandB())
+ args.push_back(mt.lookupValue(operand));
+ for (Value operand : curOp.getOperandC())
+ args.push_back(mt.lookupValue(operand));
+
+ // Add sparse metadata and selector
+ args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
+ args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
+
+ // Add scale operands
+ args.push_back(mt.lookupValue(curOp.getScaleAData()));
+ args.push_back(mt.lookupValue(curOp.getByteIdA()));
+ args.push_back(mt.lookupValue(curOp.getThreadIdA()));
+ args.push_back(mt.lookupValue(curOp.getScaleBData()));
+ args.push_back(mt.lookupValue(curOp.getByteIdB()));
+ args.push_back(mt.lookupValue(curOp.getThreadIdB()));
+
+ unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
+ curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
+ *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
+ curOp.accumPtxType(),
+ curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind());
+
+ return {intId, args};
+}
+
+LogicalResult MmaSpBlockScaleOp::verify() {
+ // Check that orderedMetadata is present
+ if (!getOrderedMetadata()) {
+ return emitOpError("'orderedMetadata' attribute is mandatory");
+ }
+
+ LogicalResult result = success();
+ int m = getShape().getM();
+ int n = getShape().getN();
+ int k = getShape().getK();
+
+ if (m == 16 && n == 8 && k == 128) {
+ if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
+ getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
+ result = emitOpError(
+ "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
+ if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
+ if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
+ result = emitOpError(
+ "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
+ if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
+ result = emitOpError(
+ "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
+ } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
+ if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+ (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+ result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
+ "attributes for mma.m16n8k128.mxf4nvf4");
+ } else
+ result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
+ } else if (m == 16 && n == 8 && k == 64) {
+ if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
+ getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
+ result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+ "attributes for mma.m16n8k64");
+ } else
+ result = emitOpError("unsupported Geom for sparse mma with block scaling");
+ return result;
+}
+
LogicalResult ShflOp::verify() {
auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir
new file mode 100644
index 0000000000000..fbd0203d19904
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir
@@ -0,0 +1,525 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// This file contains tests for all dense MMA block scale operations in the NVVM dialect
+// Based on PTX ISA documentation:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-with-block-scaling
+//
+// MMA block scale operations perform matrix multiply-accumulate with block scaling:
+// D = matmul(A * SF_A, B * SF_B) + C
+// where SF_A and SF_B are scaling factors with dimensions based on scale vector size.
+
+// =============================================================================
+// MXF8F6F4 Block Scale MMA Operations (m16n8k32) - All Type Combinations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// =============================================================================
+// MXF4 Block Scale MMA Operations (m16n8k64)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mxf4_blockscale_mma
+func.func @nvvm_mxf4_blockscale_mma(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// =============================================================================
+// MXF4NVF4 Block Scale MMA Operations (m16n8k64)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue8m0
+func.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue4m3
+func.func @nvvm_mxf4nvf4_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x4>,
+ blockScaleFormat = #nvvm.block_scale_format<ue4m3>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>}
+ : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir
new file mode 100644
index 0000000000000..2e72012bcf722
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir
@@ -0,0 +1,637 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// This file contains tests for all sparse MMA block scale operations in the NVVM dialect
+// Based on PTX ISA documentation:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-with-block-scaling
+//
+// Sparse MMA block scale operations perform matrix multiply-accumulate with block scaling
+// on sparse matrices: D = matmul(A * SF_A, B * SF_B) + C
+// where A follows 2:4 structured sparsity and SF_A, SF_B are scaling factors.
+
+// =============================================================================
+// MXF8F6F4 Sparse Block Scale MMA Operations (m16n8k64) - All Type Combinations
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// =============================================================================
+// MXF4 Sparse Block Scale MMA Operations (m16n8k128)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mxf4_sp_blockscale_mma
+func.func @nvvm_mxf4_sp_blockscale_mma(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// =============================================================================
+// MXF4NVF4 Sparse Block Scale MMA Operations (m16n8k128)
+// =============================================================================
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0
+func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3
+func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
+ %sparseMetadata: i32, %sparsitySelector: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x4>,
+ blockScaleFormat = #nvvm.block_scale_format<ue4m3>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ orderedMetadata}
+ : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
+ return
+}
>From c664fd01cfa2bf91973c5d4fbdd6fc935f4bdd6f Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 3 Dec 2025 23:30:30 +0100
Subject: [PATCH 02/11] [MLIR] Fixes for code formatting
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 323 +++++++++++----------
1 file changed, 168 insertions(+), 155 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a06fe19b8ad1c..2d40ecf6e5ee6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -693,8 +693,8 @@ void MmaOp::print(OpAsmPrinter &p) {
regTypes.push_back(this->getOperand(operandIdx).getType());
}
}
- std::optional<MMATypes> inferredType =
- MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+ regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
if (inferredType)
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
@@ -1582,21 +1582,23 @@ void printOperandList(OpAsmPrinter &p, StringRef name,
}
// Helper to parse operand list in the format: name[operands]
-LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
+LogicalResult
+parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
if (parser.parseKeyword(operandName).failed())
return failure();
- if (parser.parseOperandList(regs,
- OpAsmParser::Delimiter::OptionalSquare).failed())
+ if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
return failure();
return success();
}
-// Helper to process operand fragments and determine which attributes can be inferred
+// Helper to process operand fragments and determine which attributes can be
+// inferred
template <typename Op>
void processOperandFragments(Op &op, std::array<OperandFragment, 3> &frags,
- SmallVectorImpl<Type> ®Types,
- SmallVectorImpl<StringRef> &ignoreAttrNames) {
+ SmallVectorImpl<Type> ®Types,
+ SmallVectorImpl<StringRef> &ignoreAttrNames) {
for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
auto &frag = frags[fragIdx];
auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
@@ -1639,16 +1641,16 @@ LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
// Helper to infer and set multiplicand PTX type attributes
void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
- const SmallVectorImpl<Type> &operandTypes) {
+ const SmallVectorImpl<Type> &operandTypes) {
if (!attrs.get("multiplicandAPtxType")) {
- if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[0],
- false)) {
+ if (auto inferredType =
+ MmaOp::inferOperandMMAType(operandTypes[0], false)) {
attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
}
}
if (!attrs.get("multiplicandBPtxType")) {
- if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[1],
- false)) {
+ if (auto inferredType =
+ MmaOp::inferOperandMMAType(operandTypes[1], false)) {
attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
}
}
@@ -1656,37 +1658,34 @@ void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
// Helper to add common block scale attributes
void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
- ArrayRef<int64_t> shape,
- ScaleVecSize scaleVecSize,
+ ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
BlockScaleFormat blockScaleFormat,
MMABlockScaleKind kind) {
MLIRContext *ctx = builder.getContext();
- result.addAttribute("shape",
- builder.getAttr<MMAShapeAttr>(shape[0], shape[1],
- shape[2]));
- result.addAttribute("scaleVecSize",
- ScaleVecSizeAttr::get(ctx, scaleVecSize));
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+ result.addAttribute(
+ "scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
result.addAttribute("blockScaleFormat",
BlockScaleFormatAttr::get(ctx, blockScaleFormat));
result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
}
// Helper to infer and add multiplicand PTX types to builder
-void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result,
- ValueRange operandA, ValueRange operandB,
- std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+void addInferredMultiplicandTypes(
+ MLIRContext *ctx, OperationState &result, ValueRange operandA,
+ ValueRange operandB,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
if (multiplicandPtxTypes) {
result.addAttribute("multiplicandAPtxType",
- MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
result.addAttribute("multiplicandBPtxType",
- MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
} else {
if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
- result.addAttribute("multiplicandAPtxType",
- MMATypesAttr::get(ctx, *res));
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
- result.addAttribute("multiplicandBPtxType",
- MMATypesAttr::get(ctx, *res));
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
}
}
@@ -1730,7 +1729,8 @@ void MmaBlockScaleOp::print(OpAsmPrinter &p) {
p << " : (";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
- frags[2].regs[0].getType()}, p);
+ frags[2].regs[0].getType()},
+ p);
p << ")";
p.printArrowTypeList(TypeRange{this->getRes().getType()});
}
@@ -1775,52 +1775,55 @@ ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
for (const auto &[idx, frag] : llvm::enumerate(frags)) {
frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
/*isAccumulator=*/idx >= 2);
- if (parser.resolveOperands(frag.regs, operandTypes[idx],
- parser.getNameLoc(), result.operands).failed())
+ if (parser
+ .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
}
// Resolve scale operands
- SmallVector<Type, 3> scaleTypes = {
- builder.getI32Type(), builder.getI16Type(), builder.getI16Type()
- };
- if (parser.resolveOperands(scaleAOperands, scaleTypes,
- parser.getNameLoc(), result.operands).failed() ||
- parser.resolveOperands(scaleBOperands, scaleTypes,
- parser.getNameLoc(), result.operands).failed())
+ SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
+ builder.getI16Type()};
+ if (parser
+ .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed() ||
+ parser
+ .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
// Add attributes
result.addAttributes(namedAttributes);
- inferAndSetMultiplicandTypes(parser.getContext(),
- result.attributes, operandTypes);
+ inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+ operandTypes);
result.addTypes(resultTypes);
- result.addAttribute(
- MmaBlockScaleOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(frags[0].regs.size()),
- static_cast<int32_t>(frags[1].regs.size()),
- static_cast<int32_t>(frags[2].regs.size()),
- 1, // scaleAData
- 1, // byteIdA
- 1, // threadIdA
- 1, // scaleBData
- 1, // byteIdB
- 1 // threadIdB
- }));
+ result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
return success();
}
-void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result,
- Type resultType, ValueRange operandA,
- ValueRange operandB, ValueRange operandC,
- Value scaleAData, Value byteIdA, Value threadIdA,
- Value scaleBData, Value byteIdB, Value threadIdB,
- ArrayRef<int64_t> shape,
- std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
- ScaleVecSize scaleVecSize,
- BlockScaleFormat blockScaleFormat,
- MMABlockScaleKind kind) {
+void MmaBlockScaleOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
+ Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
addBlockScaleAttributes(builder, result, shape, scaleVecSize,
@@ -1829,25 +1832,25 @@ void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(operandA);
result.addOperands(operandB);
result.addOperands(operandC);
- result.addOperands({scaleAData, byteIdA, threadIdA,
- scaleBData, byteIdB, threadIdB});
+ result.addOperands(
+ {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
- addInferredMultiplicandTypes(builder.getContext(), result,
- operandA, operandB, multiplicandPtxTypes);
+ addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+ multiplicandPtxTypes);
result.addTypes(resultType);
- result.addAttribute(
- MmaBlockScaleOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
- static_cast<int32_t>(operandB.size()),
- static_cast<int32_t>(operandC.size()),
- 1, // scaleAData
- 1, // byteIdA
- 1, // threadIdA
- 1, // scaleBData
- 1, // byteIdB
- 1 // threadIdB
- }));
+ result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()),
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
}
MMATypes MmaBlockScaleOp::accumPtxType() {
@@ -1882,8 +1885,8 @@ NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
unsigned intId = MmaBlockScaleOp::getIntrinsicID(
curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
*curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
- curOp.accumPtxType(),
- curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind());
+ curOp.accumPtxType(), curOp.getScaleVecSize(),
+ curOp.getBlockScaleFormat(), curOp.getKind());
return {intId, args};
}
@@ -1908,9 +1911,9 @@ LogicalResult MmaBlockScaleOp::verify() {
"unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
} else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
- (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+ (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k64.mxf4nvf4");
} else
@@ -1919,8 +1922,9 @@ LogicalResult MmaBlockScaleOp::verify() {
if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
- result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
- "attributes for mma.m16n8k32");
+ result =
+ emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+ "attributes for mma.m16n8k32");
} else
result = emitOpError("unsupported Geom for mma with block scaling");
return result;
@@ -1961,7 +1965,8 @@ void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
p << " : (";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
- frags[2].regs[0].getType()}, p);
+ frags[2].regs[0].getType()},
+ p);
p << ")";
p.printArrowTypeList(TypeRange{this->getRes().getType()});
}
@@ -1984,7 +1989,8 @@ ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
return failure();
// Parse sparse-specific operands
- SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands, selectorOperands;
+ SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands,
+ selectorOperands;
if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
parseMmaOperand(parser, "selector", selectorOperands).failed())
return failure();
@@ -2012,67 +2018,74 @@ ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
for (const auto &[idx, frag] : llvm::enumerate(frags)) {
frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
/*isAccumulator=*/idx >= 2);
- if (parser.resolveOperands(frag.regs, operandTypes[idx],
- parser.getNameLoc(), result.operands).failed())
+ if (parser
+ .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
}
// Resolve sparse metadata and selector
Type i32Type = builder.getI32Type();
- if (parser.resolveOperands(metadataOperands, i32Type,
- parser.getNameLoc(), result.operands).failed() ||
- parser.resolveOperands(selectorOperands, i32Type,
- parser.getNameLoc(), result.operands).failed())
+ if (parser
+ .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
+ result.operands)
+ .failed() ||
+ parser
+ .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
// Resolve scale operands
- SmallVector<Type, 3> scaleTypes = {
- i32Type, builder.getI16Type(), builder.getI16Type()
- };
- if (parser.resolveOperands(scaleAOperands, scaleTypes,
- parser.getNameLoc(), result.operands).failed() ||
- parser.resolveOperands(scaleBOperands, scaleTypes,
- parser.getNameLoc(), result.operands).failed())
+ SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
+ builder.getI16Type()};
+ if (parser
+ .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed() ||
+ parser
+ .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
// Add attributes
result.addAttributes(namedAttributes);
- inferAndSetMultiplicandTypes(parser.getContext(),
- result.attributes, operandTypes);
+ inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+ operandTypes);
// orderedMetadata is mandatory
if (!result.attributes.get("orderedMetadata"))
result.addAttribute("orderedMetadata", builder.getUnitAttr());
result.addTypes(resultTypes);
- result.addAttribute(
- MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(frags[0].regs.size()),
- static_cast<int32_t>(frags[1].regs.size()),
- static_cast<int32_t>(frags[2].regs.size()),
- 1, // sparseMetadata
- 1, // sparsitySelector
- 1, // scaleAData
- 1, // byteIdA
- 1, // threadIdA
- 1, // scaleBData
- 1, // byteIdB
- 1 // threadIdB
- }));
+ result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // sparseMetadata
+ 1, // sparsitySelector
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
return success();
}
-void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result,
- Type resultType, ValueRange operandA,
- ValueRange operandB, ValueRange operandC,
- Value sparseMetadata, Value sparsitySelector,
- Value scaleAData, Value byteIdA, Value threadIdA,
- Value scaleBData, Value byteIdB, Value threadIdB,
- ArrayRef<int64_t> shape,
- std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
- ScaleVecSize scaleVecSize,
- BlockScaleFormat blockScaleFormat,
- MMABlockScaleKind kind) {
+void MmaSpBlockScaleOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value sparseMetadata, Value sparsitySelector, Value scaleAData,
+ Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
+ Value threadIdB, ArrayRef<int64_t> shape,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
addBlockScaleAttributes(builder, result, shape, scaleVecSize,
@@ -2082,28 +2095,27 @@ void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(operandA);
result.addOperands(operandB);
result.addOperands(operandC);
- result.addOperands({sparseMetadata, sparsitySelector,
- scaleAData, byteIdA, threadIdA,
- scaleBData, byteIdB, threadIdB});
+ result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
+ threadIdA, scaleBData, byteIdB, threadIdB});
- addInferredMultiplicandTypes(builder.getContext(), result,
- operandA, operandB, multiplicandPtxTypes);
+ addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+ multiplicandPtxTypes);
result.addTypes(resultType);
- result.addAttribute(
- MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
- static_cast<int32_t>(operandB.size()),
- static_cast<int32_t>(operandC.size()),
- 1, // sparseMetadata
- 1, // sparsitySelector
- 1, // scaleAData
- 1, // byteIdA
- 1, // threadIdA
- 1, // scaleBData
- 1, // byteIdB
- 1 // threadIdB
- }));
+ result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()),
+ 1, // sparseMetadata
+ 1, // sparsitySelector
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
}
MMATypes MmaSpBlockScaleOp::accumPtxType() {
@@ -2142,8 +2154,8 @@ NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
*curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
- curOp.accumPtxType(),
- curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind());
+ curOp.accumPtxType(), curOp.getScaleVecSize(),
+ curOp.getBlockScaleFormat(), curOp.getKind());
return {intId, args};
}
@@ -2173,9 +2185,9 @@ LogicalResult MmaSpBlockScaleOp::verify() {
"unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
} else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
- (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+ (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k128.mxf4nvf4");
} else
@@ -2184,8 +2196,9 @@ LogicalResult MmaSpBlockScaleOp::verify() {
if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
- result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
- "attributes for mma.m16n8k64");
+ result =
+ emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+ "attributes for mma.m16n8k64");
} else
result = emitOpError("unsupported Geom for sparse mma with block scaling");
return result;
>From 765647ab594549bd2678994c0a20fc7692e3432d Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 3 Dec 2025 23:31:42 +0100
Subject: [PATCH 03/11] [MLIR] One more fix for code formatting
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 2d40ecf6e5ee6..3387c69025e20 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1664,8 +1664,7 @@ void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
MLIRContext *ctx = builder.getContext();
result.addAttribute(
"shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
- result.addAttribute(
- "scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
+ result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
result.addAttribute("blockScaleFormat",
BlockScaleFormatAttr::get(ctx, blockScaleFormat));
result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
>From 5febdda3f2742ddec22d3b299ea7e378c14e22d8 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Tue, 9 Dec 2025 15:18:36 +0100
Subject: [PATCH 04/11] [MLIR] Fixes according to feedback from PR170566
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 49 ++++------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 94 +++++++++---------
.../nvvm/tcgen05-mma-block-scale-shared.mlir | 96 +++++++++----------
.../nvvm/tcgen05-mma-block-scale-tensor.mlir | 96 +++++++++----------
.../LLVMIR/nvvm/tcgen05-mma-invalid.mlir | 8 +-
.../tcgen05-mma-sp-block-scale-shared.mlir | 96 +++++++++----------
.../tcgen05-mma-sp-block-scale-tensor.mlir | 96 +++++++++----------
7 files changed, 261 insertions(+), 274 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1faa435fca6f9..e78300ee00f60 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3485,6 +3485,19 @@ def MMABlockScaleKind : I32EnumAttr<
}
def MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, MMABlockScaleKind, "block_scale_kind"> {
+ let description = [{
+ The MMABlockScaleKind attribute describes the allowed set of types for matrix A and B in the *.mma.{sp}.block_scale Op. The following are supported types for each kind:
+
+ ```
+ +--------------+-------------------------------------------+
+ | Matrix Kind | supported types for A / B |
+ +--------------+-------------------------------------------+
+ | mxf8f6f4 | e4m3, e5m3, e2m3, e3m2, e2m1 |
+ | mxf4 | e2m1 |
+ | mxf4nvf4 | e2m1 |
+ +--------------+-------------------------------------------+
+ ```
+ }];
let assemblyFormat = "`<` $value `>`";
}
@@ -5896,36 +5909,6 @@ def NVVM_Tcgen05MMASparseOp : NVVM_Op<"tcgen05.mma.sp",
}];
}
-def Tcgen05MMAKindMXF8F6F4 : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">;
-def Tcgen05MMAKindMXF4 : I32EnumAttrCase<"MXF4", 1, "mxf4">;
-def Tcgen05MMAKindMXF4NVF4 : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">;
-
-def Tcgen05MMABlockScaleKind : I32EnumAttr<
- "Tcgen05MMABlockScaleKind",
- "tcgen05.mma.block_scale supported types",
- [Tcgen05MMAKindMXF8F6F4, Tcgen05MMAKindMXF4, Tcgen05MMAKindMXF4NVF4]> {
- let cppNamespace = "::mlir::NVVM";
- let genSpecializedAttr = 0;
-}
-
-def Tcgen05MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMABlockScaleKind,
- "tcgen05_mma_block_scale_kind"> {
- let description = [{
- The Tcgen05MMABlockScaleKind attribute describes the allowed set of types for matrix A and B in the tcgen05.mma.{sp}.block_scale Op. The following are supported types for each kind:
-
- ```
- +--------------+-------------------------------------------+
- | Matrix Kind | supported types for A / B |
- +--------------+-------------------------------------------+
- | mxf8f6f4 | e4m3, e5m3, e2m3, e3m2, e2m1 |
- | mxf4 | e2m1 |
- | mxf4nvf4 | e2m1 |
- +--------------+-------------------------------------------+
- ```
- }];
- let assemblyFormat = "`<` $value `>`";
-}
-
def Tcgen05MMABlockScaleDefault : I32EnumAttrCase<"DEFAULT", 0, "default">;
def Tcgen05MMABlockScaleBlock16 : I32EnumAttrCase<"BLOCK16", 1, "block16">;
def Tcgen05MMABlockScaleBlock32 : I32EnumAttrCase<"BLOCK32", 2, "block32">;
@@ -5970,7 +5953,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
- `idesc` is a 32 bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor)
Required Attributes:
- - `kind` is a Tcgen05MMABlockScaleKind attribute
+ - `kind` is a MMABlockScaleKind attribute
- `ctaGroup` specifies CTA group configuration
* cta_1: MMA will be performed on the current thread's CTA
@@ -5983,7 +5966,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
}];
let arguments = (ins
- Tcgen05MMABlockScaleKindAttr:$kind,
+ MMABlockScaleKindAttr:$kind,
CTAGroupKindAttr:$ctaGroup,
DefaultValuedAttr<Tcgen05MMABlockScaleAttr,
"Tcgen05MMABlockScale::DEFAULT">:$blockScale,
@@ -6045,7 +6028,7 @@ def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale",
}];
let arguments = (ins
- Tcgen05MMABlockScaleKindAttr:$kind,
+ MMABlockScaleKindAttr:$kind,
CTAGroupKindAttr:$ctaGroup,
DefaultValuedAttr<Tcgen05MMABlockScaleAttr,
"Tcgen05MMABlockScale::DEFAULT">:$blockScale,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3387c69025e20..764ebafa14a1c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -667,18 +667,18 @@ MMATypes MmaOp::resultPtxType() {
void MmaOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
- struct OperandFragment {
+ struct MMAOperandFragment {
StringRef operandName;
StringRef ptxTypeAttr;
SmallVector<Value, 4> regs;
- explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
: operandName(name), ptxTypeAttr(ptxTypeName) {}
};
- std::array<OperandFragment, 3> frags{
- OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
- OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
- OperandFragment("C", "")};
+ std::array<MMAOperandFragment, 3> frags{
+ MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ MMAOperandFragment("C", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
@@ -699,7 +699,7 @@ void MmaOp::print(OpAsmPrinter &p) {
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
- auto printMmaOperand = [&](const OperandFragment &frag) -> void {
+ auto printMmaOperand = [&](const MMAOperandFragment &frag) -> void {
p << " " << frag.operandName;
p << "[";
p.printOperands(frag.regs);
@@ -780,20 +780,20 @@ void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
// `->` type($res)
ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
- struct OperandFragment {
+ struct MMAOperandFragment {
std::optional<MMATypes> elemtype;
SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
SmallVector<Type> regTypes;
};
Builder &builder = parser.getBuilder();
- std::array<OperandFragment, 4> frags;
+ std::array<MMAOperandFragment, 4> frags;
NamedAttrList namedAttributes;
// A helper to parse the operand segments.
auto parseMmaOperand = [&](StringRef operandName,
- OperandFragment &frag) -> LogicalResult {
+ MMAOperandFragment &frag) -> LogicalResult {
if (parser.parseKeyword(operandName).failed())
return failure();
if (parser
@@ -1120,19 +1120,19 @@ MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
void MmaSpOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
- struct OperandFragment {
+ struct MMAOperandFragment {
StringRef operandName;
StringRef ptxTypeAttr;
SmallVector<Value, 4> regs;
- explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
: operandName(name), ptxTypeAttr(ptxTypeName) {}
};
- std::array<OperandFragment, 5> frags{
- OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
- OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
- OperandFragment("C", ""), OperandFragment("sparseMetadata", ""),
- OperandFragment("selector", "")};
+ std::array<MMAOperandFragment, 5> frags{
+ MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ MMAOperandFragment("C", ""), MMAOperandFragment("sparseMetadata", ""),
+ MMAOperandFragment("selector", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
@@ -1158,7 +1158,7 @@ void MmaSpOp::print(OpAsmPrinter &p) {
frags[3].regs.push_back(getSparseMetadata());
frags[4].regs.push_back(getSparsitySelector());
- auto printMmaSpOperand = [&](const OperandFragment &frag) -> void {
+ auto printMmaSpOperand = [&](const MMAOperandFragment &frag) -> void {
p << " " << frag.operandName;
p << "[";
p.printOperands(frag.regs);
@@ -1223,20 +1223,20 @@ void MmaSpOp::build(
}
ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
- struct OperandFragment {
+ struct MMAOperandFragment {
std::optional<MMATypes> elemtype;
SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
SmallVector<Type> regTypes;
};
Builder &builder = parser.getBuilder();
- std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
+ std::array<MMAOperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
NamedAttrList namedAttributes;
// A helper to parse the operand segments.
auto parseMmaSpOperand = [&](StringRef operandName,
- OperandFragment &frag) -> LogicalResult {
+ MMAOperandFragment &frag) -> LogicalResult {
if (parser.parseKeyword(operandName).failed())
return failure();
if (parser
@@ -1565,11 +1565,11 @@ LogicalResult MmaSpOp::verify() {
namespace {
// Shared structure for MMA operand fragments (A, B, C)
-struct OperandFragment {
+struct MMAOperandFragment {
StringRef operandName;
StringRef ptxTypeAttr;
SmallVector<Value, 4> regs;
- explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
: operandName(name), ptxTypeAttr(ptxTypeName) {}
};
@@ -1596,7 +1596,7 @@ parseMmaOperand(OpAsmParser &parser, StringRef operandName,
// Helper to process operand fragments and determine which attributes can be
// inferred
template <typename Op>
-void processOperandFragments(Op &op, std::array<OperandFragment, 3> &frags,
+void processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
SmallVectorImpl<Type> ®Types,
SmallVectorImpl<StringRef> &ignoreAttrNames) {
for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
@@ -1703,10 +1703,10 @@ MMATypes inferPtxTypeFromResult(OpTy op) {
void MmaBlockScaleOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
- std::array<OperandFragment, 3> frags{
- OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
- OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
- OperandFragment("C", "")};
+ std::array<MMAOperandFragment, 3> frags{
+ MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ MMAOperandFragment("C", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
@@ -1915,8 +1915,9 @@ LogicalResult MmaBlockScaleOp::verify() {
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k64.mxf4nvf4");
- } else
+ } else {
result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
+ }
} else if (m == 16 && n == 8 && k == 32) {
if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
@@ -1924,8 +1925,9 @@ LogicalResult MmaBlockScaleOp::verify() {
result =
emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k32");
- } else
+ } else {
result = emitOpError("unsupported Geom for mma with block scaling");
+ }
return result;
}
@@ -1935,10 +1937,10 @@ LogicalResult MmaBlockScaleOp::verify() {
void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
- std::array<OperandFragment, 3> frags{
- OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
- OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
- OperandFragment("C", "")};
+ std::array<MMAOperandFragment, 3> frags{
+ MMAOperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ MMAOperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ MMAOperandFragment("C", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
@@ -2189,8 +2191,9 @@ LogicalResult MmaSpBlockScaleOp::verify() {
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k128.mxf4nvf4");
- } else
+ } else {
result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
+ }
} else if (m == 16 && n == 8 && k == 64) {
if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
@@ -2198,8 +2201,9 @@ LogicalResult MmaSpBlockScaleOp::verify() {
result =
emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k64");
- } else
+ } else {
result = emitOpError("unsupported Geom for sparse mma with block scaling");
+ }
return result;
}
@@ -5046,7 +5050,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
auto kind = thisOp.getKind();
auto blockScale = thisOp.getBlockScale();
llvm::Intrinsic::ID ID = [&]() {
- if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
@@ -5059,7 +5063,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
}
- } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor
? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
@@ -5070,7 +5074,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
}
- } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
@@ -5094,16 +5098,16 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
static LogicalResult
verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp,
- NVVM::Tcgen05MMABlockScaleKind kind,
+ NVVM::MMABlockScaleKind kind,
NVVM::Tcgen05MMABlockScale blockScale,
Location loc) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
- kind == Tcgen05MMABlockScaleKind::MXF4NVF4)
+ kind == MMABlockScaleKind::MXF4NVF4)
return emitError(loc, "mxf4nvf4 requires block scale attribute");
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
- kind != Tcgen05MMABlockScaleKind::MXF4NVF4)
+ kind != MMABlockScaleKind::MXF4NVF4)
return emitError(loc,
llvm::formatv("{} kind does not support block16 attribute",
stringifyEnum(kind)));
@@ -5146,7 +5150,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
auto kind = thisOp.getKind();
auto blockScale = thisOp.getBlockScale();
llvm::Intrinsic::ID ID = [&]() {
- if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
+ if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
@@ -5159,7 +5163,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
}
- } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
+ } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
@@ -5172,7 +5176,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
}
- } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
+ } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
index db4574bfaf78f..9f7dd3ed4b6b4 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
@@ -5,35 +5,35 @@ llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -43,35 +43,35 @@ llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -81,35 +81,35 @@ llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_de
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -119,35 +119,35 @@ llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_de
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -157,35 +157,35 @@ llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -195,35 +195,35 @@ llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir
index a15c3fb73de9c..7f0376951a047 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir
@@ -5,35 +5,35 @@ llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -43,35 +43,35 @@ llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -81,35 +81,35 @@ llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tm
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -119,35 +119,35 @@ llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tm
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -157,35 +157,35 @@ llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -195,35 +195,35 @@ llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir
index f46b35a910fd9..7b6b6c24c0180 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir
@@ -44,7 +44,7 @@ llvm.func @nvvm_tcgen05_mma_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>
llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) {
// expected-error @below {{mxf4nvf4 requires block scale attribute}}
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -54,7 +54,7 @@ llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>,
llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) {
// expected-error @below {{mxf4 kind does not support block16 attribute}}
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -104,7 +104,7 @@ llvm.func @nvvm_tcgen05_mma_sp_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr
llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
// expected-error @below {{mxf4nvf4 requires block scale attribute}}
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -114,6 +114,6 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<
llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
// expected-error @below {{mxf4 kind does not support block16 attribute}}
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir
index 5c7eabee71b4e..7cd0989d07b8b 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir
@@ -5,35 +5,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -43,35 +43,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -81,35 +81,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -119,35 +119,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -157,35 +157,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -195,35 +195,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir
index 3200411aee213..717f760589b70 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir
@@ -5,35 +5,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -43,35 +43,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -81,35 +81,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -119,35 +119,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -157,35 +157,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -195,35 +195,35 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3)
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b
- {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
>From b87ec837844dd886a37385270a1f32b61ad9fe6c Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Tue, 9 Dec 2025 15:34:05 +0100
Subject: [PATCH 05/11] [MLIR] Fixes for code formatting
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 764ebafa14a1c..68150afc9ede7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -5096,11 +5096,9 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
return {ID, args};
}
-static LogicalResult
-verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp,
- NVVM::MMABlockScaleKind kind,
- NVVM::Tcgen05MMABlockScale blockScale,
- Location loc) {
+static LogicalResult verifyTcgen05MMABlockScaleOp(
+ NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind,
+ NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
kind == MMABlockScaleKind::MXF4NVF4)
>From d78c4b996bd2a41fcab652c35cce5a9034bc7417 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 10 Dec 2025 06:13:12 +0100
Subject: [PATCH 06/11] [MLIR] More updates for PR170566
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 7 -------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 20 ++------------------
2 files changed, 2 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e78300ee00f60..3659d7deb0910 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3698,10 +3698,6 @@ class NVVM_MmaBlockScaleBase<string mnemonic, list<Trait> traits = []> :
return 0;
}
- // Common declarations - implementations in NVVMDialect.cpp
- MMATypes accumPtxType();
- MMATypes resultPtxType();
-
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
@@ -3848,9 +3844,6 @@ def NVVM_MmaSpBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.sp.block_scale"> {
return 0;
}
- MMATypes accumPtxType();
- MMATypes resultPtxType();
-
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 68150afc9ede7..0d0ef4ef43dc4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1852,14 +1852,6 @@ void MmaBlockScaleOp::build(
}));
}
-MMATypes MmaBlockScaleOp::accumPtxType() {
- return inferPtxTypeFromResult(*this);
-}
-
-MMATypes MmaBlockScaleOp::resultPtxType() {
- return inferPtxTypeFromResult(*this);
-}
-
NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
@@ -1884,7 +1876,7 @@ NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
unsigned intId = MmaBlockScaleOp::getIntrinsicID(
curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
*curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
- curOp.accumPtxType(), curOp.getScaleVecSize(),
+ inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
curOp.getBlockScaleFormat(), curOp.getKind());
return {intId, args};
@@ -2119,14 +2111,6 @@ void MmaSpBlockScaleOp::build(
}));
}
-MMATypes MmaSpBlockScaleOp::accumPtxType() {
- return inferPtxTypeFromResult(*this);
-}
-
-MMATypes MmaSpBlockScaleOp::resultPtxType() {
- return inferPtxTypeFromResult(*this);
-}
-
NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
@@ -2155,7 +2139,7 @@ NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
*curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
- curOp.accumPtxType(), curOp.getScaleVecSize(),
+ inferPtxTypeFromResult(curOp), curOp.getScaleVecSize(),
curOp.getBlockScaleFormat(), curOp.getKind());
return {intId, args};
>From 23d75e119eee07581ae9324d259ad3af425a983e Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 10 Dec 2025 12:22:17 +0100
Subject: [PATCH 07/11] [MLIR][NVVM] Updates according to feedback from
PR170566
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 37 ++++++++++++----------
1 file changed, 21 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 0d0ef4ef43dc4..2c2e62b130203 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1627,14 +1627,19 @@ LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
if (parser.parseColon().failed() || parser.parseLParen().failed())
return failure();
- for (int i = 0; i < 3; i++) {
- if (i > 0 && parser.parseComma().failed())
- return failure();
+ auto typeParser = [&]() {
Type ty;
if (parser.parseType(ty).failed())
return failure();
operandTypes.push_back(ty);
- }
+ return success();
+ };
+ if (parser.parseCommaSeparatedList(typeParser))
+ return failure();
+
+ if (operandTypes.size() != 3)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected exactly 3 types");
return parser.parseRParen();
}
@@ -1656,18 +1661,18 @@ void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
}
}
-// Helper to add common block scale attributes
-void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
+// Helper to add common block scale properties
+template <typename OpType>
+void addBlockScaleProperties(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
BlockScaleFormat blockScaleFormat,
MMABlockScaleKind kind) {
MLIRContext *ctx = builder.getContext();
- result.addAttribute(
- "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
- result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
- result.addAttribute("blockScaleFormat",
- BlockScaleFormatAttr::get(ctx, blockScaleFormat));
- result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
+ auto &properties = result.getOrAddProperties<typename OpType::Properties>();
+ properties.setShape(builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+ properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
+ properties.setBlockScaleFormat(BlockScaleFormatAttr::get(ctx, blockScaleFormat));
+ properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
}
// Helper to infer and add multiplicand PTX types to builder
@@ -1825,8 +1830,8 @@ void MmaBlockScaleOp::build(
MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
- addBlockScaleAttributes(builder, result, shape, scaleVecSize,
- blockScaleFormat, kind);
+ addBlockScaleProperties<MmaBlockScaleOp>(builder, result, shape, scaleVecSize,
+ blockScaleFormat, kind);
result.addOperands(operandA);
result.addOperands(operandB);
@@ -2081,8 +2086,8 @@ void MmaSpBlockScaleOp::build(
MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
- addBlockScaleAttributes(builder, result, shape, scaleVecSize,
- blockScaleFormat, kind);
+ addBlockScaleProperties<MmaSpBlockScaleOp>(builder, result, shape, scaleVecSize,
+ blockScaleFormat, kind);
result.addAttribute("orderedMetadata", builder.getUnitAttr());
result.addOperands(operandA);
>From 6267a79592175b48fa7c819bec80423b518a35b8 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 10 Dec 2025 12:44:47 +0100
Subject: [PATCH 08/11] [MLIR][NVVM] Code formatting was fixed.
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 2c2e62b130203..5d06ba057164f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1669,9 +1669,11 @@ void addBlockScaleProperties(OpBuilder &builder, OperationState &result,
MMABlockScaleKind kind) {
MLIRContext *ctx = builder.getContext();
auto &properties = result.getOrAddProperties<typename OpType::Properties>();
- properties.setShape(builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+ properties.setShape(
+ builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
- properties.setBlockScaleFormat(BlockScaleFormatAttr::get(ctx, blockScaleFormat));
+ properties.setBlockScaleFormat(
+ BlockScaleFormatAttr::get(ctx, blockScaleFormat));
properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
}
@@ -1830,8 +1832,8 @@ void MmaBlockScaleOp::build(
MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
- addBlockScaleProperties<MmaBlockScaleOp>(builder, result, shape, scaleVecSize,
- blockScaleFormat, kind);
+ addBlockScaleProperties<MmaBlockScaleOp>(
+ builder, result, shape, scaleVecSize, blockScaleFormat, kind);
result.addOperands(operandA);
result.addOperands(operandB);
>From 89b342047d518ce8514255b98402740285fa4e50 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 10 Dec 2025 12:57:39 +0100
Subject: [PATCH 09/11] [MLIR][NVVM] One more fix for code formatting
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5d06ba057164f..7541125189aa9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1832,8 +1832,8 @@ void MmaBlockScaleOp::build(
MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
- addBlockScaleProperties<MmaBlockScaleOp>(
- builder, result, shape, scaleVecSize, blockScaleFormat, kind);
+ addBlockScaleProperties<MmaBlockScaleOp>(builder, result, shape, scaleVecSize,
+ blockScaleFormat, kind);
result.addOperands(operandA);
result.addOperands(operandB);
@@ -2088,8 +2088,8 @@ void MmaSpBlockScaleOp::build(
MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
- addBlockScaleProperties<MmaSpBlockScaleOp>(builder, result, shape, scaleVecSize,
- blockScaleFormat, kind);
+ addBlockScaleProperties<MmaSpBlockScaleOp>(
+ builder, result, shape, scaleVecSize, blockScaleFormat, kind);
result.addAttribute("orderedMetadata", builder.getUnitAttr());
result.addOperands(operandA);
>From ae5f53d01cb330f31661430244af4373353b113c Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 10 Dec 2025 16:02:32 +0100
Subject: [PATCH 10/11] [MLIR][NVVM] Fixes for static and namespace usage
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 38 ++++++++++++----------
1 file changed, 21 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7541125189aa9..254e929d39aa3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1572,17 +1572,18 @@ struct MMAOperandFragment {
explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
: operandName(name), ptxTypeAttr(ptxTypeName) {}
};
+} // namespace
// Helper to print operand list in the format: name[operands]
-void printOperandList(OpAsmPrinter &p, StringRef name,
- ArrayRef<Value> operands) {
+static void printOperandList(OpAsmPrinter &p, StringRef name,
+ ArrayRef<Value> operands) {
p << " " << name << "[";
p.printOperands(operands);
p << "]";
}
// Helper to parse operand list in the format: name[operands]
-LogicalResult
+static LogicalResult
parseMmaOperand(OpAsmParser &parser, StringRef operandName,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
if (parser.parseKeyword(operandName).failed())
@@ -1596,9 +1597,10 @@ parseMmaOperand(OpAsmParser &parser, StringRef operandName,
// Helper to process operand fragments and determine which attributes can be
// inferred
template <typename Op>
-void processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
- SmallVectorImpl<Type> ®Types,
- SmallVectorImpl<StringRef> &ignoreAttrNames) {
+static void
+processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
+ SmallVectorImpl<Type> ®Types,
+ SmallVectorImpl<StringRef> &ignoreAttrNames) {
for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
auto &frag = frags[fragIdx];
auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
@@ -1622,8 +1624,9 @@ void processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
}
// Helper to parse type signature: (A_type, B_type, C_type)
-LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
- SmallVectorImpl<Type> &operandTypes) {
+static LogicalResult
+parseMmaTypeSignature(OpAsmParser &parser,
+ SmallVectorImpl<Type> &operandTypes) {
if (parser.parseColon().failed() || parser.parseLParen().failed())
return failure();
@@ -1645,8 +1648,9 @@ LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
}
// Helper to infer and set multiplicand PTX type attributes
-void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
- const SmallVectorImpl<Type> &operandTypes) {
+static void
+inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
+ const SmallVectorImpl<Type> &operandTypes) {
if (!attrs.get("multiplicandAPtxType")) {
if (auto inferredType =
MmaOp::inferOperandMMAType(operandTypes[0], false)) {
@@ -1663,10 +1667,11 @@ void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
// Helper to add common block scale properties
template <typename OpType>
-void addBlockScaleProperties(OpBuilder &builder, OperationState &result,
- ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
- BlockScaleFormat blockScaleFormat,
- MMABlockScaleKind kind) {
+static void addBlockScaleProperties(OpBuilder &builder, OperationState &result,
+ ArrayRef<int64_t> shape,
+ ScaleVecSize scaleVecSize,
+ BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
MLIRContext *ctx = builder.getContext();
auto &properties = result.getOrAddProperties<typename OpType::Properties>();
properties.setShape(
@@ -1678,7 +1683,7 @@ void addBlockScaleProperties(OpBuilder &builder, OperationState &result,
}
// Helper to infer and add multiplicand PTX types to builder
-void addInferredMultiplicandTypes(
+static void addInferredMultiplicandTypes(
MLIRContext *ctx, OperationState &result, ValueRange operandA,
ValueRange operandB,
std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
@@ -1697,12 +1702,11 @@ void addInferredMultiplicandTypes(
// Template helper for common accumPtxType/resultPtxType implementation
template <typename OpTy>
-MMATypes inferPtxTypeFromResult(OpTy op) {
+static MMATypes inferPtxTypeFromResult(OpTy op) {
return *MmaOp::inferOperandMMAType(
cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
/*isAccumulator=*/true);
}
-} // namespace
//===----------------------------------------------------------------------===//
// MmaBlockScaleOp
>From 12afe48b94a1633dc867b79875468df0325e2eb9 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Thu, 11 Dec 2025 20:43:22 +0100
Subject: [PATCH 11/11] [MLIR][NVVM] A typo fixed in description (e5m3 should
be e5m2).
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 3659d7deb0910..ff82a99d498d4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3492,7 +3492,7 @@ def MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, MMABlockScaleKind, "block_sca
+--------------+-------------------------------------------+
| Matrix Kind | supported types for A / B |
+--------------+-------------------------------------------+
- | mxf8f6f4 | e4m3, e5m3, e2m3, e3m2, e2m1 |
+ | mxf8f6f4 | e4m3, e5m2, e2m3, e3m2, e2m1 |
| mxf4 | e2m1 |
| mxf4nvf4 | e2m1 |
+--------------+-------------------------------------------+
More information about the Mlir-commits
mailing list