[Mlir-commits] [mlir] [MLIR][NVVM] Fix kFactor for fp8/fp6/fp4 types in MmaSpOp verifier. Improve mma tests. (PR #183133)
Kirill Vedernikov
llvmlistbot at llvm.org
Wed Feb 25 05:14:01 PST 2026
https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/183133
>From f96887d188d23cb9bde199a04f286f40283da591 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Tue, 24 Feb 2026 20:08:09 +0100
Subject: [PATCH] [MLIR][NVVM] Fix kFactor for fp8/fp6/fp4 types in MmaSpOp
verifier. Improve mma tests.
Fix an incorrect kFactor value for e4m3/e5m2, e3m2/e2m3, e2m1 types in
MmaSpOp::verify(). The kFactor for these types was set to 32 but should
be 16.
kFactor is used to compute the expected number of operand A/B register
fragments. With kFactor=32 (wrong) and the only allowed shape m16n8k64,
the fragment count was incorrect. With kFactor=16 (correct), it matches
the PTX ISA definition for mma.sp with fp8/fp6/fp4 A/B operands.
PTX ISA reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-sparse-mma
Also improve existing MLIR dialect tests for nvvm.mma.sp.sync and add
new mlir-translate tests covering mma, mma.sp, and blockscale variants.
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
.../test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir | 131 +--
.../Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir | 146 ++-
mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir | 144 ++-
.../Target/LLVMIR/nvvm/mma-blockscale.mlir | 697 +++++++++++++++
mlir/test/Target/LLVMIR/nvvm/mma-sp-kind.mlir | 191 ++++
.../Target/LLVMIR/nvvm/mma-sp-ordered.mlir | 380 ++++++++
mlir/test/Target/LLVMIR/nvvm/mma-sp.mlir | 325 +++++++
.../LLVMIR/nvvm/mma-sparse-blockscale.mlir | 842 ++++++++++++++++++
9 files changed, 2636 insertions(+), 222 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/mma-blockscale.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/mma-sp-kind.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/mma-sp-ordered.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/mma-sp.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/mma-sparse-blockscale.mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index fa085d407d6ec..10e23ef21b219 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1423,7 +1423,7 @@ LogicalResult MmaSpOp::verify() {
case MMATypes::e3m2:
case MMATypes::e2m3:
case MMATypes::e2m1:
- kFactor = 32;
+ kFactor = 16;
multiplicandFragType = i32Ty;
expectedResult.push_back(f16x2x2StructTy);
expectedResult.push_back(f32x4StructTy);
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir
index ff3e91b89016d..9236fbbaa6022 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir
@@ -22,12 +22,13 @@
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f16
func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -35,17 +36,18 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f16(
multiplicandBPtxType = #nvvm.mma_type<e4m3>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f32
func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -53,7 +55,7 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32(
multiplicandBPtxType = #nvvm.mma_type<e4m3>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -62,12 +64,13 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32(
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f16
func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -75,17 +78,18 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f16(
multiplicandBPtxType = #nvvm.mma_type<e5m2>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f32
func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -93,7 +97,7 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32(
multiplicandBPtxType = #nvvm.mma_type<e5m2>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -103,12 +107,13 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32(
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f16
func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -116,17 +121,18 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f16(
multiplicandBPtxType = #nvvm.mma_type<e3m2>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f32
func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -134,7 +140,7 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32(
multiplicandBPtxType = #nvvm.mma_type<e3m2>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -144,12 +150,13 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32(
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f16
func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -157,17 +164,18 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f16(
multiplicandBPtxType = #nvvm.mma_type<e2m3>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f32
func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -175,7 +183,7 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32(
multiplicandBPtxType = #nvvm.mma_type<e2m3>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -185,12 +193,13 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32(
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f16
func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -198,17 +207,18 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f16(
multiplicandBPtxType = #nvvm.mma_type<e2m1>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f32
func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{kind = #nvvm.mma_kind<f8f6f4>,
orderedMetadata,
@@ -216,6 +226,5 @@ func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f32(
multiplicandBPtxType = #nvvm.mma_type<e2m1>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
-
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir
index a4e2812e54c12..3fc9643a39ae5 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir
@@ -18,14 +18,15 @@ func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f16(
%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{orderedMetadata,
shape = #nvvm.shape<m = 16, n = 8, k = 16>}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f32
@@ -33,14 +34,15 @@ func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f32(
%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{orderedMetadata,
shape = #nvvm.shape<m = 16, n = 8, k = 16>}
: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -52,14 +54,15 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f16(
%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{orderedMetadata,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f32
@@ -67,14 +70,15 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f32(
%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{orderedMetadata,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -86,7 +90,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k16_bf16_f32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -95,7 +100,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k16_bf16_f32(
multiplicandBPtxType = #nvvm.mma_type<bf16>,
shape = #nvvm.shape<m = 16, n = 8, k = 16>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_bf16_f32
@@ -103,7 +108,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_bf16_f32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -112,7 +118,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_bf16_f32(
multiplicandBPtxType = #nvvm.mma_type<bf16>,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -124,7 +130,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k8_tf32_f32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -133,7 +140,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k8_tf32_f32(
multiplicandBPtxType = #nvvm.mma_type<tf32>,
shape = #nvvm.shape<m = 16, n = 8, k = 8>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_tf32_f32
@@ -141,7 +148,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k16_tf32_f32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -150,7 +158,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k16_tf32_f32(
multiplicandBPtxType = #nvvm.mma_type<tf32>,
shape = #nvvm.shape<m = 16, n = 8, k = 16>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -162,7 +170,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -172,7 +181,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite
@@ -180,7 +189,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -190,7 +200,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite(
intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s8_s32
@@ -198,7 +208,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_s8_s32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -208,7 +219,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_s8_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// =============================================================================
@@ -220,7 +231,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_u8_s32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -230,7 +242,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k32_u8_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u8_s32
@@ -238,7 +250,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_u8_s32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -248,7 +261,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_u8_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// =============================================================================
@@ -260,7 +273,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_s4_s32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -270,7 +284,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_s4_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_s4_s32
@@ -278,7 +292,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k128_s4_s32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -288,7 +303,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k128_s4_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 128>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// =============================================================================
@@ -300,7 +315,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_u4_s32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -310,7 +326,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_u4_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_u4_s32
@@ -318,7 +334,8 @@ func.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -328,7 +345,7 @@ func.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 128>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// =============================================================================
@@ -336,38 +353,22 @@ func.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32(
// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+
// =============================================================================
-// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16
-func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
- sparseMetadata[%meta] selector[%sel]
- {orderedMetadata,
- multiplicandAPtxType = #nvvm.mma_type<e4m3>,
- multiplicandBPtxType = #nvvm.mma_type<e4m3>,
- shape = #nvvm.shape<m = 16, n = 8, k = 64>}
- : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
-}
-
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32
func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{orderedMetadata,
multiplicandAPtxType = #nvvm.mma_type<e4m3>,
multiplicandBPtxType = #nvvm.mma_type<e4m3>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -375,37 +376,20 @@ func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32(
// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+
// =============================================================================
-// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16
-func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
- sparseMetadata[%meta] selector[%sel]
- {orderedMetadata,
- multiplicandAPtxType = #nvvm.mma_type<e5m2>,
- multiplicandBPtxType = #nvvm.mma_type<e5m2>,
- shape = #nvvm.shape<m = 16, n = 8, k = 64>}
- : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
-}
-
// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32
func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{orderedMetadata,
multiplicandAPtxType = #nvvm.mma_type<e5m2>,
multiplicandBPtxType = #nvvm.mma_type<e5m2>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
-
diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir
index e7122aac61baf..84aaa750981d4 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir
@@ -20,13 +20,14 @@ func.func @nvvm_mma_sp_m16n8k16_f16_f16(
%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{shape = #nvvm.shape<m = 16, n = 8, k = 16>}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f32
@@ -34,13 +35,14 @@ func.func @nvvm_mma_sp_m16n8k16_f16_f32(
%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{shape = #nvvm.shape<m = 16, n = 8, k = 16>}
: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -52,13 +54,14 @@ func.func @nvvm_mma_sp_m16n8k32_f16_f16(
%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
+ return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f32
@@ -66,13 +69,14 @@ func.func @nvvm_mma_sp_m16n8k32_f16_f32(
%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -84,7 +88,8 @@ func.func @nvvm_mma_sp_m16n8k16_bf16_f32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -92,7 +97,7 @@ func.func @nvvm_mma_sp_m16n8k16_bf16_f32(
multiplicandBPtxType = #nvvm.mma_type<bf16>,
shape = #nvvm.shape<m = 16, n = 8, k = 16>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_bf16_f32
@@ -100,7 +105,8 @@ func.func @nvvm_mma_sp_m16n8k32_bf16_f32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -108,7 +114,7 @@ func.func @nvvm_mma_sp_m16n8k32_bf16_f32(
multiplicandBPtxType = #nvvm.mma_type<bf16>,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -120,7 +126,8 @@ func.func @nvvm_mma_sp_m16n8k8_tf32_f32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -128,7 +135,7 @@ func.func @nvvm_mma_sp_m16n8k8_tf32_f32(
multiplicandBPtxType = #nvvm.mma_type<tf32>,
shape = #nvvm.shape<m = 16, n = 8, k = 8>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_tf32_f32
@@ -136,7 +143,8 @@ func.func @nvvm_mma_sp_m16n8k16_tf32_f32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -144,7 +152,7 @@ func.func @nvvm_mma_sp_m16n8k16_tf32_f32(
multiplicandBPtxType = #nvvm.mma_type<tf32>,
shape = #nvvm.shape<m = 16, n = 8, k = 16>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
@@ -156,7 +164,8 @@ func.func @nvvm_mma_sp_m16n8k32_s8_s32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -165,7 +174,7 @@ func.func @nvvm_mma_sp_m16n8k32_s8_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32_satfinite
@@ -173,7 +182,8 @@ func.func @nvvm_mma_sp_m16n8k32_s8_s32_satfinite(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -182,7 +192,7 @@ func.func @nvvm_mma_sp_m16n8k32_s8_s32_satfinite(
intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s8_s32
@@ -190,7 +200,8 @@ func.func @nvvm_mma_sp_m16n8k64_s8_s32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -199,7 +210,7 @@ func.func @nvvm_mma_sp_m16n8k64_s8_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// =============================================================================
@@ -211,7 +222,8 @@ func.func @nvvm_mma_sp_m16n8k32_u8_s32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -220,7 +232,7 @@ func.func @nvvm_mma_sp_m16n8k32_u8_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 32>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u8_s32
@@ -228,7 +240,8 @@ func.func @nvvm_mma_sp_m16n8k64_u8_s32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -237,7 +250,7 @@ func.func @nvvm_mma_sp_m16n8k64_u8_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// =============================================================================
@@ -249,7 +262,8 @@ func.func @nvvm_mma_sp_m16n8k64_s4_s32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -258,7 +272,7 @@ func.func @nvvm_mma_sp_m16n8k64_s4_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_s4_s32
@@ -266,7 +280,8 @@ func.func @nvvm_mma_sp_m16n8k128_s4_s32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -275,7 +290,7 @@ func.func @nvvm_mma_sp_m16n8k128_s4_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 128>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// =============================================================================
@@ -287,7 +302,8 @@ func.func @nvvm_mma_sp_m16n8k64_u4_s32(
%a0 : i32, %a1 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -296,7 +312,7 @@ func.func @nvvm_mma_sp_m16n8k64_u4_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_u4_s32
@@ -304,7 +320,8 @@ func.func @nvvm_mma_sp_m16n8k128_u4_s32(
%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32,
- %meta : i32, %sel : i32) {
+ %meta : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
// CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
@@ -313,78 +330,47 @@ func.func @nvvm_mma_sp_m16n8k128_u4_s32(
intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
shape = #nvvm.shape<m = 16, n = 8, k = 128>}
: (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
- return
+ return %0 : !llvm.struct<(i32, i32, i32, i32)>
}
// =============================================================================
// FP8 (e4m3) Sparse MMA Operations
// =============================================================================
-// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f16
-func.func @nvvm_mma_sp_m16n8k64_e4m3_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
- sparseMetadata[%meta] selector[%sel]
- {multiplicandAPtxType = #nvvm.mma_type<e4m3>,
- multiplicandBPtxType = #nvvm.mma_type<e4m3>,
- shape = #nvvm.shape<m = 16, n = 8, k = 64>}
- : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
-}
-
// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f32
func.func @nvvm_mma_sp_m16n8k64_e4m3_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{multiplicandAPtxType = #nvvm.mma_type<e4m3>,
multiplicandBPtxType = #nvvm.mma_type<e4m3>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// =============================================================================
// FP8 (e5m2) Sparse MMA Operations
// =============================================================================
-// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f16
-func.func @nvvm_mma_sp_m16n8k64_e5m2_f16(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
- sparseMetadata[%meta] selector[%sel]
- {multiplicandAPtxType = #nvvm.mma_type<e5m2>,
- multiplicandBPtxType = #nvvm.mma_type<e5m2>,
- shape = #nvvm.shape<m = 16, n = 8, k = 64>}
- : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- return
-}
-
// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f32
func.func @nvvm_mma_sp_m16n8k64_e5m2_f32(
- %a0 : i32, %a1 : i32,
- %b0 : i32, %b1 : i32,
+ %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+ %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %meta : i32, %sel : i32) {
- // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ %meta : i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
sparseMetadata[%meta] selector[%sel]
{multiplicandAPtxType = #nvvm.mma_type<e5m2>,
multiplicandBPtxType = #nvvm.mma_type<e5m2>,
shape = #nvvm.shape<m = 16, n = 8, k = 64>}
: (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- return
+ return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
-
diff --git a/mlir/test/Target/LLVMIR/nvvm/mma-blockscale.mlir b/mlir/test/Target/LLVMIR/nvvm/mma-blockscale.mlir
new file mode 100644
index 0000000000000..edc59ed20193f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mma-blockscale.mlir
@@ -0,0 +1,697 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m1.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m1.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m1.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m1.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m1.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m3.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m3.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m3.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m3.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e2m3.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e3m2.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e3m2.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e3m2.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e3m2.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e3m2.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e4m3.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e4m3.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e4m3.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e4m3.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e4m3.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e5m2.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e5m2.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e5m2.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e5m2.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2
+llvm.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k32.row.col.mxf8f6f4.scale.1x.f32.e5m2.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4_blockscale_mma
+llvm.func @nvvm_mxf4_blockscale_mma(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k64.row.col.mxf4.scale.2x.f32.e2m1.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue8m0
+llvm.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k64.row.col.mxf4nvf4.scale.2x.f32.e2m1.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue4m3
+llvm.func @nvvm_mxf4nvf4_blockscale_mma_ue4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k64.row.col.mxf4nvf4.scale.4x.f32.e2m1.e2m1.f32.ue4m3(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue4m3>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x4>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue8m0_x4
+llvm.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0_x4(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.block.scale.m16n8k64.row.col.mxf4nvf4.scale.4x.f32.e2m1.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1]
+ C[%c0, %c1, %c2, %c3]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x4>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mma-sp-kind.mlir b/mlir/test/Target/LLVMIR/nvvm/mma-sp-kind.mlir
new file mode 100644
index 0000000000000..66dbaf5c6d766
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mma-sp-kind.mlir
@@ -0,0 +1,191 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f16
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f16(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f16.e4m3.e4m3.f16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f32
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.f32.e4m3.e4m3.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f16
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f16(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f16.e5m2.e5m2.f16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f32
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.f32.e5m2.e5m2.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f16
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f16(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f16.e3m2.e3m2.f16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f32
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f32.e3m2.e3m2.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f16
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f16(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f16.e2m3.e2m3.f16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f32
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f32.e2m3.e2m3.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f16
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f16(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f16.e2m1.e2m1.f16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f32
+llvm.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f32.e2m1.e2m1.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {kind = #nvvm.mma_kind<f8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mma-sp-ordered.mlir b/mlir/test/Target/LLVMIR/nvvm/mma-sp-ordered.mlir
new file mode 100644
index 0000000000000..33b4387ee8389
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mma-sp-ordered.mlir
@@ -0,0 +1,380 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f16
+llvm.func @nvvm_mma_sp_ordered_m16n8k16_f16_f16(
+ %a0: vector<2xf16>, %a1: vector<2xf16>,
+ %b0: vector<2xf16>, %b1: vector<2xf16>,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k16.row.col.f16.f16(<2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f32
+llvm.func @nvvm_mma_sp_ordered_m16n8k16_f16_f32(
+ %a0: vector<2xf16>, %a1: vector<2xf16>,
+ %b0: vector<2xf16>, %b1: vector<2xf16>,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k16.row.col.f32.f32(<2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f16
+llvm.func @nvvm_mma_sp_ordered_m16n8k32_f16_f16(
+ %a0: vector<2xf16>, %a1: vector<2xf16>, %a2: vector<2xf16>, %a3: vector<2xf16>,
+ %b0: vector<2xf16>, %b1: vector<2xf16>, %b2: vector<2xf16>, %b3: vector<2xf16>,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k32.row.col.f16.f16(<2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f32
+llvm.func @nvvm_mma_sp_ordered_m16n8k32_f16_f32(
+ %a0: vector<2xf16>, %a1: vector<2xf16>, %a2: vector<2xf16>, %a3: vector<2xf16>,
+ %b0: vector<2xf16>, %b1: vector<2xf16>, %b2: vector<2xf16>, %b3: vector<2xf16>,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k32.row.col.f32.f32(<2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_bf16_f32
+llvm.func @nvvm_mma_sp_ordered_m16n8k16_bf16_f32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k16.row.col.bf16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_bf16_f32
+llvm.func @nvvm_mma_sp_ordered_m16n8k32_bf16_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k32.row.col.bf16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k8_tf32_f32
+llvm.func @nvvm_mma_sp_ordered_m16n8k8_tf32_f32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k8.row.col.tf32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_tf32_f32
+llvm.func @nvvm_mma_sp_ordered_m16n8k16_tf32_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k16.row.col.tf32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32
+llvm.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k32.row.col.s8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite
+llvm.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k32.row.col.satfinite.s8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s8_s32
+llvm.func @nvvm_mma_sp_ordered_m16n8k64_s8_s32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.s8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_u8_s32
+llvm.func @nvvm_mma_sp_ordered_m16n8k32_u8_s32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k32.row.col.u8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u8_s32
+llvm.func @nvvm_mma_sp_ordered_m16n8k64_u8_s32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.u8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s4_s32
+llvm.func @nvvm_mma_sp_ordered_m16n8k64_s4_s32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.s4(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_s4_s32
+llvm.func @nvvm_mma_sp_ordered_m16n8k128_s4_s32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k128.row.col.s4(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u4_s32
+llvm.func @nvvm_mma_sp_ordered_m16n8k64_u4_s32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.u4(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_u4_s32
+llvm.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k128.row.col.u4(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16
+llvm.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f16.e4m3.e4m3.f16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32
+llvm.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.f32.e4m3.e4m3.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16
+llvm.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.kind.f8f6f4.f16.e5m2.e5m2.f16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32
+llvm.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.m16n8k64.row.col.f32.e5m2.e5m2.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mma-sp.mlir b/mlir/test/Target/LLVMIR/nvvm/mma-sp.mlir
new file mode 100644
index 0000000000000..9e238f0491627
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mma-sp.mlir
@@ -0,0 +1,325 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f16
+llvm.func @nvvm_mma_sp_m16n8k16_f16_f16(
+ %a0: vector<2xf16>, %a1: vector<2xf16>,
+ %b0: vector<2xf16>, %b1: vector<2xf16>,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.m16n8k16.row.col.f16.f16(<2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f32
+llvm.func @nvvm_mma_sp_m16n8k16_f16_f32(
+ %a0: vector<2xf16>, %a1: vector<2xf16>,
+ %b0: vector<2xf16>, %b1: vector<2xf16>,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.m16n8k16.row.col.f32.f32(<2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f16
+llvm.func @nvvm_mma_sp_m16n8k32_f16_f16(
+ %a0: vector<2xf16>, %a1: vector<2xf16>, %a2: vector<2xf16>, %a3: vector<2xf16>,
+ %b0: vector<2xf16>, %b1: vector<2xf16>, %b2: vector<2xf16>, %b3: vector<2xf16>,
+ %c0: vector<2xf16>, %c1: vector<2xf16>,
+ %meta: i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.sp.m16n8k32.row.col.f16.f16(<2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %res : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f32
+llvm.func @nvvm_mma_sp_m16n8k32_f16_f32(
+ %a0: vector<2xf16>, %a1: vector<2xf16>, %a2: vector<2xf16>, %a3: vector<2xf16>,
+ %b0: vector<2xf16>, %b1: vector<2xf16>, %b2: vector<2xf16>, %b3: vector<2xf16>,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.m16n8k32.row.col.f32.f32(<2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, <2 x half> {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_bf16_f32
+llvm.func @nvvm_mma_sp_m16n8k16_bf16_f32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.m16n8k16.row.col.bf16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_bf16_f32
+llvm.func @nvvm_mma_sp_m16n8k32_bf16_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.m16n8k32.row.col.bf16(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<bf16>,
+ multiplicandBPtxType = #nvvm.mma_type<bf16>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k8_tf32_f32
+llvm.func @nvvm_mma_sp_m16n8k8_tf32_f32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.m16n8k8.row.col.tf32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 8>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_tf32_f32
+llvm.func @nvvm_mma_sp_m16n8k16_tf32_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.m16n8k16.row.col.tf32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<tf32>,
+ multiplicandBPtxType = #nvvm.mma_type<tf32>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32
+llvm.func @nvvm_mma_sp_m16n8k32_s8_s32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k32.row.col.s8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32_satfinite
+llvm.func @nvvm_mma_sp_m16n8k32_s8_s32_satfinite(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k32.row.col.satfinite.s8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s8_s32
+llvm.func @nvvm_mma_sp_m16n8k64_s8_s32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k64.row.col.s8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<s8>,
+ multiplicandBPtxType = #nvvm.mma_type<s8>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_u8_s32
+llvm.func @nvvm_mma_sp_m16n8k32_u8_s32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k32.row.col.u8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 32>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u8_s32
+llvm.func @nvvm_mma_sp_m16n8k64_u8_s32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k64.row.col.u8(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<u8>,
+ multiplicandBPtxType = #nvvm.mma_type<u8>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s4_s32
+llvm.func @nvvm_mma_sp_m16n8k64_s4_s32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k64.row.col.s4(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_s4_s32
+llvm.func @nvvm_mma_sp_m16n8k128_s4_s32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k128.row.col.s4(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<s4>,
+ multiplicandBPtxType = #nvvm.mma_type<s4>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u4_s32
+llvm.func @nvvm_mma_sp_m16n8k64_u4_s32(
+ %a0: i32, %a1: i32,
+ %b0: i32, %b1: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k64.row.col.u4(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_u4_s32
+llvm.func @nvvm_mma_sp_m16n8k128_u4_s32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: i32, %c1: i32, %c2: i32, %c3: i32,
+ %meta: i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.sp.m16n8k128.row.col.u4(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>,
+ multiplicandAPtxType = #nvvm.mma_type<u4>,
+ multiplicandBPtxType = #nvvm.mma_type<u4>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %res : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f32
+llvm.func @nvvm_mma_sp_m16n8k64_e4m3_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.m16n8k64.row.col.f32.e4m3.e4m3.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f32
+llvm.func @nvvm_mma_sp_m16n8k64_e5m2_f32(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.m16n8k64.row.col.f32.e5m2.e5m2.f32(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0)
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ %res = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta] selector[%sel]
+ {multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mma-sparse-blockscale.mlir b/mlir/test/Target/LLVMIR/nvvm/mma-sparse-blockscale.mlir
new file mode 100644
index 0000000000000..27d06b79a1796
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mma-sparse-blockscale.mlir
@@ -0,0 +1,842 @@
+// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m1.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m1.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m1.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m1.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m1.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m3.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m3.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m3.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m3.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e2m3.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e3m2.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e3m2.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e3m2.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e3m2.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e3m2.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e3m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e4m3.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e4m3.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e4m3.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e4m3.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e4m3.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e4m3>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e5m2.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e5m2.e2m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e5m2.e3m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e3m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e5m2.e4m3.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e4m3>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2
+llvm.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k64.row.col.mxf8f6f4.scale.1x.f32.e5m2.e5m2.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf8f6f4>,
+ multiplicandAPtxType = #nvvm.mma_type<e5m2>,
+ multiplicandBPtxType = #nvvm.mma_type<e5m2>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x1>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 64>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4_sp_blockscale_mma
+llvm.func @nvvm_mxf4_sp_blockscale_mma(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k128.row.col.mxf4.scale.2x.f32.e2m1.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0
+llvm.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k128.row.col.mxf4nvf4.scale.2x.f32.e2m1.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3
+llvm.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k128.row.col.mxf4nvf4.scale.4x.f32.e2m1.e2m1.f32.ue4m3(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue4m3>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x4>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0_x4
+llvm.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0_x4(
+ %a0: i32, %a1: i32, %a2: i32, %a3: i32,
+ %b0: i32, %b1: i32, %b2: i32, %b3: i32,
+ %c0: f32, %c1: f32, %c2: f32, %c3: f32,
+ %meta: i32,
+ %scaleA0: i32, %scaleA1: i16, %scaleA2: i16,
+ %scaleB0: i32, %scaleB1: i16, %scaleB2: i16) -> !llvm.struct<(f32, f32, f32, f32)> {
+ %sel = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: call { float, float, float, float } @llvm.nvvm.mma.sp.ordered.metadata.block.scale.m16n8k128.row.col.mxf4nvf4.scale.4x.f32.e2m1.e2m1.f32.ue8m0(i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, float {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 0, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}}, i32 {{%[0-9]+}}, i16 {{%[0-9]+}}, i16 {{%[0-9]+}})
+ %res = nvvm.mma.sp.block_scale
+ A[%a0, %a1, %a2, %a3]
+ B[%b0, %b1, %b2, %b3]
+ C[%c0, %c1, %c2, %c3]
+ sparseMetadata[%meta]
+ selector[%sel]
+ scaleA[%scaleA0, %scaleA1, %scaleA2]
+ scaleB[%scaleB0, %scaleB1, %scaleB2]
+ {blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ orderedMetadata,
+ scaleVecSize = #nvvm.scale_vec_size<x4>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 128>}
+ : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %res : !llvm.struct<(f32, f32, f32, f32)>
+}
More information about the Mlir-commits
mailing list