[Mlir-commits] [mlir] [MLIR][NVVM] Support generating all the ldmatrix intrinsics from NVVM ops (PR #148783)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 16 00:24:21 PDT 2025
https://github.com/Pecco-314 updated https://github.com/llvm/llvm-project/pull/148783
>From fac30998dfd72b1481210db77a85aa4c788f177a Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Tue, 15 Jul 2025 14:17:13 +0800
Subject: [PATCH 1/2] Support generating all the ldmatrix intrinsics from NVVM
ops
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 36 ++++++-
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 3 +-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 12 ++-
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 101 ++++++++++++++----
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 4 +-
mlir/test/Dialect/LLVMIR/invalid.mlir | 17 ++-
mlir/test/Dialect/LLVMIR/nvvm.mlir | 11 --
mlir/test/Target/LLVMIR/nvvmir.mlir | 44 ++++++--
8 files changed, 175 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..cfb21e8331d05 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,6 +1990,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}
+def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
+def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
+def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
+def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
+
+def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
+ [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
+def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
+def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
+def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;
+
+def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
+ [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
+ LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elttype"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
Arguments<(ins LLVM_PointerShared:$ptr,
Variadic<I32>:$sources,
@@ -2021,13 +2050,16 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
Results<(outs AnyType:$res)>,
- Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
+ Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num,
+ MMALayoutAttr: $layout,
+ LdStMatrixShapeAttr: $shape,
+ LdStMatrixEltTypeAttr: $elttype)> {
let summary = "cooperative matrix load";
string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
- auto intId = getLdMatrixIntrinsicId($layout, $num);
+ auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype);
$res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
}];
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 80b3d85488495..470dc2512a9ad 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -289,7 +289,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
- : NVVM::MMALayout::row);
+ : NVVM::MMALayout::row,
+ NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16);
// The ldmatrix operation returns either a single i32 value or a struct of
// i32 values. Here we unpack those values and cast them back to their
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6e29b129e8835..93c155b67fb5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -806,14 +806,18 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return emitOpError("expected num attribute to be 1, 2 or 4");
Type i32 = IntegerType::get(getContext(), 32);
- if (getNum() == 1 && getType() != i32)
+ uint32_t num = getNum();
+ if (getShape() == LdStMatrixShape::M16N16) {
+ num *= 2;
+ }
+ if (num == 1 && getType() != i32)
return emitOpError("expected destination type is i32");
- if (getNum() == 2 || getNum() == 4) {
+ if (num == 2 || num == 4) {
Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type>(getNum(), i32));
+ getContext(), SmallVector<Type>(num, i32));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
- << getNum() << " elements of type i32";
+ << num << " elements of type i32";
}
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eecca64c4bf81..5d13933519c54 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -134,33 +134,90 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
llvm_unreachable("unsupported vote kind");
}
-/// Return the intrinsic ID associated with ldmatrix for the given paramters.
-static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
- int32_t num) {
+static llvm::Intrinsic::ID
+getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+ NVVM::LdStMatrixShape shape,
+ NVVM::LdStMatrixEltType elttype) {
if (layout == NVVM::MMALayout::row) {
- switch (num) {
- case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
- case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
- case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ if (shape == NVVM::LdStMatrixShape::M8N8 &&
+ elttype == NVVM::LdStMatrixEltType::B16) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+ case 4:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
+ }
}
-
} else {
- switch (num) {
- case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
- case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
- case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ if (shape == NVVM::LdStMatrixShape::M8N8 &&
+ elttype == NVVM::LdStMatrixEltType::B16) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
+ case 4:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
+ }
+ } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
+ }
}
}
+ llvm_unreachable("unsupported matrix configuration");
}
/// Return the intrinsic ID associated with st.bulk for the given address type.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index d0bc806e0aa8c..75a556f471373 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK-LABEL: @ldmatrix_x4
func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
%c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+ // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
// CHECK: llvm.extractvalue
// CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
// CHECK-LABEL: @ldmatrix_x1
func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
%c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+ // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> i32
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
// CHECK: llvm.bitcast
// CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index bd1106e304c60..f9def0877d71a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr) -> i32
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
llvm.return
}
@@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
- %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+ %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
llvm.return
}
@@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
llvm.return
}
@@ -1140,10 +1140,19 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
- %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
llvm.return
}
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
+ // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+ llvm.return
+}
+
+
// -----
llvm.func @caller() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c98ac92..6a4edd0d22a08 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
llvm.return
}
-// CHECK-LABEL: llvm.func @ld_matrix
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
- %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
- // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- llvm.return
-}
-
// CHECK-LABEL: llvm.func @redux_sync
llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
// CHECK: nvvm.redux.sync add %{{.*}}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..89429a762db92 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
- %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+ %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
- %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
- %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
- // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> i32
+ %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
+ %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+ %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
llvm.return
}
>From 784b7494581a1df03e16ec0faacc85d35c23960d Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Wed, 16 Jul 2025 15:05:37 +0800
Subject: [PATCH 2/2] Modify the arguments of the ldmatrix op
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 16 +++------
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 3 +-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 18 +++++-----
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 4 +--
mlir/test/Dialect/LLVMIR/invalid.mlir | 10 +++---
mlir/test/Target/LLVMIR/nvvmir.mlir | 36 +++++++++----------
7 files changed, 41 insertions(+), 48 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cfb21e8331d05..6af9f4e36be3d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,18 +1990,10 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}
-def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
-def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
-def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
-def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
-
-def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
- [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
- let assemblyFormat = "`<` $value `>`";
+def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> {
+ let summary = "Matrix shape for ldmatrix and stmatrix";
+ let parameters = (ins "int":$m, "int":$n);
+ let assemblyFormat = "`<` struct(params) `>`";
}
def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 470dc2512a9ad..53eeabb16c984 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -285,12 +285,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
Value srcPtr =
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
+ auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row,
- NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16);
+ /*shape=*/shape, /*elttype=*/NVVM::LdStMatrixEltType::B16);
// The ldmatrix operation returns either a single i32 value or a struct of
// i32 values. Here we unpack those values and cast them back to their
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 93c155b67fb5c..fbb78ed487448 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -807,7 +807,7 @@ LogicalResult NVVM::LdMatrixOp::verify() {
Type i32 = IntegerType::get(getContext(), 32);
uint32_t num = getNum();
- if (getShape() == LdStMatrixShape::M16N16) {
+ if (getShape().getM() == 16 && getShape().getN() == 16) {
num *= 2;
}
if (num == 1 && getType() != i32)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 5d13933519c54..098336cc035a4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -136,10 +136,10 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
static llvm::Intrinsic::ID
getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
- NVVM::LdStMatrixShape shape,
+ NVVM::LdStMatrixShapeAttr shape,
NVVM::LdStMatrixEltType elttype) {
if (layout == NVVM::MMALayout::row) {
- if (shape == NVVM::LdStMatrixShape::M8N8 &&
+ if (shape.getM() == 8 && shape.getN() == 8 &&
elttype == NVVM::LdStMatrixEltType::B16) {
switch (num) {
case 1:
@@ -149,7 +149,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
case 4:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
}
- } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+ } else if (shape.getM() == 8 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
switch (num) {
case 1:
@@ -162,7 +162,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
}
- } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+ } else if (shape.getM() == 8 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
switch (num) {
case 1:
@@ -177,7 +177,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
}
}
} else {
- if (shape == NVVM::LdStMatrixShape::M8N8 &&
+ if (shape.getM() == 8 && shape.getN() == 8 &&
elttype == NVVM::LdStMatrixEltType::B16) {
switch (num) {
case 1:
@@ -187,7 +187,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
case 4:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
}
- } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ } else if (shape.getM() == 16 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8) {
switch (num) {
case 1:
@@ -195,7 +195,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
case 2:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
}
- } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ } else if (shape.getM() == 16 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
switch (num) {
case 1:
@@ -205,7 +205,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
}
- } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+ } else if (shape.getM() == 16 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
switch (num) {
case 1:
@@ -217,7 +217,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
}
}
}
- llvm_unreachable("unsupported matrix configuration");
+ llvm_unreachable("unknown ldmatrix kind");
}
/// Return the intrinsic ID associated with st.bulk for the given address type.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 75a556f471373..2c0ed9b68a3c8 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK-LABEL: @ldmatrix_x4
func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
%c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
+ // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
// CHECK: llvm.extractvalue
// CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
// CHECK-LABEL: @ldmatrix_x1
func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
%c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> i32
+ // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
// CHECK: llvm.bitcast
// CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f9def0877d71a..6c0c942a041c6 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
llvm.return
}
@@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
- %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+ %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
llvm.return
}
@@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
llvm.return
}
@@ -1140,7 +1140,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
- %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
llvm.return
}
@@ -1148,7 +1148,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
- %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 89429a762db92..69d791138ec71 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,47 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
- %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+ %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
- %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
- %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+ %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
- %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
- %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+ %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
- %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
- %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
- %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+ %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
- %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
- %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
- %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
- %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
- %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
- %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
- %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
- %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
llvm.return
}
More information about the Mlir-commits
mailing list