[llvm] [mlir] [MLIR][NVVM][NVPTX] Support for new mma/mma.sp variants from PTX 9.1 (PR #182325)
Kirill Vedernikov via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 20 01:18:28 PST 2026
https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/182325
>From 52f6c54887409810f1e10edc8b92d6afe92fb84b Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Thu, 19 Feb 2026 18:30:10 +0100
Subject: [PATCH 1/2] [MLIR][NVVM][NVPTX] Support for new mma/mma.sp variants
from PTX 9.1 Updated MLIR mma/mma.sp block scale tests with struct instead of
vector
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 6 +-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 39 +-
llvm/lib/Target/NVPTX/NVPTXSubtarget.h | 8 +
llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120a.py | 12 +
llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120f.py | 12 +
llvm/test/CodeGen/NVPTX/wmma.py | 26 +-
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 6 +-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 6 +-
.../Dialect/LLVMIR/nvvm-mma-blockscale.mlir | 385 ++++++++------
.../LLVMIR/nvvm-mma-sparse-blockscale.mlir | 474 +++++++++++-------
10 files changed, 649 insertions(+), 325 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120a.py
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120f.py
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index f2e1bcb5517c8..f0c3c32086954 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -973,7 +973,8 @@ class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string
!and(!eq(geom, "m16n8k64"),
!eq(kind, "mxf4nvf4"),
!eq(scale_vec_size, ".scale_4x"),
- !eq(stype, "ue4m3")) : true,
+ !or(!eq(stype, "ue4m3"),
+ !eq(stype, "ue8m0"))) : true,
!and(!eq(geom, "m16n8k32"),
!eq(kind, "mxf8f6f4"),
!or(!eq(scale_vec_size, ""),
@@ -1111,7 +1112,8 @@ class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
!and(!eq(geom, "m16n8k128"),
!eq(kind, "mxf4nvf4"),
- !eq(stype, "ue4m3"),
+ !or(!eq(stype, "ue4m3"),
+ !eq(stype, "ue8m0")),
!eq(scale_vec_size, ".scale_4x")): true,
!and(!eq(geom, "m16n8k64"),
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 7b7b11d14ecc8..102ce8cc37236 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4936,7 +4936,8 @@ def INT_PTX_SREG_WARPSIZE :
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
+class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "",
+ string kind = "", string stype = "", string scale = "">
: WMMA_REGS<r.geom, r.frag, r.ptx_elt_type,
!or(!eq(op, "mma.sp"), !eq(op, "mma.sp.block_scale"))> {
// NVPTX register types used to carry fragment data.
@@ -4978,6 +4979,18 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
// longer the case, we can concat all per-fragment predicates to enforce that
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
+ !and(!eq(op, "mma.block_scale"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue8m0"),
+ !eq(scale, ".scale_4x"))
+ : [callSubtarget<"hasMMABlockScaleNF4X4E8">],
+
+ !and(!eq(op, "mma.sp.block_scale"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue8m0"),
+ !eq(scale, ".scale_4x"))
+ : [callSubtarget<"hasMMASparseBlockScaleNF4X4E8">],
+
!and(!eq(op, "mma.sp.block_scale"),
!eq(kind, "mxf4nvf4"),
!eq(kind, "mxf4")) : [callSubtarget<"hasMMASparseBlockScaleF4">],
@@ -5396,10 +5409,14 @@ defset list<WMMA_INSTR> MMA_BLOCK_SCALEs = {
foreach stype = ["ue8m0", "ue4m3"] in {
foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
- def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind>,
- WMMA_REGINFO<op[1], "mma.block_scale", "", kind>,
- WMMA_REGINFO<op[2], "mma.block_scale", "", kind>,
- WMMA_REGINFO<op[3], "mma.block_scale", "", kind>,
+ def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind,
+ stype, scale_vec_size>,
+ WMMA_REGINFO<op[1], "mma.block_scale", "", kind,
+ stype, scale_vec_size>,
+ WMMA_REGINFO<op[2], "mma.block_scale", "", kind,
+ stype, scale_vec_size>,
+ WMMA_REGINFO<op[3], "mma.block_scale", "", kind,
+ stype, scale_vec_size>,
kind, stype, scale_vec_size>;
}
} // op
@@ -5518,10 +5535,14 @@ defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
foreach stype = ["ue8m0", "ue4m3"] in {
foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
- def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
- WMMA_REGINFO<op[1], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
- WMMA_REGINFO<op[2], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
- WMMA_REGINFO<op[3], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale", "sp::ordered_metadata", kind,
+ stype, scale_vec_size>,
+ WMMA_REGINFO<op[1], "mma.sp.block_scale", "sp::ordered_metadata", kind,
+ stype, scale_vec_size>,
+ WMMA_REGINFO<op[2], "mma.sp.block_scale", "sp::ordered_metadata", kind,
+ stype, scale_vec_size>,
+ WMMA_REGINFO<op[3], "mma.sp.block_scale", "sp::ordered_metadata", kind,
+ stype, scale_vec_size>,
kind, stype, scale_vec_size>;
}
} // op
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 8286fa9f9559b..bbe4e2f44f4e6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -184,6 +184,14 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
return hasPTXWithAccelSMs(87, {120, 121});
}
+ bool hasMMABlockScaleNF4X4E8() const {
+ return hasPTXWithFamilySMs(91, {120});
+ }
+
+ bool hasMMASparseBlockScaleNF4X4E8() const {
+ return hasPTXWithAccelSMs(91, {120, 121});
+ }
+
// f32x2 instructions in Blackwell family
bool hasF32x2Instructions() const;
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120a.py
new file mode 100644
index 0000000000000..090b81c9de8ab
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120a.py
@@ -0,0 +1,12 @@
+# Check all variants of instructions supported by PTX91 on SM120a
+# RUN: %python %s --ptx=91 --gpu-arch=120a > %t-ptx91-sm_120a.ll
+# RUN: llc < %t-ptx91-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx91 \
+# RUN: | FileCheck %t-ptx91-sm_120a.ll
+# RUN: %if ptxas-sm_120a && ptxas-isa-9.1 %{ \
+# RUN: llc < %t-ptx91-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx91 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120f.py b/llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120f.py
new file mode 100644
index 0000000000000..df5d900f9a92d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx91-sm120f.py
@@ -0,0 +1,12 @@
+# Check all variants of instructions supported by PTX91 on SM120f
+# RUN: %python %s --ptx=91 --gpu-arch=120f > %t-ptx91-sm_120f.ll
+# RUN: llc < %t-ptx91-sm_120f.ll -mtriple=nvptx64 -mcpu=sm_120f -mattr=+ptx91 \
+# RUN: | FileCheck %t-ptx91-sm_120f.ll
+# RUN: %if ptxas-sm_120f && ptxas-isa-9.1 %{ \
+# RUN: llc < %t-ptx91-sm_120f.ll -mtriple=nvptx64 -mcpu=sm_120f -mattr=+ptx91 \
+# RUN: | %ptxas-verify -arch=sm_120f \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 134105e970146..2bd0796c68f52 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1175,6 +1175,16 @@ def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
):
return False
+ if (
+ kind == "mxf4nvf4"
+ and scale_vec_size == ".scale_vec::4X"
+ and stype == "ue8m0"
+ and not (
+ ptx_version >= 91 and sm_version == 120 and has_family_specific_features()
+ )
+ ):
+ return False
+
if (
op.a.geom == "m16n8k64"
and kind == "mxf4"
@@ -1194,7 +1204,7 @@ def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
if (
op.a.geom == "m16n8k64"
and kind == "mxf4nvf4"
- and stype == "ue4m3"
+ and stype in ["ue4m3", "ue8m0"]
and scale_vec_size == ".scale_vec::4X"
):
return True
@@ -1560,6 +1570,18 @@ def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
):
return False
+ if (
+ kind == "mxf4nvf4"
+ and scale_vec_size == ".scale_vec::4X"
+ and stype == "ue8m0"
+ and not (
+ ptx_version >= 91
+ and (sm_version == 120 or sm_version == 121)
+ and has_arch_accel_features()
+ )
+ ):
+ return False
+
if (
op.a.geom == "m16n8k128"
and kind == "mxf4"
@@ -1579,7 +1601,7 @@ def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
if (
op.a.geom == "m16n8k128"
and kind == "mxf4nvf4"
- and stype == "ue4m3"
+ and stype in ["ue4m3", "ue8m0"]
and scale_vec_size == ".scale_vec::4X"
):
return True
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9725cdb6a4c4d..74c5a3fa7f1dd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3498,7 +3498,8 @@ class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
!and(!eq(geom, "m16n8k64"),
!eq(kind, "mxf4nvf4"),
!eq(scale_vec_size, ".scale_4x"),
- !eq(stype, "ue4m3")) : true,
+ !or(!eq(stype, "ue4m3"),
+ !eq(stype, "ue8m0"))) : true,
!and(!eq(geom, "m16n8k32"),
!eq(kind, "mxf8f6f4"),
!or(!eq(scale_vec_size, ""),
@@ -3525,7 +3526,8 @@ class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
!eq(scale_vec_size, ".scale_2x")): true,
!and(!eq(geom, "m16n8k128"),
!eq(kind, "mxf4nvf4"),
- !eq(stype, "ue4m3"),
+ !or(!eq(stype, "ue4m3"),
+ !eq(stype, "ue8m0")),
!eq(scale_vec_size, ".scale_4x")): true,
!and(!eq(geom, "m16n8k64"),
!eq(kind, "mxf8f6f4"),
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index b51eb3408df00..ddab6ae409f82 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1917,7 +1917,8 @@ LogicalResult MmaBlockScaleOp::verify() {
if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
(getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+ (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k64.mxf4nvf4");
} else {
@@ -2185,7 +2186,8 @@ LogicalResult MmaSpBlockScaleOp::verify() {
if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
(getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+ (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k128.mxf4nvf4");
} else {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir
index fbd0203d19904..0785e5b124b45 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir
@@ -13,11 +13,14 @@
// =============================================================================
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -26,16 +29,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -44,16 +50,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -62,16 +71,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -80,16 +92,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -98,16 +113,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -116,16 +134,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -134,16 +155,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -152,16 +176,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -170,16 +197,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -188,16 +218,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1
-func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -206,16 +239,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -224,16 +260,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -242,16 +281,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -260,16 +302,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -278,16 +323,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1
-func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -296,16 +344,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -314,16 +365,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -332,16 +386,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -350,16 +407,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -368,16 +428,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1
-func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -386,16 +449,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -404,16 +470,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -422,16 +491,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3
-func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -440,16 +512,19 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2
-func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>,
@@ -458,8 +533,8 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<
scaleVecSize = #nvvm.scale_vec_size<x1>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -467,11 +542,14 @@ func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<
// =============================================================================
// CHECK-LABEL: @nvvm_mxf4_blockscale_mma
-func.func @nvvm_mxf4_blockscale_mma(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf4_blockscale_mma(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 64>,
@@ -480,8 +558,8 @@ func.func @nvvm_mxf4_blockscale_mma(%a: vector<4xi32>, %b: vector<2xi32>, %c: ve
scaleVecSize = #nvvm.scale_vec_size<x2>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -489,11 +567,14 @@ func.func @nvvm_mxf4_blockscale_mma(%a: vector<4xi32>, %b: vector<2xi32>, %c: ve
// =============================================================================
// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue8m0
-func.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 64>,
@@ -502,16 +583,19 @@ func.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<2xi3
scaleVecSize = #nvvm.scale_vec_size<x2>,
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf4nvf4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue4m3
-func.func @nvvm_mxf4nvf4_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>,
+func.func @nvvm_mxf4nvf4_blockscale_mma_ue4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
scaleA[%scaleAData, %byteIdA, %threadIdA]
scaleB[%scaleBData, %byteIdB, %threadIdB]
{shape = #nvvm.shape<m = 16, n = 8, k = 64>,
@@ -520,6 +604,27 @@ func.func @nvvm_mxf4nvf4_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<2xi3
scaleVecSize = #nvvm.scale_vec_size<x4>,
blockScaleFormat = #nvvm.block_scale_format<ue4m3>,
kind = #nvvm.block_scale_kind<mxf4nvf4>}
- : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue8m0_x4
+func.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0_x4(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x4>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir
index 2e72012bcf722..c29a906b49eed 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir
@@ -13,12 +13,16 @@
// =============================================================================
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -30,17 +34,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -52,17 +60,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -74,17 +86,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -96,17 +112,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -118,17 +138,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -140,17 +164,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -162,17 +190,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -184,17 +216,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -206,17 +242,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -228,17 +268,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -250,17 +294,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -272,17 +320,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -294,17 +346,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -316,17 +372,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -338,17 +398,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -360,17 +424,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -382,17 +450,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -404,17 +476,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -426,17 +502,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -448,17 +528,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -470,17 +554,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -492,17 +580,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -514,17 +606,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -536,17 +632,21 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2
-func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -558,8 +658,8 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vect
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf8f6f4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -567,12 +667,16 @@ func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vect
// =============================================================================
// CHECK-LABEL: @nvvm_mxf4_sp_blockscale_mma
-func.func @nvvm_mxf4_sp_blockscale_mma(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf4_sp_blockscale_mma(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -584,8 +688,8 @@ func.func @nvvm_mxf4_sp_blockscale_mma(%a: vector<4xi32>, %b: vector<4xi32>, %c:
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -593,12 +697,16 @@ func.func @nvvm_mxf4_sp_blockscale_mma(%a: vector<4xi32>, %b: vector<4xi32>, %c:
// =============================================================================
// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0
-func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -610,17 +718,21 @@ func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<4
blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
kind = #nvvm.block_scale_kind<mxf4nvf4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3
-func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>,
- %sparseMetadata: i32, %sparsitySelector: i32,
+func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
%scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
- %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) {
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
- %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%sparseMetadata]
selector[%sparsitySelector]
scaleA[%scaleAData, %byteIdA, %threadIdA]
@@ -632,6 +744,32 @@ func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<4
blockScaleFormat = #nvvm.block_scale_format<ue4m3>,
kind = #nvvm.block_scale_kind<mxf4nvf4>,
orderedMetadata}
- : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)>
- return
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0_x4
+func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0_x4(%a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %sparseMetadata: i32,
+ %scaleAData: i32, %byteIdA: i16, %threadIdA: i16,
+ %scaleBData: i32, %byteIdB: i16, %threadIdB: i16)
+ -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sparsitySelector = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}]
+ %0 = nvvm.mma.sp.block_scale A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%sparseMetadata]
+ selector[%sparsitySelector]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x4>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ orderedMetadata}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
>From 17404fd159a2cef47b2976c80e20595e933c8b52 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Fri, 20 Feb 2026 10:15:33 +0100
Subject: [PATCH 2/2] [NVPTX] renamed predicate functions for newly added
mma/mma.sp
---
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 4 ++--
llvm/lib/Target/NVPTX/NVPTXSubtarget.h | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 102ce8cc37236..090ee63589fcc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4983,13 +4983,13 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "",
!eq(kind, "mxf4nvf4"),
!eq(stype, "ue8m0"),
!eq(scale, ".scale_4x"))
- : [callSubtarget<"hasMMABlockScaleNF4X4E8">],
+ : [callSubtarget<"hasMMAWithMXF4NVF4Scale4xE8M0">],
!and(!eq(op, "mma.sp.block_scale"),
!eq(kind, "mxf4nvf4"),
!eq(stype, "ue8m0"),
!eq(scale, ".scale_4x"))
- : [callSubtarget<"hasMMASparseBlockScaleNF4X4E8">],
+ : [callSubtarget<"hasMMASparseWithMXF4NVF4Scale4xE8M0">],
!and(!eq(op, "mma.sp.block_scale"),
!eq(kind, "mxf4nvf4"),
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index bbe4e2f44f4e6..8ff34284e6b16 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -184,11 +184,11 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
return hasPTXWithAccelSMs(87, {120, 121});
}
- bool hasMMABlockScaleNF4X4E8() const {
+ bool hasMMAWithMXF4NVF4Scale4xE8M0() const {
return hasPTXWithFamilySMs(91, {120});
}
- bool hasMMASparseBlockScaleNF4X4E8() const {
+ bool hasMMASparseWithMXF4NVF4Scale4xE8M0() const {
return hasPTXWithAccelSMs(91, {120, 121});
}
More information about the llvm-commits
mailing list