[Mlir-commits] [mlir] [MLIR][NVVM] Support generating all the ldmatrix intrinsics from NVVM ops (PR #148783)

Gao Yanfeng llvmlistbot at llvm.org
Tue Jul 22 23:24:12 PDT 2025


https://github.com/Pecco-314 updated https://github.com/llvm/llvm-project/pull/148783

>From fac30998dfd72b1481210db77a85aa4c788f177a Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Tue, 15 Jul 2025 14:17:13 +0800
Subject: [PATCH 1/8] Support generating all the ldmatrix intrinsics from NVVM
 ops

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  36 ++++++-
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |   3 +-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  12 ++-
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  | 101 ++++++++++++++----
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |   4 +-
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  17 ++-
 mlir/test/Dialect/LLVMIR/nvvm.mlir            |  11 --
 mlir/test/Target/LLVMIR/nvvmir.mlir           |  44 ++++++--
 8 files changed, 175 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..cfb21e8331d05 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,6 +1990,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
+def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
+def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
+def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
+def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
+
+def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
+  [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
+def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
+def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
+def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;
+
+def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
+  [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
+   LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elttype"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
   Arguments<(ins LLVM_PointerShared:$ptr, 
                  Variadic<I32>:$sources, 
@@ -2021,13 +2050,16 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
 
 def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   Results<(outs AnyType:$res)>,
-  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
+  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num,
+                 MMALayoutAttr: $layout,
+                 LdStMatrixShapeAttr: $shape,
+                 LdStMatrixEltTypeAttr: $elttype)> {
 
   let summary = "cooperative matrix load";
 
   string llvmBuilder = [{
       auto operands = moduleTranslation.lookupValues(opInst.getOperands());
-      auto intId = getLdMatrixIntrinsicId($layout, $num);
+      auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype);
       $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
   }];
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 80b3d85488495..470dc2512a9ad 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -289,7 +289,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
         ldMatrixResultType, srcPtr,
         /*num=*/op.getNumTiles(),
         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
-                                     : NVVM::MMALayout::row);
+                                     : NVVM::MMALayout::row,
+        NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16);
 
     // The ldmatrix operation returns either a single i32 value or a struct of
     // i32 values. Here we unpack those values and cast them back to their
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6e29b129e8835..93c155b67fb5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -806,14 +806,18 @@ LogicalResult NVVM::LdMatrixOp::verify() {
     return emitOpError("expected num attribute to be 1, 2 or 4");
 
   Type i32 = IntegerType::get(getContext(), 32);
-  if (getNum() == 1 && getType() != i32)
+  uint32_t num = getNum();
+  if (getShape() == LdStMatrixShape::M16N16) {
+    num *= 2;
+  }
+  if (num == 1 && getType() != i32)
     return emitOpError("expected destination type is i32");
-  if (getNum() == 2 || getNum() == 4) {
+  if (num == 2 || num == 4) {
     Type dstType = LLVM::LLVMStructType::getLiteral(
-        getContext(), SmallVector<Type>(getNum(), i32));
+        getContext(), SmallVector<Type>(num, i32));
     if (getType() != dstType)
       return emitOpError("expected destination type is a structure of ")
-             << getNum() << " elements of type i32";
+             << num << " elements of type i32";
   }
   return success();
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eecca64c4bf81..5d13933519c54 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -134,33 +134,90 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
   llvm_unreachable("unsupported vote kind");
 }
 
-/// Return the intrinsic ID associated with ldmatrix for the given paramters.
-static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
-                                                  int32_t num) {
+static llvm::Intrinsic::ID
+getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+                       NVVM::LdStMatrixShape shape,
+                       NVVM::LdStMatrixEltType elttype) {
   if (layout == NVVM::MMALayout::row) {
-    switch (num) {
-    case 1:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
-    case 2:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
-    case 4:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
-    default:
-      llvm_unreachable("unsupported number of matrix");
+    if (shape == NVVM::LdStMatrixShape::M8N8 &&
+        elttype == NVVM::LdStMatrixEltType::B16) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+      case 4:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
+      case 4:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
+      case 4:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
+      }
     }
-
   } else {
-    switch (num) {
-    case 1:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
-    case 2:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
-    case 4:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
-    default:
-      llvm_unreachable("unsupported number of matrix");
+    if (shape == NVVM::LdStMatrixShape::M8N8 &&
+        elttype == NVVM::LdStMatrixEltType::B16) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
+      case 4:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
+      }
     }
   }
