[Mlir-commits] [mlir] [mlir][nvgpu] Fix crash when mmaShape size is not three (PR #173490)

Longsheng Mou llvmlistbot at llvm.org
Wed Dec 24 04:56:35 PST 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173490

>From 8f76afbf35a91f122dc6bb3f050e38a86675d814 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Wed, 24 Dec 2025 20:40:51 +0800
Subject: [PATCH] [mlir][nvgpu] Fix crash when mmaShape size is not three

This PR fixes a crash when the mmaShape attribute of `nvgpu.mma.sync`
does not have exactly three elements. The change replaces the ArrayAttr-based
mmaShape with DenseI64ArrayAttr and adds verifier checks to ensure the
attribute has three elements.
---
 .../include/mlir/Dialect/NVGPU/IR/NVGPUOps.td | 24 ++-------
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |  9 ++--
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  4 +-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    | 26 +++++-----
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 24 ++++-----
 ...fold-arith-vector-to-mma-ops-mma-sync.mlir |  4 +-
 .../vector-to-mma-ops-mma-sync.mlir           | 20 ++++----
 mlir/test/Dialect/NVGPU/invalid.mlir          | 49 +++++++++++++------
 .../Dialect/NVGPU/mma-sync-f32-to-tf32.mlir   |  6 +--
 .../Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir |  4 +-
 mlir/test/Dialect/NVGPU/roundtrip.mlir        | 16 +++---
 .../NVGPU/transform-matmul-to-nvvm.mlir       |  4 +-
 .../GPU/CUDA/sparse-mma-2-4-f16.mlir          |  2 +-
 .../sm80/transform-mma-sync-matmul-f32.mlir   |  2 +-
 14 files changed, 102 insertions(+), 92 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
index 73d86283a5940..5b9ae8bb7a518 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
@@ -66,16 +66,6 @@ class NVGPU_MmaSyncOp<string mnemonic> :
         NVGPU_Op<mnemonic,  [Pure,
                              PredOpTrait<"matrixA and matrixB have same element type",
                                          TCopVTEtIsSameAs<0, 1>>]> {
-  code extraBaseClassDeclaration = [{
-    std::array<int64_t, 3> getMmaShapeAsArray() {
-      ArrayAttr mmaShape = this->getMmaShape();
-      assert(mmaShape.size() == 3 && "mmaShape should be three integers");
-      return {::llvm::cast<IntegerAttr>(mmaShape[0]).getInt(),
-              ::llvm::cast<IntegerAttr>(mmaShape[1]).getInt(),
-              ::llvm::cast<IntegerAttr>(mmaShape[2]).getInt()};
-    }
-  }];
-
   let hasVerifier = 1;
 }
 
@@ -96,14 +86,14 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
     Example:
 
     ```mlir
-    %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} :
+    %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = array<i64: 16, 8, 16>} :
         (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
     ```
   }];
   let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
                        AnyVectorOfNonZeroRank:$matrixB,
                        AnyVectorOfNonZeroRank:$matrixC,
-                       I64ArrayAttr:$mmaShape,
+                       DenseI64ArrayAttr:$mmaShape,
                        OptionalAttr<UnitAttr>:$tf32Enabled);
 
   let results = (outs AnyVectorOfNonZeroRank:$res);
@@ -112,7 +102,7 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
     OpBuilder<(ins "Value":$matrixA,
                    "Value":$matrixB,
                    "Value":$matrixC,
-                   "ArrayAttr":$mmaShape)>,
+                   "DenseI64ArrayAttr":$mmaShape)>,
     OpBuilder<(ins "Value":$matrixA,
                    "Value":$matrixB,
                    "Value":$matrixC,
@@ -124,8 +114,6 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
     `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
   }];
-
-  let extraClassDeclaration = extraBaseClassDeclaration;
 }
 
 def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>,
@@ -151,7 +139,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
   Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction):
 
   ```mlir
-  nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
+  nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = array<i64: 16, 8, 32>} :
     (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   ```
   }];
@@ -160,7 +148,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
                        AnyVectorOfNonZeroRank:$matrixB,
                        AnyVectorOfNonZeroRank:$matrixC,
                        NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
-                       I64ArrayAttr:$mmaShape,
+                       DenseI64ArrayAttr:$mmaShape,
                        DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
                        OptionalAttr<UnitAttr>:$tf32Enabled
                        );
@@ -179,8 +167,6 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
     `(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
   }];
