[Mlir-commits] [mlir] [mlir][nvgpu] Fix crash when mmaShape size is not three (PR #173490)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 24 04:47:05 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR fixes a crash when the mmaShape attribute of `nvgpu.mma.sync` does not have exactly three elements. The change replaces the ArrayAttr-based mmaShape with DenseI64ArrayAttr and adds verifier checks to ensure the attribute has three elements. Fixes #<!-- -->173378.

---

Patch is 51.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/173490.diff


14 Files Affected:

- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td (+5-19) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-4) 
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+2-2) 
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+15-11) 
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+12-12) 
- (modified) mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir (+2-2) 
- (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir (+10-10) 
- (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+34-15) 
- (modified) mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir (+3-3) 
- (modified) mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir (+2-2) 
- (modified) mlir/test/Dialect/NVGPU/roundtrip.mlir (+8-8) 
- (modified) mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir (+1-1) 
- (modified) mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
index 73d86283a5940..5b9ae8bb7a518 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
@@ -66,16 +66,6 @@ class NVGPU_MmaSyncOp<string mnemonic> :
         NVGPU_Op<mnemonic,  [Pure,
                              PredOpTrait<"matrixA and matrixB have same element type",
                                          TCopVTEtIsSameAs<0, 1>>]> {
-  code extraBaseClassDeclaration = [{
-    std::array<int64_t, 3> getMmaShapeAsArray() {
-      ArrayAttr mmaShape = this->getMmaShape();
-      assert(mmaShape.size() == 3 && "mmaShape should be three integers");
-      return {::llvm::cast<IntegerAttr>(mmaShape[0]).getInt(),
-              ::llvm::cast<IntegerAttr>(mmaShape[1]).getInt(),
-              ::llvm::cast<IntegerAttr>(mmaShape[2]).getInt()};
-    }
-  }];
-
   let hasVerifier = 1;
 }
 
@@ -96,14 +86,14 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
     Example:
 
     ```mlir
-    %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} :
+    %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = array<i64: 16, 8, 16>} :
         (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
     ```
   }];
   let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
                        AnyVectorOfNonZeroRank:$matrixB,
                        AnyVectorOfNonZeroRank:$matrixC,
-                       I64ArrayAttr:$mmaShape,
+                       DenseI64ArrayAttr:$mmaShape,
                        OptionalAttr<UnitAttr>:$tf32Enabled);
 
   let results = (outs AnyVectorOfNonZeroRank:$res);
@@ -112,7 +102,7 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
     OpBuilder<(ins "Value":$matrixA,
                    "Value":$matrixB,
                    "Value":$matrixC,
-                   "ArrayAttr":$mmaShape)>,
+                   "DenseI64ArrayAttr":$mmaShape)>,
     OpBuilder<(ins "Value":$matrixA,
                    "Value":$matrixB,
                    "Value":$matrixC,
@@ -124,8 +114,6 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
     `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
   }];
-
-  let extraClassDeclaration = extraBaseClassDeclaration;
 }
 
 def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>,
@@ -151,7 +139,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
   Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction):
 
   ```mlir
-  nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
+  nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = array<i64: 16, 8, 32>} :
     (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   ```
   }];
@@ -160,7 +148,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
                        AnyVectorOfNonZeroRank:$matrixB,
                        AnyVectorOfNonZeroRank:$matrixC,
                        NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
-                       I64ArrayAttr:$mmaShape,
+                       DenseI64ArrayAttr:$mmaShape,
                        DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
                        OptionalAttr<UnitAttr>:$tf32Enabled
                        );
@@ -179,8 +167,6 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
     `(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
   }];
-
-  let extraClassDeclaration = extraBaseClassDeclaration;
 }
 
 def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..7cd59b0f135fe 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -340,7 +340,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
     VectorType bType = op.getMatrixA().getType();
     VectorType cType = op.getMatrixC().getType();
 
-    std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
+    ArrayRef<int64_t> gemmShape = op.getMmaShape();
 
     // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
     bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
@@ -485,7 +485,7 @@ static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
 /// it's expected that the provided parameters correspond to a valid
 /// instruction.
 static std::string buildMmaSparseAsmString(
-    const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
+    ArrayRef<int64_t> shape, unsigned matASize, unsigned matBSize,
     unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
     std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
@@ -526,7 +526,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
     NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
     std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
     ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
-    int64_t metadataSelector, const std::array<int64_t, 3> &shape,
+    int64_t metadataSelector, ArrayRef<int64_t> shape,
     Type intrinsicResultType) {
   auto asmDialectAttr =
       LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
@@ -618,7 +618,7 @@ struct NVGPUMmaSparseSyncLowering
 
     FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
         b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
-        matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
+        matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShape(),
         intrinsicResTy);
     if (failed(intrinsicResult))
       return failure();
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 98434357f826f..d4aa893a5b420 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1058,8 +1058,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
-  Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
-                                          rewriter.getI64ArrayAttr({m, n, k}));
+  Value matmul =
+      nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC, {m, n, k});
   valueMapping[op.getResult()] = matmul;
   return success();
 }
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 237aab4d7f309..eb2de91d30988 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -119,7 +119,8 @@ LogicalResult DeviceAsyncCopyOp::verify() {
 //===----------------------------------------------------------------------===//
 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
                       ::mlir::OperationState &odsState, Value matrixA,
-                      Value matrixB, Value matrixC, ArrayAttr mmaShape) {
+                      Value matrixB, Value matrixC,
+                      DenseI64ArrayAttr mmaShape) {
   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
         mmaShape, UnitAttr());
 }