+  llvm_unreachable("unsupported matrix configuration");
 }
 
 /// Return the intrinsic ID associated with st.bulk for the given address type.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index d0bc806e0aa8c..75a556f471373 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
 // CHECK-LABEL: @ldmatrix_x4
 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
   // CHECK: llvm.extractvalue
   // CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
 // CHECK-LABEL: @ldmatrix_x1
 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> i32
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
   // CHECK: llvm.bitcast
   // CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index bd1106e304c60..f9def0877d71a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
   llvm.return
 }
 
@@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
   llvm.return
 }
 
@@ -1140,10 +1140,19 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   llvm.return
 }
 
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+
 // -----
 
 llvm.func @caller() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c98ac92..6a4edd0d22a08 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
   llvm.return
 }
 
-// CHECK-LABEL: llvm.func @ld_matrix
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
-  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-  llvm.return
-}
-
 // CHECK-LABEL: llvm.func @redux_sync
 llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
   // CHECK: nvvm.redux.sync  add %{{.*}}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..89429a762db92 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> i32
+  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
+  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
+  %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
+  %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }
 

>From 784b7494581a1df03e16ec0faacc85d35c23960d Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Wed, 16 Jul 2025 15:05:37 +0800
Subject: [PATCH 2/8] Modify the arguments of the ldmatrix op

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 16 +++------
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |  3 +-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  2 +-
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  | 18 +++++-----
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  4 +--
 mlir/test/Dialect/LLVMIR/invalid.mlir         | 10 +++---
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 36 +++++++++----------
 7 files changed, 41 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cfb21e8331d05..6af9f4e36be3d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,18 +1990,10 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
-def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
-def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
-def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
-def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
-
-def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
-  [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::NVVM";
-}
-def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
-  let assemblyFormat = "`<` $value `>`";
+def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> {
+  let summary = "Matrix shape for ldmatrix and stmatrix";
+  let parameters = (ins "int":$m, "int":$n);
+  let assemblyFormat = "`<` struct(params) `>`";
 }
 
 def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 470dc2512a9ad..53eeabb16c984 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -285,12 +285,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
     Value srcPtr =
         getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
                              adaptor.getSrcMemref(), adaptor.getIndices());
+    auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
     Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
         ldMatrixResultType, srcPtr,
         /*num=*/op.getNumTiles(),
         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
                                      : NVVM::MMALayout::row,