-
-  let extraClassDeclaration = extraBaseClassDeclaration;
 }
 
 def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..22b88b871efaa 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -340,7 +340,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
     VectorType bType = op.getMatrixA().getType();
     VectorType cType = op.getMatrixC().getType();
 
-    std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
+    ArrayRef<int64_t> gemmShape = op.getMmaShape();
 
     // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
     bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
@@ -485,7 +485,7 @@ static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
 /// it's expected that the provided parameters correspond to a valid
 /// instruction.
 static std::string buildMmaSparseAsmString(
-    const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
+    ArrayRef<int64_t> shape, unsigned matASize, unsigned matBSize,
     unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
     std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
@@ -495,6 +495,7 @@ static std::string buildMmaSparseAsmString(
 
   std::string asmStr;
   llvm::raw_string_ostream ss(asmStr);
+  assert(shape.size() == 3 && "mmaShape should be three integers");
   ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
      << shape[2] << ".row.col.";
 
@@ -526,7 +527,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
     std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
     ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
-    int64_t metadataSelector, const std::array<int64_t, 3> &shape,
+    int64_t metadataSelector, ArrayRef<int64_t> shape,
     Type intrinsicResultType) {
   auto asmDialectAttr =
       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
@@ -618,7 +619,7 @@ struct NVGPUMmaSparseSyncLowering
 
     FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
         b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
-        matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
+        matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShape(),
         intrinsicResTy);
     if (failed(intrinsicResult))
       return failure();
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 98434357f826f..d4aa893a5b420 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1058,8 +1058,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
-  Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
-                                          rewriter.getI64ArrayAttr({m, n, k}));
+  Value matmul =
+      nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC, {m, n, k});
   valueMapping[op.getResult()] = matmul;
   return success();
 }
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 237aab4d7f309..eb2de91d30988 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -119,7 +119,8 @@ LogicalResult DeviceAsyncCopyOp::verify() {
 //===----------------------------------------------------------------------===//
 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
                       ::mlir::OperationState &odsState, Value matrixA,
-                      Value matrixB, Value matrixC, ArrayAttr mmaShape) {
+                      Value matrixB, Value matrixC,
+                      DenseI64ArrayAttr mmaShape) {
   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
         mmaShape, UnitAttr());
 }
@@ -129,8 +130,7 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
                       Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
                       bool tf32Enabled) {
   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
-        odsBuilder.getI64ArrayAttr(mmaShape),
-        tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
+        mmaShape, tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
 }
 
 /// Performs verification for MmaSyncOp and MmaSparseSyncOp.
@@ -138,7 +138,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
                                      TypedValue<VectorType> matrixA,
                                      TypedValue<VectorType> matrixB,
                                      TypedValue<VectorType> matrixC,
-                                     const std::array<int64_t, 3> &mmaShape,
+                                     ArrayRef<int64_t> mmaShape,
                                      bool tf32Enabled, bool sparse = false) {
   // The verification for mma.sync covering various shapes and data types is
   // based on the fundamental tensor core shape.
@@ -209,7 +209,12 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
     return op->emitError() << "matrixC must be 2 dimensional vector";
   }
 
-  auto [m, n, k] = mmaShape;
+  if (mmaShape.size() != 3) {
+    return op->emitError() << "mmaShape should be three integers";
+  }
+  int64_t m = mmaShape[0];
+  int64_t n = mmaShape[1];
+  int64_t k = mmaShape[2];
 
   // verify warp-wide size for vector a
   int64_t sparseFactor = sparse ? 2 : 1;
