[Mlir-commits] [mlir] [MLIR][NVVM] Support generating all the ldmatrix intrinsics from NVVM ops (PR #148783)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 14 23:30:17 PDT 2025
https://github.com/Pecco-314 created https://github.com/llvm/llvm-project/pull/148783
Previously, the NVVM dialect's ldmatrix operation could only generate a limited subset of the available NVVM ldmatrix intrinsics. The intrinsics generating new ops introduced in BlackWell are not accessible through the NVVM ops. This commit extends the ldmatrix operation to support all available ldmatrix intrinsics.
>From fac30998dfd72b1481210db77a85aa4c788f177a Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Tue, 15 Jul 2025 14:17:13 +0800
Subject: [PATCH] Support generating all the ldmatrix intrinsics from NVVM ops
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 36 ++++++-
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 3 +-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 12 ++-
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 101 ++++++++++++++----
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 4 +-
mlir/test/Dialect/LLVMIR/invalid.mlir | 17 ++-
mlir/test/Dialect/LLVMIR/nvvm.mlir | 11 --
mlir/test/Target/LLVMIR/nvvmir.mlir | 44 ++++++--
8 files changed, 175 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..cfb21e8331d05 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,6 +1990,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}
+def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
+def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
+def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
+def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
+
+def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
+ [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
+def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
+def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
+def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;
+
+def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
+ [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
+ LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elttype"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
Arguments<(ins LLVM_PointerShared:$ptr,
Variadic<I32>:$sources,
@@ -2021,13 +2050,16 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
Results<(outs AnyType:$res)>,
- Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
+ Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num,
+ MMALayoutAttr: $layout,
+ LdStMatrixShapeAttr: $shape,
+ LdStMatrixEltTypeAttr: $elttype)> {
let summary = "cooperative matrix load";
string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
- auto intId = getLdMatrixIntrinsicId($layout, $num);
+ auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype);
$res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
}];
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 80b3d85488495..470dc2512a9ad 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -289,7 +289,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
- : NVVM::MMALayout::row);
+ : NVVM::MMALayout::row,
+ NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16);
// The ldmatrix operation returns either a single i32 value or a struct of
// i32 values. Here we unpack those values and cast them back to their
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6e29b129e8835..93c155b67fb5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -806,14 +806,18 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return emitOpError("expected num attribute to be 1, 2 or 4");
Type i32 = IntegerType::get(getContext(), 32);
- if (getNum() == 1 && getType() != i32)
+ uint32_t num = getNum();
+ if (getShape() == LdStMatrixShape::M16N16) {
+ num *= 2;
+ }
+ if (num == 1 && getType() != i32)
return emitOpError("expected destination type is i32");
- if (getNum() == 2 || getNum() == 4) {
+ if (num == 2 || num == 4) {
Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type>(getNum(), i32));
+ getContext(), SmallVector<Type>(num, i32));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
- << getNum() << " elements of type i32";
+ << num << " elements of type i32";
}
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eecca64c4bf81..5d13933519c54 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -134,33 +134,90 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
llvm_unreachable("unsupported vote kind");
}
-/// Return the intrinsic ID associated with ldmatrix for the given paramters.
-static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
- int32_t num) {
+static llvm::Intrinsic::ID
+getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+ NVVM::LdStMatrixShape shape,
+ NVVM::LdStMatrixEltType elttype) {
if (layout == NVVM::MMALayout::row) {
- switch (num) {
- case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
- case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
- case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ if (shape == NVVM::LdStMatrixShape::M8N8 &&
+ elttype == NVVM::LdStMatrixEltType::B16) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+ case 4:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
+ }
}
-
} else {
- switch (num) {
- case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
- case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
- case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ if (shape == NVVM::LdStMatrixShape::M8N8 &&
+ elttype == NVVM::LdStMatrixEltType::B16) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
+ case 4:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
+ }
}
}
+ llvm_unreachable("unsupported matrix configuration");
}
/// Return the intrinsic ID associated with st.bulk for the given address type.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index d0bc806e0aa8c..75a556f471373 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK-LABEL: @ldmatrix_x4
func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
%c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+ // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
// CHECK: llvm.extractvalue
// CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
// CHECK-LABEL: @ldmatrix_x1
func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
%c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+ // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> i32
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
// CHECK: llvm.bitcast
// CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index bd1106e304c60..f9def0877d71a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr) -> i32
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
llvm.return
}
@@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
- %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+ %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
llvm.return
}
@@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
llvm.return
}
@@ -1140,10 +1140,19 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
- %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
llvm.return
}
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+ llvm.return
+}
+
+
// -----
llvm.func @caller() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c98ac92..6a4edd0d22a08 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
llvm.return
}
-// CHECK-LABEL: llvm.func @ld_matrix
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
- %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- llvm.return
-}
-
// CHECK-LABEL: llvm.func @redux_sync
llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
// CHECK: nvvm.redux.sync add %{{.*}}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..89429a762db92 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
- %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+ %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
- %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
- %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> i32
+ %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
+ %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
llvm.return
}
More information about the Mlir-commits
mailing list