@@ -129,8 +130,7 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
                       Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
                       bool tf32Enabled) {
   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
-        odsBuilder.getI64ArrayAttr(mmaShape),
-        tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
+        mmaShape, tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
 }
 
 /// Performs verification for MmaSyncOp and MmaSparseSyncOp.
@@ -138,7 +138,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
                                      TypedValue<VectorType> matrixA,
                                      TypedValue<VectorType> matrixB,
                                      TypedValue<VectorType> matrixC,
-                                     const std::array<int64_t, 3> &mmaShape,
+                                     ArrayRef<int64_t> mmaShape,
                                      bool tf32Enabled, bool sparse = false) {
   // The verification for mma.sync covering various shapes and data types is
   // based on the fundamental tensor core shape.
@@ -209,7 +209,12 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
     return op->emitError() << "matrixC must be 2 dimensional vector";
   }
 
-  auto [m, n, k] = mmaShape;
+  if (mmaShape.size() != 3) {
+    return op->emitError() << "mmaShape should be three integers";
+  }
+  int64_t m = mmaShape[0];
+  int64_t n = mmaShape[1];
+  int64_t k = mmaShape[2];
 
   // verify warp-wide size for vector a
   int64_t sparseFactor = sparse ? 2 : 1;
@@ -262,7 +267,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
 
 LogicalResult MmaSyncOp::verify() {
   return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
-                         getMatrixC(), getMmaShapeAsArray(),
+                         getMatrixC(), getMmaShape(),
                          getOperation()->hasAttr(getTf32EnabledAttrName()));
 }
 
@@ -274,17 +279,16 @@ void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
                             Value matrixB, Value matrixC, Value sparseMetadata,
                             ArrayRef<int64_t> mmaShape) {
   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
-        sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
+        sparseMetadata, mmaShape, 0, UnitAttr());
 }
 
 LogicalResult MmaSparseSyncOp::verify() {
   unsigned sparsitySelector = getSparsitySelector();
   if (sparsitySelector > 1)
     return emitOpError() << "sparsity selector should be 0 or 1";
-  return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
-                         getMatrixC(), getMmaShapeAsArray(),
-                         getOperation()->hasAttr(getTf32EnabledAttrName()),
-                         true);
+  return verifyMmaSyncOp(
+      this->getOperation(), getMatrixA(), getMatrixB(), getMatrixC(),
+      getMmaShape(), getOperation()->hasAttr(getTf32EnabledAttrName()), true);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0eb44789fe31d..74bf571b4a64e 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -14,7 +14,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
   // CHECK-NOT: llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   // CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -31,7 +31,7 @@ func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %a
   // CHECK: [[d:%.+]] = nvvm.mma.sync
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
   // CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
   // CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -59,7 +59,7 @@ func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: v
   // CHECK-NOT: llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 8>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
   // CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -90,7 +90,7 @@ func.func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: ve
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
 
@@ -109,7 +109,7 @@ func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vect
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
 
@@ -134,7 +134,7 @@ func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vect
   // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 64>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 64>} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
 
@@ -145,7 +145,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
   // CHECK: llvm.extractvalue
   // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
   // CHECK-SAME: shape = #nvvm.shape<m = 8, n = 8, k = 4>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
   // CHECK: llvm.mlir.poison : vector<2xf64>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
@@ -201,7 +201,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
   // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
   // CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -370,7 +370,7 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
   // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 32>} :
     (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
 
   // CHECK-DAG: llvm.extractvalue %[[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -406,7 +406,7 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
   // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>} :
     (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
@@ -427,7 +427,7 @@ func.func @mma_sp_sync_f16_16816_01(%arg0: vector<2x2xf16>,
   // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 
   %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3)
-       {mmaShape = [16, 8, 16], sparsitySelector = 1 : i32} :
+       {mmaShape = array<i64: 16, 8, 16>, sparsitySelector = 1 : i32} :
        (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   return %d : vector<2x2xf16>
 }
@@ -465,7 +465,7 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
   // CHECK-SAME: %[[sparseMetadata]] :
   // CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
 
-  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 64>} :
     (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
   return %d : vector<2x2xi32>
 }
diff --git a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
index 0afaa19d59d15..c42f5add697f0 100644
--- a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
@@ -29,7 +29,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
   %B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
   %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
   
-  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
   %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
   vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
 
@@ -38,7 +38,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
   %B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
 ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/173490


More information about the Mlir-commits mailing list