-        NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16);
+        /*shape=*/shape, /*elttype=*/NVVM::LdStMatrixEltType::B16);
 
     // The ldmatrix operation returns either a single i32 value or a struct of
     // i32 values. Here we unpack those values and cast them back to their
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 93c155b67fb5c..fbb78ed487448 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -807,7 +807,7 @@ LogicalResult NVVM::LdMatrixOp::verify() {
 
   Type i32 = IntegerType::get(getContext(), 32);
   uint32_t num = getNum();
-  if (getShape() == LdStMatrixShape::M16N16) {
+  if (getShape().getM() == 16 && getShape().getN() == 16) {
     num *= 2;
   }
   if (num == 1 && getType() != i32)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 5d13933519c54..098336cc035a4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -136,10 +136,10 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
 
 static llvm::Intrinsic::ID
 getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
-                       NVVM::LdStMatrixShape shape,
+                       NVVM::LdStMatrixShapeAttr shape,
                        NVVM::LdStMatrixEltType elttype) {
   if (layout == NVVM::MMALayout::row) {
-    if (shape == NVVM::LdStMatrixShape::M8N8 &&
+    if (shape.getM() == 8 && shape.getN() == 8 &&
         elttype == NVVM::LdStMatrixEltType::B16) {
       switch (num) {
       case 1:
@@ -149,7 +149,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
       case 4:
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
       }
-    } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+    } else if (shape.getM() == 8 && shape.getN() == 16 &&
                elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
       switch (num) {
       case 1:
@@ -162,7 +162,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
         return llvm::Intrinsic::
             nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
       }
-    } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+    } else if (shape.getM() == 8 && shape.getN() == 16 &&
                elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
       switch (num) {
       case 1:
@@ -177,7 +177,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
       }
     }
   } else {
-    if (shape == NVVM::LdStMatrixShape::M8N8 &&
+    if (shape.getM() == 8 && shape.getN() == 8 &&
         elttype == NVVM::LdStMatrixEltType::B16) {
       switch (num) {
       case 1:
@@ -187,7 +187,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
       case 4:
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
       }
-    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+    } else if (shape.getM() == 16 && shape.getN() == 16 &&
                elttype == NVVM::LdStMatrixEltType::B8) {
       switch (num) {
       case 1:
@@ -195,7 +195,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
       case 2:
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
       }
-    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+    } else if (shape.getM() == 16 && shape.getN() == 16 &&
                elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
       switch (num) {
       case 1:
@@ -205,7 +205,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
         return llvm::Intrinsic::
             nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
       }
-    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+    } else if (shape.getM() == 16 && shape.getN() == 16 &&
                elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
       switch (num) {
       case 1:
@@ -217,7 +217,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
       }
     }
   }
-  llvm_unreachable("unsupported matrix configuration");
+  llvm_unreachable("unknown ldmatrix kind");
 }
 
 /// Return the intrinsic ID associated with st.bulk for the given address type.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 75a556f471373..2c0ed9b68a3c8 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
 // CHECK-LABEL: @ldmatrix_x4
 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
   // CHECK: llvm.extractvalue
   // CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
 // CHECK-LABEL: @ldmatrix_x1
 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> i32
+  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
   // CHECK: llvm.bitcast
   // CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f9def0877d71a..6c0c942a041c6 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
   llvm.return
 }
 
@@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
   llvm.return
 }
 
@@ -1140,7 +1140,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   llvm.return
 }
 
@@ -1148,7 +1148,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 89429a762db92..69d791138ec71 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,47 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }
 

>From 3b7c285b7e8370bd4962a39da7ff0b0f0bacdf26 Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Mon, 21 Jul 2025 16:24:21 +0800
Subject: [PATCH 3/8] Follow the convention of eltType in the naming

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  4 +--
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |  2 +-
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  | 16 ++++-----
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  4 +--
 mlir/test/Dialect/LLVMIR/invalid.mlir         | 10 +++---
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 36 +++++++++----------
 6 files changed, 36 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6af9f4e36be3d..0f12829c2b9ad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2045,13 +2045,13 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num,
                  MMALayoutAttr: $layout,
                  LdStMatrixShapeAttr: $shape,
-                 LdStMatrixEltTypeAttr: $elttype)> {
+                 LdStMatrixEltTypeAttr: $eltType)> {
 
   let summary = "cooperative matrix load";
 
   string llvmBuilder = [{
       auto operands = moduleTranslation.lookupValues(opInst.getOperands());
-      auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype);
+      auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $eltType);
       $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
   }];
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 53eeabb16c984..b1ed889194ab0 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -291,7 +291,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
         /*num=*/op.getNumTiles(),
         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
                                      : NVVM::MMALayout::row,
-        /*shape=*/shape, /*elttype=*/NVVM::LdStMatrixEltType::B16);
+        /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
 
     // The ldmatrix operation returns either a single i32 value or a struct of
     // i32 values. Here we unpack those values and cast them back to their
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 098336cc035a4..0904d2b49f184 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -137,10 +137,10 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
 static llvm::Intrinsic::ID
 getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
                        NVVM::LdStMatrixShapeAttr shape,