@@ -262,7 +267,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
 
 LogicalResult MmaSyncOp::verify() {
   return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
-                         getMatrixC(), getMmaShapeAsArray(),
+                         getMatrixC(), getMmaShape(),
                          getOperation()->hasAttr(getTf32EnabledAttrName()));
 }
 
@@ -274,17 +279,16 @@ void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
                             Value matrixB, Value matrixC, Value sparseMetadata,
                             ArrayRef<int64_t> mmaShape) {
   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
-        sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
+        sparseMetadata, mmaShape, 0, UnitAttr());
 }
 
 LogicalResult MmaSparseSyncOp::verify() {
   unsigned sparsitySelector = getSparsitySelector();
   if (sparsitySelector > 1)
     return emitOpError() << "sparsity selector should be 0 or 1";
-  return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
-                         getMatrixC(), getMmaShapeAsArray(),
-                         getOperation()->hasAttr(getTf32EnabledAttrName()),
-                         true);
+  return verifyMmaSyncOp(
+      this->getOperation(), getMatrixA(), getMatrixB(), getMatrixC(),
+      getMmaShape(), getOperation()->hasAttr(getTf32EnabledAttrName()), true);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0eb44789fe31d..74bf571b4a64e 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -14,7 +14,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
   // CHECK-NOT: llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   // CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -31,7 +31,7 @@ func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %a
   // CHECK: [[d:%.+]] = nvvm.mma.sync
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
   // CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
   // CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -59,7 +59,7 @@ func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: v
   // CHECK-NOT: llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 8>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   // CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -90,7 +90,7 @@ func.func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: ve
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
 
@@ -109,7 +109,7 @@ func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vect
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
 
@@ -134,7 +134,7 @@ func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vect
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 64>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 64>} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
 
@@ -145,7 +145,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
   // CHECK: llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
   // CHECK-SAME: shape = #nvvm.shape<m = 8, n = 8, k = 4>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
   // CHECK: llvm.mlir.poison : vector<2xf64>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
