[Mlir-commits] [mlir] [mlir][nvgpu] Fix crash when mmaShape size is not three (PR #173490)
Longsheng Mou
llvmlistbot at llvm.org
Wed Dec 24 04:56:35 PST 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/173490
>From 8f76afbf35a91f122dc6bb3f050e38a86675d814 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Wed, 24 Dec 2025 20:40:51 +0800
Subject: [PATCH] [mlir][nvgpu] Fix crash when mmaShape size is not three
This PR fixes a crash when the mmaShape attribute of `nvgpu.mma.sync`
does not have exactly three elements. The change replaces the ArrayAttr-based
mmaShape with DenseI64ArrayAttr and adds verifier checks to ensure the
attribute has three elements.
---
.../include/mlir/Dialect/NVGPU/IR/NVGPUOps.td | 24 ++-------
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 9 ++--
.../Conversion/VectorToGPU/VectorToGPU.cpp | 4 +-
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 26 +++++-----
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 24 ++++-----
...fold-arith-vector-to-mma-ops-mma-sync.mlir | 4 +-
.../vector-to-mma-ops-mma-sync.mlir | 20 ++++----
mlir/test/Dialect/NVGPU/invalid.mlir | 49 +++++++++++++------
.../Dialect/NVGPU/mma-sync-f32-to-tf32.mlir | 6 +--
.../Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir | 4 +-
mlir/test/Dialect/NVGPU/roundtrip.mlir | 16 +++---
.../NVGPU/transform-matmul-to-nvvm.mlir | 4 +-
.../GPU/CUDA/sparse-mma-2-4-f16.mlir | 2 +-
.../sm80/transform-mma-sync-matmul-f32.mlir | 2 +-
14 files changed, 102 insertions(+), 92 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
index 73d86283a5940..5b9ae8bb7a518 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
@@ -66,16 +66,6 @@ class NVGPU_MmaSyncOp<string mnemonic> :
NVGPU_Op<mnemonic, [Pure,
PredOpTrait<"matrixA and matrixB have same element type",
TCopVTEtIsSameAs<0, 1>>]> {
- code extraBaseClassDeclaration = [{
- std::array<int64_t, 3> getMmaShapeAsArray() {
- ArrayAttr mmaShape = this->getMmaShape();
- assert(mmaShape.size() == 3 && "mmaShape should be three integers");
- return {::llvm::cast<IntegerAttr>(mmaShape[0]).getInt(),
- ::llvm::cast<IntegerAttr>(mmaShape[1]).getInt(),
- ::llvm::cast<IntegerAttr>(mmaShape[2]).getInt()};
- }
- }];
-
let hasVerifier = 1;
}
@@ -96,14 +86,14 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
Example:
```mlir
- %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} :
+ %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = array<i64: 16, 8, 16>} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
```
}];
let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
AnyVectorOfNonZeroRank:$matrixB,
AnyVectorOfNonZeroRank:$matrixC,
- I64ArrayAttr:$mmaShape,
+ DenseI64ArrayAttr:$mmaShape,
OptionalAttr<UnitAttr>:$tf32Enabled);
let results = (outs AnyVectorOfNonZeroRank:$res);
@@ -112,7 +102,7 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
OpBuilder<(ins "Value":$matrixA,
"Value":$matrixB,
"Value":$matrixC,
- "ArrayAttr":$mmaShape)>,
+ "DenseI64ArrayAttr":$mmaShape)>,
OpBuilder<(ins "Value":$matrixA,
"Value":$matrixB,
"Value":$matrixC,
@@ -124,8 +114,6 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
`(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
-
- let extraClassDeclaration = extraBaseClassDeclaration;
}
def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>,
@@ -151,7 +139,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction):
```mlir
- nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
+ nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = array<i64: 16, 8, 32>} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
```
}];
@@ -160,7 +148,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
AnyVectorOfNonZeroRank:$matrixB,
AnyVectorOfNonZeroRank:$matrixC,
NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
- I64ArrayAttr:$mmaShape,
+ DenseI64ArrayAttr:$mmaShape,
DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
OptionalAttr<UnitAttr>:$tf32Enabled
);
@@ -179,8 +167,6 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
`(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
-
- let extraClassDeclaration = extraBaseClassDeclaration;
}
def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..22b88b871efaa 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -340,7 +340,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
VectorType bType = op.getMatrixA().getType();
VectorType cType = op.getMatrixC().getType();
- std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
+ ArrayRef<int64_t> gemmShape = op.getMmaShape();
// Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
@@ -485,7 +485,7 @@ static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
/// it's expected that the provided parameters correspond to a valid
/// instruction.
static std::string buildMmaSparseAsmString(
- const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
+ ArrayRef<int64_t> shape, unsigned matASize, unsigned matBSize,
unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
@@ -495,6 +495,7 @@ static std::string buildMmaSparseAsmString(
std::string asmStr;
llvm::raw_string_ostream ss(asmStr);
+ assert(shape.size() == 3 && "mmaShape should be three integers");
ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
<< shape[2] << ".row.col.";
@@ -526,7 +527,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
- int64_t metadataSelector, const std::array<int64_t, 3> &shape,
+ int64_t metadataSelector, ArrayRef<int64_t> shape,
Type intrinsicResultType) {
auto asmDialectAttr =
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
@@ -618,7 +619,7 @@ struct NVGPUMmaSparseSyncLowering
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
- matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
+ matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShape(),
intrinsicResTy);
if (failed(intrinsicResult))
return failure();
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 98434357f826f..d4aa893a5b420 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1058,8 +1058,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
- Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
- rewriter.getI64ArrayAttr({m, n, k}));
+ Value matmul =
+ nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC, {m, n, k});
valueMapping[op.getResult()] = matmul;
return success();
}
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 237aab4d7f309..eb2de91d30988 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -119,7 +119,8 @@ LogicalResult DeviceAsyncCopyOp::verify() {
//===----------------------------------------------------------------------===//
void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, Value matrixA,
- Value matrixB, Value matrixC, ArrayAttr mmaShape) {
+ Value matrixB, Value matrixC,
+ DenseI64ArrayAttr mmaShape) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
mmaShape, UnitAttr());
}
@@ -129,8 +130,7 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
bool tf32Enabled) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
- odsBuilder.getI64ArrayAttr(mmaShape),
- tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
+ mmaShape, tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
}
/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
@@ -138,7 +138,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
TypedValue<VectorType> matrixA,
TypedValue<VectorType> matrixB,
TypedValue<VectorType> matrixC,
- const std::array<int64_t, 3> &mmaShape,
+ ArrayRef<int64_t> mmaShape,
bool tf32Enabled, bool sparse = false) {
// The verification for mma.sync covering various shapes and data types is
// based on the fundamental tensor core shape.
@@ -209,7 +209,12 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
return op->emitError() << "matrixC must be 2 dimensional vector";
}
- auto [m, n, k] = mmaShape;
+ if (mmaShape.size() != 3) {
+ return op->emitError() << "mmaShape should be three integers";
+ }
+ int64_t m = mmaShape[0];
+ int64_t n = mmaShape[1];
+ int64_t k = mmaShape[2];
// verify warp-wide size for vector a
int64_t sparseFactor = sparse ? 2 : 1;
@@ -262,7 +267,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
LogicalResult MmaSyncOp::verify() {
return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
- getMatrixC(), getMmaShapeAsArray(),
+ getMatrixC(), getMmaShape(),
getOperation()->hasAttr(getTf32EnabledAttrName()));
}
@@ -274,17 +279,16 @@ void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
Value matrixB, Value matrixC, Value sparseMetadata,
ArrayRef<int64_t> mmaShape) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
- sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
+ sparseMetadata, mmaShape, 0, UnitAttr());
}
LogicalResult MmaSparseSyncOp::verify() {
unsigned sparsitySelector = getSparsitySelector();
if (sparsitySelector > 1)
return emitOpError() << "sparsity selector should be 0 or 1";
- return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
- getMatrixC(), getMmaShapeAsArray(),
- getOperation()->hasAttr(getTf32EnabledAttrName()),
- true);
+ return verifyMmaSyncOp(
+ this->getOperation(), getMatrixA(), getMatrixB(), getMatrixC(),
+ getMmaShape(), getOperation()->hasAttr(getTf32EnabledAttrName()), true);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0eb44789fe31d..74bf571b4a64e 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -14,7 +14,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
// CHECK-NOT: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -31,7 +31,7 @@ func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %a
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
// CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -59,7 +59,7 @@ func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: v
// CHECK-NOT: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 8>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -90,7 +90,7 @@ func.func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: ve
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -109,7 +109,7 @@ func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vect
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -134,7 +134,7 @@ func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vect
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 64>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 64>} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -145,7 +145,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
// CHECK-SAME: shape = #nvvm.shape<m = 8, n = 8, k = 4>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
// CHECK: llvm.mlir.poison : vector<2xf64>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
@@ -201,7 +201,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -370,7 +370,7 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 32>} :
(vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue %[[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -406,7 +406,7 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>} :
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -427,7 +427,7 @@ func.func @mma_sp_sync_f16_16816_01(%arg0: vector<2x2xf16>,
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3)
- {mmaShape = [16, 8, 16], sparsitySelector = 1 : i32} :
+ {mmaShape = array<i64: 16, 8, 16>, sparsitySelector = 1 : i32} :
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -465,7 +465,7 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 64>} :
(vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
diff --git a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
index 0afaa19d59d15..c42f5add697f0 100644
--- a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
@@ -29,7 +29,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
%B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
%C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
- // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
%D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
@@ -38,7 +38,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
%B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
%C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
- // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
%D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B1_f32, %C1 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
index 912f7fba59e60..b11d79a207fba 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
@@ -88,7 +88,7 @@ func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, #gpu.address_spac
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, #gpu.address_space<workgroup>>, vector<16x32xi8>
%B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, #gpu.address_space<workgroup>>, vector<8x32xi8>
%C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
- // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
// CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
@@ -153,7 +153,7 @@ func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x
%A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64>
%B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64>
%C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64>
- // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
@@ -234,7 +234,7 @@ func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16,
// CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
// CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
- // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
%B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
%C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
%D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
@@ -242,7 +242,7 @@ func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16,
// CHECK-DAG: [[fragmentB1:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
// CHECK-DAG: [[fragmentC1:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
- // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
%B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
%C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
%D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
@@ -290,7 +290,7 @@ func.func @multi_dim_m16n8k16_fp16_row_row_row(%arg0: memref<4x32x1x32xf16, #gpu
// CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
// CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
- // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
%B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
%C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
%D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
@@ -424,7 +424,7 @@ func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, #gpu.address_sp
// CHECK: [[b_frag:%.+]] = vector.insert [[b_el]], {{.*}} : f32 into vector<1x1xf32>
// CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
- // CHECK-SAME: mmaShape = [16, 8, 4]
+ // CHECK-SAME: mmaShape = array<i64: 16, 8, 4>
// CHECK-SAME: -> vector<2x2xf32>
%A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<16x4xf32>
%B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<8x4xf32>
@@ -487,7 +487,7 @@ func.func @m16n8k8_tf32_f32_row_row_row(%arg0: memref<20x20xf32, #gpu.address_sp
// CHECK: [[b_frag1:%.+]] = vector.insert [[b_el1]], {{.*}} : f32 into vector<2x1xf32>
// CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag1]], [[c_frag]])
- // CHECK-SAME: mmaShape = [16, 8, 8]
+ // CHECK-SAME: mmaShape = array<i64: 16, 8, 8>
// CHECK-SAME: -> vector<2x2xf32>
%A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<16x8xf32>
%B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<8x8xf32>
@@ -557,7 +557,7 @@ func.func @m16n8k8_tf32_f32_col_col_row(%arg0: memref<20x20xf32, #gpu.address_sp
// CHECK: [[b_frag:%.+]] = nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}
// CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
- // CHECK-SAME: mmaShape = [16, 8, 8]
+ // CHECK-SAME: mmaShape = array<i64: 16, 8, 8>
// CHECK-SAME: -> vector<2x2xf32>
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<16x8xf32>
%B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space<workgroup>>, vector<8x8xf32>
@@ -628,7 +628,7 @@ func.func @m16n8k64_int4_row_col_row(%arg0: memref<128x128xi4, #gpu.address_spac
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi4, #gpu.address_space<workgroup>>, vector<16x64xi4>
%B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi4, #gpu.address_space<workgroup>>, vector<8x64xi4>
%C = vector.transfer_read %arg2[%c0, %c0], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
- // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+ // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 64>} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x64xi4>, vector<8x64xi4> into vector<16x8xi32>
// CHECK: [[lane:%.+]] = gpu.lane_id
@@ -698,7 +698,7 @@ func.func @m16n8k32_int8_row_col_row(%arg0: memref<128x128xi8, #gpu.address_spac
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, #gpu.address_space<workgroup>>, vector<16x32xi8>
%B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, #gpu.address_space<workgroup>>, vector<8x32xi8>
%C = vector.transfer_read %arg2[%c0, %c0], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
- // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
// CHECK: [[lane:%.+]] = gpu.lane_id
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index f735e3f8cc623..0376a520bfaa7 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -58,81 +58,88 @@ func.func @ldmatrix_type_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x2xf16> {
func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{expected 256 warp-wide matrix A elements}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{expected 128 warp-wide matrix B elements}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> {
// expected-error @+1 {{expected 128 warp-wide matrix C elements}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>
return %d : vector<2x4xf16>
}
// -----
func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{expected matrix A to be shaped (4 x 2)}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
func.func @m16n8k16_fp16_tf32Enabled(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{expected tf32 tensor cores only for F32 operands}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16], tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>, tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{expected 128 warp-wide matrix A elements}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----
func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{expected matrix A to be shaped (4 x 1)}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----
func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
// expected-error @+1 {{expected 32 warp-wide matrix A elements}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
return %d : vector<1x2xf64>
}
// -----
func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> {
// expected-error @+1 {{expected matrix C to be shaped (1 x 2)}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>
return %d : vector<2x1xf64>
}
// -----
func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
// expected-error @+1 {{expected 256 warp-wide matrix B elements}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
// -----
func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
// expected-error @+1 {{op failed to verify that matrixA and matrixB have same element type}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
// -----
+func.func @mma_sync_check_mmaShape(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ // expected-error @+1 {{mmaShape should be three integers}}
+ %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
+// -----
+
func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
// expected-error @below {{destination memref must have a memory space attribute of IntegerAttr(3) or gpu::AddressSpaceAttr(Workgroup)}}
nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xf32>
@@ -189,13 +196,25 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
%arg2: vector<2x2xf16>,
%arg3: vector<2xi16>) -> vector<2x2xf16> {
// expected-error @+1 {{'nvgpu.mma.sp.sync' op sparsity selector should be 0 or 1}}
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16], sparsitySelector = 42 : i32} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>, sparsitySelector = 42 : i32} :
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
+func.func @mma_sp_sync_check_mmaShape(%arg0: vector<4x2xf16>,
+ %arg1: vector<4x2xf16>,
+ %arg2: vector<2x2xf16>,
+ %arg3: vector<2xi16>) -> vector<2x2xf16> {
+ // expected-error @+1 {{mmaShape should be three integers}}
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8>} :
+ (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
+
+// -----
+
func.func @async_cp_zfill_f32_align1(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
// expected-error @+1 {{'nvgpu.device_async_copy' op bypassL1 does not satify alignment for 'memref<3x16x128xf32, 3>' with destination element 1. Unset bypassL1, or set destination element to 4}}
@@ -359,7 +378,7 @@ func.func @rcp_unsupported_ftz(%in : vector<16xf32>) {
func.func @check_matrixA_dim(%arg0: vector<16xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{matrixA must be 2 dimensional vector}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<16xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<16xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -367,7 +386,7 @@ func.func @check_matrixA_dim(%arg0: vector<16xf16>, %arg1: vector<2x2xf16>, %arg
func.func @check_matrixB_dim(%arg0: vector<4x4xf16>, %arg1: vector<4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{matrixB must be 2 dimensional vector}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x4xf16>, vector<4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -375,7 +394,7 @@ func.func @check_matrixB_dim(%arg0: vector<4x4xf16>, %arg1: vector<4xf16>, %arg2
func.func @check_matrixC_dim(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<4xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{matrixC must be 2 dimensional vector}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
diff --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
index 6d0cf348273c3..1a9151c4c8a36 100644
--- a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
+++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
@@ -4,7 +4,7 @@
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK: nvgpu.mma.sync
// CHECK-SAME: tf32Enabled
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
@@ -14,7 +14,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK: nvgpu.mma.sync
// CHECK-SAME: tf32Enabled
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----
@@ -24,6 +24,6 @@ func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: v
// CHECK-NOT: tf32Enabled
// CHECK: return
func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
diff --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
index 7db60254d2ec8..6fae05a252887 100644
--- a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
+++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
@@ -3,7 +3,7 @@
// CHECK-LABEL: m16n8k4_tf32
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
@@ -12,7 +12,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
// CHECK-LABEL: m16n8k8_tf32
func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----
diff --git a/mlir/test/Dialect/NVGPU/roundtrip.mlir b/mlir/test/Dialect/NVGPU/roundtrip.mlir
index ad516b4d2c200..667c89b1f6c49 100644
--- a/mlir/test/Dialect/NVGPU/roundtrip.mlir
+++ b/mlir/test/Dialect/NVGPU/roundtrip.mlir
@@ -13,8 +13,8 @@ func.func @ldmatrix(%arg0: memref<?x?xf16, 3>, %x: index, %y: index) {
func.func @mma_sync(%arg0: vector<4x2xf16>,
%arg1: vector<2x2xf16>,
%arg2: vector<2x2xf16>) -> vector<2x2xf16> {
-// CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
- %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} :
+// CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -25,9 +25,9 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
%arg2: vector<2x2xf16>,
%arg3: vector<2xi16>) -> vector<2x2xf16> {
// CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
- // CHECK-SAME: mmaShape = [16, 8, 32]
+ // CHECK-SAME: mmaShape = array<i64: 16, 8, 32>
// CHECK-SAME: (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 32>} :
(vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -38,9 +38,9 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
%arg2: vector<2x2xf16>,
%arg3: vector<2xi16>) -> vector<2x2xf16> {
// CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
- // CHECK-SAME: mmaShape = [16, 8, 16]
+ // CHECK-SAME: mmaShape = array<i64: 16, 8, 16>
// CHECK-SAME: (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>} :
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -51,9 +51,9 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
%arg2: vector<2x2xi32>,
%arg3: vector<2xi16>) -> vector<2x2xi32> {
// CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
- // CHECK-SAME: mmaShape = [16, 8, 64]
+ // CHECK-SAME: mmaShape = array<i64: 16, 8, 64>
// CHECK-SAME: (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 64>} :
(vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
index bbe27fe1b99d9..dd9d43e1be795 100644
--- a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -48,7 +48,7 @@ func.func @matmul_16x8x4xf32_global(
// CHECK: %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[RES:.*]] = vector.insert %[[VAL_29]], %[[VAL_33]] [1, 1] : f32 into vector<2x2xf32>
//
-// CHECK: %[[VAL_35:.*]] = nvgpu.mma.sync(%[[LHS]], %[[RHS]], %[[RES]]) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+// CHECK: %[[VAL_35:.*]] = nvgpu.mma.sync(%[[LHS]], %[[RHS]], %[[RES]]) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
//
// CHECK: %[[VAL_36:.*]] = vector.extract %[[VAL_35]][0, 0] : f32 from vector<2x2xf32>
// CHECK: %[[VAL_37:.*]] = vector.extract %[[VAL_35]][0, 1] : f32 from vector<2x2xf32>
@@ -96,7 +96,7 @@ func.func @matmul_16x8x16xf16_global(
// CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf16>
// CHECK-COUNT-4: vector.insert {{.*}} : f16 into vector<2x2xf16>
//
- // CHECK: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 16]}
+ // CHECK: nvgpu.mma.sync(%{{.*}}) {mmaShape = array<i64: 16, 8, 16>}
// CHECK-SAME: : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
//
// CHECK-COUNT-4: vector.extract %{{.*}} : f16 from vector<2x2xf16>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
index d8d7c1c39db91..00338c6486413 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
@@ -194,7 +194,7 @@ module attributes {gpu.container_module} {
// within each group of four threads contribute metadata.
%d = nvgpu.mma.sp.sync(%A_data, %B_data, %accum)
metadata(%meta)
- {mmaShape = [16, 8, 32]} : (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ {mmaShape = array<i64: 16, 8, 32>} : (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
//===----------------------------------------------------------------------===//
// Write back results to gpu global memory
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
index 2eef2ff8f3564..46ce7e8b55b2f 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
@@ -3,7 +3,7 @@
// RUN: | FileCheck %s --check-prefix=CHECK-MMA-SYNC
// CHECK-MMA-SYNC-LABEL: func @main() {
-// CHECK-MMA-SYNC: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 4], tf32Enabled}
+// CHECK-MMA-SYNC: nvgpu.mma.sync(%{{.*}}) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled}
// CHECK-MMA-SYNC-SAME: : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
// Tested to run locally in 1.7s.
More information about the Mlir-commits
mailing list