-                       NVVM::LdStMatrixEltType elttype) {
+                       NVVM::LdStMatrixEltType eltType) {
   if (layout == NVVM::MMALayout::row) {
     if (shape.getM() == 8 && shape.getN() == 8 &&
-        elttype == NVVM::LdStMatrixEltType::B16) {
+        eltType == NVVM::LdStMatrixEltType::B16) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
@@ -150,7 +150,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
       }
     } else if (shape.getM() == 8 && shape.getN() == 16 &&
-               elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+               eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::
@@ -163,7 +163,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
             nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
       }
     } else if (shape.getM() == 8 && shape.getN() == 16 &&
-               elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+               eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::
@@ -178,7 +178,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
     }
   } else {
     if (shape.getM() == 8 && shape.getN() == 8 &&
-        elttype == NVVM::LdStMatrixEltType::B16) {
+        eltType == NVVM::LdStMatrixEltType::B16) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
@@ -188,7 +188,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
       }
     } else if (shape.getM() == 16 && shape.getN() == 16 &&
-               elttype == NVVM::LdStMatrixEltType::B8) {
+               eltType == NVVM::LdStMatrixEltType::B8) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
@@ -196,7 +196,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
       }
     } else if (shape.getM() == 16 && shape.getN() == 16 &&
-               elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+               eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::
@@ -206,7 +206,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
             nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
       }
     } else if (shape.getM() == 16 && shape.getN() == 16 &&
-               elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+               eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 2c0ed9b68a3c8..bc2204ee63ede 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
 // CHECK-LABEL: @ldmatrix_x4
 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
   // CHECK: llvm.extractvalue
   // CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
 // CHECK-LABEL: @ldmatrix_x1
 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
+  // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
   // CHECK: llvm.bitcast
   // CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 6c0c942a041c6..f13960708f800 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
   llvm.return
 }
 
@@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
   llvm.return
 }
 
@@ -1140,7 +1140,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   llvm.return
 }
 
@@ -1148,7 +1148,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 69d791138ec71..c4a15b1d62b4a 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,47 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,eltType =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,eltType =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,eltType =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }
 

>From 7af2a81556cd3f4e196c1a41a805681929346ebc Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Mon, 21 Jul 2025 17:08:32 +0800
Subject: [PATCH 4/8] Add verifier for ldmatrix

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 55 +++++++++++++----
 mlir/test/Dialect/LLVMIR/invalid.mlir      | 69 +++++++++++++++++++---
 2 files changed, 106 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index fbb78ed487448..a4912fc6990b6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -802,23 +802,58 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   if (addressSpace != NVVM::kSharedMemorySpace)
     return emitOpError("expected source pointer in memory space 3");
 
-  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
-    return emitOpError("expected num attribute to be 1, 2 or 4");
+  uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
+  if (m == 8 && n == 8) {
+    if (num != 1 && num != 2 && num != 4) {
+      return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
+                         "matrix");
+    }
+    if (getEltType() != LdStMatrixEltType::B16) {
+      return emitOpError("expected element type to be b16 for 8x8 matrix");
+    }
+  } else if (m == 8 && n == 16) {
+    if (num != 1 && num != 2 && num != 4) {
+      return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
+                         "matrix");
+    }
+    if (getLayout() != MMALayout::row) {
+      return emitOpError("expected layout to be row for 8x16 matrix");
+    }
+    if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+        getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+      return emitOpError("expected element type to be b8x16.b4x16_p64 or "
+                         "b8x16.b6x16_p32 for 8x16 matrix");
+    }
+  } else if (m == 16 && n == 16) {
+    if (num != 1 && num != 2) {
+      return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
+                         "matrix");
+    }
+    if (getLayout() != MMALayout::col) {
+      return emitOpError("expected layout to be col for 16x16 matrix");
+    }
+    if (getEltType() != LdStMatrixEltType::B8 &&
+        getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+        getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+      return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
+                         "b8x16.b6x16_p32 for 16x16 matrix");
+    }
+  } else {
+    return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
+  }
 
   Type i32 = IntegerType::get(getContext(), 32);
-  uint32_t num = getNum();
-  if (getShape().getM() == 16 && getShape().getN() == 16) {
-    num *= 2;
-  }
-  if (num == 1 && getType() != i32)
+  uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
+  if (numElements == 1 && getType() != i32)
     return emitOpError("expected destination type is i32");