@@ -201,7 +201,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
   // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
   // CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -370,7 +370,7 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
   // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 32>} :
     (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
 
   // CHECK-DAG: llvm.extractvalue %[[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -406,7 +406,7 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
   // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>} :
     (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
@@ -427,7 +427,7 @@ func.func @mma_sp_sync_f16_16816_01(%arg0: vector<2x2xf16>,
   // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 
   %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3)
-       {mmaShape = [16, 8, 16], sparsitySelector = 1 : i32} :
+       {mmaShape = array<i64: 16, 8, 16>, sparsitySelector = 1 : i32} :
        (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
@@ -465,7 +465,7 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
   // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
 
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 64>} :
     (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
diff --git a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
index 0afaa19d59d15..c42f5add697f0 100644
--- a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
@@ -29,7 +29,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
   %B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
   %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
   
-  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
   %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
   vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
 
@@ -38,7 +38,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
   %B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
   %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
 
-  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
   %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B1_f32, %C1 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
   vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
 
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
index 912f7fba59e60..b11d79a207fba 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
@@ -88,7 +88,7 @@ func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, #gpu.address_spac
   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, #gpu.address_space<workgroup>>, vector<16x32xi8>
   %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, #gpu.address_space<workgroup>>, vector<8x32xi8>
   %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
-  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
 
   // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
@@ -153,7 +153,7 @@ func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x
   %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64>
   %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64>
   %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64>
-  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64>
 
   // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
@@ -234,7 +234,7 @@ func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16,
 
   // CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
   // CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
-  // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
   %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
   %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
@@ -242,7 +242,7 @@ func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16,
 
   // CHECK-DAG: [[fragmentB1:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
   // CHECK-DAG: [[fragmentC1:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
-  // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
   %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
   %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
@@ -290,7 +290,7 @@ func.func @multi_dim_m16n8k16_fp16_row_row_row(%arg0: memref<4x32x1x32xf16, #gpu
 
   // CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
   // CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
-  // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
   %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
   %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
@@ -424,7 +424,7 @@ func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, #gpu.address_sp
   // CHECK: [[b_frag:%.+]] = vector.insert [[b_el]], {{.*}} : f32 into vector<1x1xf32>
 
   // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
-  // CHECK-SAME: mmaShape = [16, 8, 4]
+  // CHECK-SAME: mmaShape = array<i64: 16, 8, 4>
   // CHECK-SAME: -> vector<2x2xf32>
   %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<16x4xf32>
   %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<8x4xf32>
@@ -487,7 +487,7 @@ func.func @m16n8k8_tf32_f32_row_row_row(%arg0: memref<20x20xf32, #gpu.address_sp
   // CHECK: [[b_frag1:%.+]] = vector.insert [[b_el1]], {{.*}} : f32 into vector<2x1xf32>
 
   // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag1]], [[c_frag]])
-  // CHECK-SAME: mmaShape = [16, 8, 8]
+  // CHECK-SAME: mmaShape = array<i64: 16, 8, 8>
   // CHECK-SAME: -> vector<2x2xf32>
   %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<16x8xf32>
   %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<8x8xf32>
@@ -557,7 +557,7 @@ func.func @m16n8k8_tf32_f32_col_col_row(%arg0: memref<20x20xf32, #gpu.address_sp
   // CHECK: [[b_frag:%.+]] = nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}
 
   // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
-  // CHECK-SAME: mmaShape = [16, 8, 8]
+  // CHECK-SAME: mmaShape = array<i64: 16, 8, 8>
   // CHECK-SAME: -> vector<2x2xf32>
   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<16x8xf32>
   %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<8x8xf32>
@@ -628,7 +628,7 @@ func.func @m16n8k64_int4_row_col_row(%arg0: memref<128x128xi4, #gpu.address_spac
   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi4, #gpu.address_space<workgroup>>, vector<16x64xi4>
   %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi4, #gpu.address_space<workgroup>>, vector<8x64xi4>
   %C = vector.transfer_read %arg2[%c0, %c0], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
-  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 64>} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x64xi4>, vector<8x64xi4> into vector<16x8xi32>
 
   // CHECK: [[lane:%.+]] = gpu.lane_id
@@ -698,7 +698,7 @@ func.func @m16n8k32_int8_row_col_row(%arg0: memref<128x128xi8, #gpu.address_spac
   %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, #gpu.address_space<workgroup>>, vector<16x32xi8>
   %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, #gpu.address_space<workgroup>>, vector<8x32xi8>
   %C = vector.transfer_read %arg2[%c0, %c0], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
-  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
 
   // CHECK: [[lane:%.+]] = gpu.lane_id
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index f735e3f8cc623..0376a520bfaa7 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -58,81 +58,88 @@ func.func @ldmatrix_type_x4(%arg0: memref<128x128xf32, 3>) ->  vector<4x2xf16> {
 
 func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
   // expected-error @+1 {{expected 256 warp-wide matrix A elements}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
 // -----
 
 func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
   // expected-error @+1 {{expected 128 warp-wide matrix B elements}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
 // -----
 
 func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> {
   // expected-error @+1 {{expected 128 warp-wide matrix C elements}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>
   return %d : vector<2x4xf16>
 }
 // -----
 
 func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
   // expected-error @+1 {{expected matrix A to be shaped (4 x 2)}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
 // -----
 
 func.func @m16n8k16_fp16_tf32Enabled(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
   // expected-error @+1 {{expected tf32 tensor cores only for F32 operands}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16], tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>, tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
 // -----
 
 func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // expected-error @+1 {{expected 128 warp-wide matrix A elements}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
   return %d : vector<2x2xf32>
 }
 // -----
 
 func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // expected-error @+1 {{expected matrix A to be shaped (4 x 1)}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
   return %d : vector<2x2xf32>
 }
 // -----
 
 func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
   // expected-error @+1 {{expected 32 warp-wide matrix A elements}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
   return %d : vector<1x2xf64>
 }
 // -----
 
 func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> {
   // expected-error @+1 {{expected matrix C to be shaped (1 x 2)}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>
   return %d : vector<2x1xf64>
 }
 // -----
 
 func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
   // expected-error @+1 {{expected 256 warp-wide matrix B elements}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
 // -----
 
 func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
   // expected-error @+1 {{op failed to verify that matrixA and matrixB have same element type}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
 // -----
 
+func.func @mma_sync_check_mmaShape(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{mmaShape should be three integers}}
+  %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+// -----
+
 func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
   // expected-error @below {{destination memref must have a memory space attribute of IntegerAttr(3) or gpu::AddressSpaceAttr(Workgroup)}}
   nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xf32>
@@ -189,13 +196,25 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
                                  %arg2: vector<2x2xf16>,
                                  %arg3: vector<2xi16>) -> vector<2x2xf16> {
   // expected-error @+1 {{'nvgpu.mma.sp.sync' op sparsity selector should be 0 or 1}}
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16], sparsitySelector = 42 : i32} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>, sparsitySelector = 42 : i32} :
        (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
 
 // -----
 
+func.func @mma_sp_sync_check_mmaShape(%arg0: vector<4x2xf16>,
+                                      %arg1: vector<4x2xf16>,
+                                      %arg2: vector<2x2xf16>,
+                                      %arg3: vector<2xi16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{mmaShape should be three integers}}
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8>} :
+       (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
 func.func @async_cp_zfill_f32_align1(
   %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
     // expected-error @+1 {{'nvgpu.device_async_copy' op bypassL1 does not satify alignment for 'memref<3x16x128xf32, 3>' with destination element 1. Unset bypassL1, or set destination element to 4}}
@@ -359,7 +378,7 @@ func.func @rcp_unsupported_ftz(%in : vector<16xf32>) {
 
 func.func @check_matrixA_dim(%arg0: vector<16xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
   // expected-error @+1 {{matrixA must be 2 dimensional vector}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<16xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<16xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
 
@@ -367,7 +386,7 @@ func.func @check_matrixA_dim(%arg0: vector<16xf16>, %arg1: vector<2x2xf16>, %arg
 
 func.func @check_matrixB_dim(%arg0: vector<4x4xf16>, %arg1: vector<4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
   // expected-error @+1 {{matrixB must be 2 dimensional vector}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x4xf16>, vector<4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
 
@@ -375,7 +394,7 @@ func.func @check_matrixB_dim(%arg0: vector<4x4xf16>, %arg1: vector<4xf16>, %arg2
 
 func.func @check_matrixC_dim(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<4xf16>) -> vector<2x2xf16> {
   // expected-error @+1 {{matrixC must be 2 dimensional vector}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
 
diff --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
index 6d0cf348273c3..1a9151c4c8a36 100644
--- a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
+++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
@@ -4,7 +4,7 @@
 func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // CHECK: nvgpu.mma.sync
   // CHECK-SAME: tf32Enabled
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
   return %d : vector<2x2xf32>
 }
 
@@ -14,7 +14,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
 func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // CHECK: nvgpu.mma.sync
   // CHECK-SAME: tf32Enabled
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
   return %d : vector<2x2xf32>
 }
 // -----
@@ -24,6 +24,6 @@ func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: v
 //   CHECK-NOT: tf32Enabled
 //       CHECK: return
 func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
diff --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
index 7db60254d2ec8..6fae05a252887 100644
--- a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
+++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
@@ -3,7 +3,7 @@
 // CHECK-LABEL: m16n8k4_tf32
 func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
   return %d : vector<2x2xf32>
 }
 
@@ -12,7 +12,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
 // CHECK-LABEL: m16n8k8_tf32
 func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
   return %d : vector<2x2xf32>
 }
 // -----
diff --git a/mlir/test/Dialect/NVGPU/roundtrip.mlir b/mlir/test/Dialect/NVGPU/roundtrip.mlir
index ad516b4d2c200..667c89b1f6c49 100644
--- a/mlir/test/Dialect/NVGPU/roundtrip.mlir
+++ b/mlir/test/Dialect/NVGPU/roundtrip.mlir
@@ -13,8 +13,8 @@ func.func @ldmatrix(%arg0: memref<?x?xf16, 3>, %x: index, %y: index) {
 func.func @mma_sync(%arg0: vector<4x2xf16>,
                %arg1: vector<2x2xf16>,
                %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
-//       CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
-  %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} :
+//       CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} :
     (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
@@ -25,9 +25,9 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
                                  %arg2: vector<2x2xf16>,
                                  %arg3: vector<2xi16>) -> vector<2x2xf16> {
   //      CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
-  // CHECK-SAME:   mmaShape = [16, 8, 32]
+  // CHECK-SAME:   mmaShape = array<i64: 16, 8, 32>
   // CHECK-SAME: (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 32>} :
     (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
@@ -38,9 +38,9 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
                                  %arg2: vector<2x2xf16>,
                                  %arg3: vector<2xi16>) -> vector<2x2xf16> {
   //      CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
-  // CHECK-SAME:   mmaShape = [16, 8, 16]
+  // CHECK-SAME:   mmaShape = array<i64: 16, 8, 16>
   // CHECK-SAME: (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>} :
     (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
@@ -51,9 +51,9 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
                                 %arg2: vector<2x2xi32>,
                                 %arg3: vector<2xi16>) -> vector<2x2xi32> {
   //      CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
-  // CHECK-SAME:   mmaShape = [16, 8, 64]
+  // CHECK-SAME:   mmaShape = array<i64: 16, 8, 64>
   // CHECK-SAME: (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 64>} :
     (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
index bbe27fe1b99d9..dd9d43e1be795 100644
--- a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -48,7 +48,7 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>
 // CHECK:           %[[RES:.*]] = vector.insert %[[VAL_29]], %[[VAL_33]] [1, 1] : f32 into vector<2x2xf32>
 //
-// CHECK:           %[[VAL_35:.*]] = nvgpu.mma.sync(%[[LHS]], %[[RHS]], %[[RES]]) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+// CHECK:           %[[VAL_35:.*]] = nvgpu.mma.sync(%[[LHS]], %[[RHS]], %[[RES]]) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
 //
 // CHECK:           %[[VAL_36:.*]] = vector.extract %[[VAL_35]][0, 0] : f32 from vector<2x2xf32>
 // CHECK:           %[[VAL_37:.*]] = vector.extract %[[VAL_35]][0, 1] : f32 from vector<2x2xf32>
@@ -96,7 +96,7 @@ func.func @matmul_16x8x16xf16_global(
   // CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf16>
   // CHECK-COUNT-4: vector.insert {{.*}} : f16 into vector<2x2xf16>
   //
-  //         CHECK: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 16]}
+  //         CHECK: nvgpu.mma.sync(%{{.*}}) {mmaShape = array<i64: 16, 8, 16>}
   //    CHECK-SAME:   : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   //
   // CHECK-COUNT-4: vector.extract %{{.*}} : f16 from vector<2x2xf16>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
index d8d7c1c39db91..00338c6486413 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
@@ -194,7 +194,7 @@ module attributes {gpu.container_module} {
       // within each group of four threads contribute metadata.
       %d = nvgpu.mma.sp.sync(%A_data, %B_data, %accum)
            metadata(%meta)
-           {mmaShape = [16, 8, 32]} : (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+           {mmaShape = array<i64: 16, 8, 32>} : (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
 
       //===----------------------------------------------------------------------===//
       // Write back results to gpu global memory
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
index 2eef2ff8f3564..46ce7e8b55b2f 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
@@ -3,7 +3,7 @@
 // RUN: | FileCheck %s --check-prefix=CHECK-MMA-SYNC
 
 // CHECK-MMA-SYNC-LABEL: func @main() {
-//       CHECK-MMA-SYNC:   nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 4], tf32Enabled}
+//       CHECK-MMA-SYNC:   nvgpu.mma.sync(%{{.*}}) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled}
 //  CHECK-MMA-SYNC-SAME:     : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
 
 // Tested to run locally in 1.7s.



More information about the Mlir-commits mailing list