-  if (num == 2 || num == 4) {
+  if (numElements == 2 || numElements == 4) {
     Type dstType = LLVM::LLVMStructType::getLiteral(
-        getContext(), SmallVector<Type>(num, i32));
+        getContext(), SmallVector<Type>(numElements, i32));
     if (getType() != dstType)
       return emitOpError("expected destination type is a structure of ")
-             << num << " elements of type i32";
+             << numElements << " elements of type i32";
   }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f13960708f800..fc38ed4c6ecfe 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1114,7 +1114,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 // -----
 
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
+llvm.func @ld_matrix(%arg0: !llvm.ptr) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
   %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
   llvm.return
@@ -1122,15 +1122,15 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
 
 // -----
 
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}}
   %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
 // -----
 
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
   %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
   llvm.return
@@ -1138,7 +1138,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 // -----
 
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
   %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   llvm.return
@@ -1146,12 +1146,65 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 // -----
 
-llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }
 
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
 
 // -----
 

>From c0f544e727955f4803cbe04fc3a5eb19f422b68e Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Wed, 23 Jul 2025 11:18:30 +0800
Subject: [PATCH 5/8] Change "ld_st_matrix_elttype" to "ld_st_matrix_elt_type"

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  2 +-
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  4 +--
 mlir/test/Dialect/LLVMIR/invalid.mlir         | 24 ++++++-------
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 36 +++++++++----------
 4 files changed, 33 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0f12829c2b9ad..e7f7d37d6b17d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2007,7 +2007,7 @@ def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmat
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
 }
-def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elttype"> {
+def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elt_type"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index bc2204ee63ede..090c8bb42ac17 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
 // CHECK-LABEL: @ldmatrix_x4
 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
   // CHECK: llvm.extractvalue
   // CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
 // CHECK-LABEL: @ldmatrix_x1
 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
+  // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
   // CHECK: llvm.bitcast
   // CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fc38ed4c6ecfe..e0e42e46830bc 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr) -> i32
   llvm.return
 }
 
@@ -1124,7 +1124,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1132,7 +1132,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
   llvm.return
 }
 
@@ -1140,7 +1140,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   llvm.return
 }
 
@@ -1148,7 +1148,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1156,7 +1156,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)>
   llvm.return
 }
 
@@ -1164,7 +1164,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1172,7 +1172,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1180,7 +1180,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }
 
@@ -1188,7 +1188,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }
 
@@ -1196,13 +1196,13 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c4a15b1d62b4a..f502a2f6b5769 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,47 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
+  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,eltType =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,eltType =#nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,eltType =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,eltType =#nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
 
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
-  %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,eltType =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,eltType =#nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }
 

>From dfd19cbc592d70ccdc5f4f9462c84de33f549598 Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Wed, 23 Jul 2025 11:27:21 +0800
Subject: [PATCH 6/8] Simplifier the structure of `getLdMatrixIntrinsicId`

---
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  | 57 ++++++++-----------
 1 file changed, 25 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 0904d2b49f184..33b784eb56504 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -138,19 +138,26 @@ static llvm::Intrinsic::ID
 getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
                        NVVM::LdStMatrixShapeAttr shape,
                        NVVM::LdStMatrixEltType eltType) {
-  if (layout == NVVM::MMALayout::row) {
-    if (shape.getM() == 8 && shape.getN() == 8 &&
-        eltType == NVVM::LdStMatrixEltType::B16) {
-      switch (num) {
-      case 1:
-        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
-      case 2:
-        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
-      case 4:
-        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
-      }
-    } else if (shape.getM() == 8 && shape.getN() == 16 &&
-               eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+  if (shape.getM() == 8 && shape.getN() == 8) {
+    switch (num) {
+    case 1:
+      return (layout == NVVM::MMALayout::row)
+                 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
+                 : llvm::Intrinsic::
+                       nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
+    case 2:
+      return (layout == NVVM::MMALayout::row)
+                 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
+                 : llvm::Intrinsic::
+                       nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
+    case 4:
+      return (layout == NVVM::MMALayout::row)
+                 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
+                 : llvm::Intrinsic::
+                       nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
+    }
+  } else if (shape.getM() == 8 && shape.getN() == 16) {
+    if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::
@@ -162,8 +169,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
         return llvm::Intrinsic::
             nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
       }
-    } else if (shape.getM() == 8 && shape.getN() == 16 &&
-               eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+    } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::
@@ -176,27 +182,15 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
             nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
       }
     }
-  } else {
-    if (shape.getM() == 8 && shape.getN() == 8 &&
-        eltType == NVVM::LdStMatrixEltType::B16) {
-      switch (num) {
-      case 1:
-        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
-      case 2:
-        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
-      case 4:
-        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
-      }
-    } else if (shape.getM() == 16 && shape.getN() == 16 &&
-               eltType == NVVM::LdStMatrixEltType::B8) {
+  } else if (shape.getM() == 16 && shape.getN() == 16) {
+    if (eltType == NVVM::LdStMatrixEltType::B8) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
       case 2:
         return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
       }
-    } else if (shape.getM() == 16 && shape.getN() == 16 &&
-               eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+    } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::
@@ -205,8 +199,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
         return llvm::Intrinsic::
             nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
       }
-    } else if (shape.getM() == 16 && shape.getN() == 16 &&
-               eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+    } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
       switch (num) {
       case 1:
         return llvm::Intrinsic::

>From 40b098b53dd7dc2c6aec1eb39a8eeb0d84b63e45 Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Wed, 23 Jul 2025 11:32:47 +0800
Subject: [PATCH 7/8] Move the negative test to nvvmir-invalid.mlir

---
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 94 ---------------------
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 94 +++++++++++++++++++++
 2 files changed, 94 insertions(+), 94 deletions(-)

diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index e0e42e46830bc..c094f37a15e66 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1114,100 +1114,6 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 // -----
 
-llvm.func @ld_matrix(%arg0: !llvm.ptr) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr) -> i32
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)>
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-  llvm.return
-}
-
-// -----
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
-  llvm.return
-}
-
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
-  llvm.return
-}
-
-// -----
-
 llvm.func @caller() {
   // expected-error @below {{expected function call to produce a value}}
   llvm.call @callee() : () -> ()
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 8c4f0aafd36a7..0f713b9c918ee 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -312,3 +312,97 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<
   nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1>
   llvm.return
 }
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error at +1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType  = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
\ No newline at end of file

>From a7d309966ebf0dfa27d1e88849c894365c931561 Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Wed, 23 Jul 2025 11:38:42 +0800
Subject: [PATCH 8/8] Keep the pointer as shared instead of any

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 2 +-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 5 -----
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 8 --------
 3 files changed, 1 insertion(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e7f7d37d6b17d..1bc7bd0d744ef 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2042,7 +2042,7 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
 
 def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   Results<(outs AnyType:$res)>,
-  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num,
+  Arguments<(ins LLVM_PointerShared: $ptr, I32Attr: $num,
                  MMALayoutAttr: $layout,
                  LdStMatrixShapeAttr: $shape,
                  LdStMatrixEltTypeAttr: $eltType)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a4912fc6990b6..42389c0008a80 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -797,11 +797,6 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
 }
 
 LogicalResult NVVM::LdMatrixOp::verify() {
-  unsigned addressSpace =
-      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
-  if (addressSpace != NVVM::kSharedMemorySpace)
-    return emitOpError("expected source pointer in memory space 3");
-
   uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
   if (m == 8 && n == 8) {
     if (num != 1 && num != 2 && num != 4) {
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0f713b9c918ee..254373ec3f47c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -315,14 +315,6 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<
 
 // -----
 
-llvm.func @ld_matrix(%arg0: !llvm.ptr) {
-  // expected-error at +1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr) -> i32
-  llvm.return
-}
-
-// -----
-
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error at +1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}}
   %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType  = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32



More information about the Mlir-commits mailing list