[Mlir-commits] [mlir] [mlir][nvgpu] Fix crash when mmaShape size is not three (PR #173490)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 24 04:47:05 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR fixes a crash when the mmaShape attribute of `nvgpu.mma.sync` does not have exactly three elements. The change replaces the ArrayAttr-based mmaShape with DenseI64ArrayAttr and adds verifier checks to ensure the attribute has three elements. Fixes #<!-- -->173378.
---
Patch is 51.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/173490.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td (+5-19)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-4)
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+2-2)
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+15-11)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+12-12)
- (modified) mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir (+2-2)
- (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir (+10-10)
- (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+34-15)
- (modified) mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir (+3-3)
- (modified) mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir (+2-2)
- (modified) mlir/test/Dialect/NVGPU/roundtrip.mlir (+8-8)
- (modified) mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir (+2-2)
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir (+1-1)
- (modified) mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
index 73d86283a5940..5b9ae8bb7a518 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
@@ -66,16 +66,6 @@ class NVGPU_MmaSyncOp<string mnemonic> :
NVGPU_Op<mnemonic, [Pure,
PredOpTrait<"matrixA and matrixB have same element type",
TCopVTEtIsSameAs<0, 1>>]> {
- code extraBaseClassDeclaration = [{
- std::array<int64_t, 3> getMmaShapeAsArray() {
- ArrayAttr mmaShape = this->getMmaShape();
- assert(mmaShape.size() == 3 && "mmaShape should be three integers");
- return {::llvm::cast<IntegerAttr>(mmaShape[0]).getInt(),
- ::llvm::cast<IntegerAttr>(mmaShape[1]).getInt(),
- ::llvm::cast<IntegerAttr>(mmaShape[2]).getInt()};
- }
- }];
-
let hasVerifier = 1;
}
@@ -96,14 +86,14 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
Example:
```mlir
- %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} :
+ %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = array<i64: 16, 8, 16>} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
```
}];
let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
AnyVectorOfNonZeroRank:$matrixB,
AnyVectorOfNonZeroRank:$matrixC,
- I64ArrayAttr:$mmaShape,
+ DenseI64ArrayAttr:$mmaShape,
OptionalAttr<UnitAttr>:$tf32Enabled);
let results = (outs AnyVectorOfNonZeroRank:$res);
@@ -112,7 +102,7 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
OpBuilder<(ins "Value":$matrixA,
"Value":$matrixB,
"Value":$matrixC,
- "ArrayAttr":$mmaShape)>,
+ "DenseI64ArrayAttr":$mmaShape)>,
OpBuilder<(ins "Value":$matrixA,
"Value":$matrixB,
"Value":$matrixC,
@@ -124,8 +114,6 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
`(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
-
- let extraClassDeclaration = extraBaseClassDeclaration;
}
def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>,
@@ -151,7 +139,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction):
```mlir
- nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
+ nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = array<i64: 16, 8, 32>} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
```
}];
@@ -160,7 +148,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
AnyVectorOfNonZeroRank:$matrixB,
AnyVectorOfNonZeroRank:$matrixC,
NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
- I64ArrayAttr:$mmaShape,
+ DenseI64ArrayAttr:$mmaShape,
DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
OptionalAttr<UnitAttr>:$tf32Enabled
);
@@ -179,8 +167,6 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
`(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
-
- let extraClassDeclaration = extraBaseClassDeclaration;
}
def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..7cd59b0f135fe 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -340,7 +340,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
VectorType bType = op.getMatrixA().getType();
VectorType cType = op.getMatrixC().getType();
- std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
+ ArrayRef<int64_t> gemmShape = op.getMmaShape();
// Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
@@ -485,7 +485,7 @@ static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
/// it's expected that the provided parameters correspond to a valid
/// instruction.
static std::string buildMmaSparseAsmString(
- const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
+ ArrayRef<int64_t> shape, unsigned matASize, unsigned matBSize,
unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
@@ -526,7 +526,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
- int64_t metadataSelector, const std::array<int64_t, 3> &shape,
+ int64_t metadataSelector, ArrayRef<int64_t> shape,
Type intrinsicResultType) {
auto asmDialectAttr =
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
@@ -618,7 +618,7 @@ struct NVGPUMmaSparseSyncLowering
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
- matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
+ matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShape(),
intrinsicResTy);
if (failed(intrinsicResult))
return failure();
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 98434357f826f..d4aa893a5b420 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1058,8 +1058,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
- Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
- rewriter.getI64ArrayAttr({m, n, k}));
+ Value matmul =
+ nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC, {m, n, k});
valueMapping[op.getResult()] = matmul;
return success();
}
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 237aab4d7f309..eb2de91d30988 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -119,7 +119,8 @@ LogicalResult DeviceAsyncCopyOp::verify() {
//===----------------------------------------------------------------------===//
void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, Value matrixA,
- Value matrixB, Value matrixC, ArrayAttr mmaShape) {
+ Value matrixB, Value matrixC,
+ DenseI64ArrayAttr mmaShape) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
mmaShape, UnitAttr());
}
@@ -129,8 +130,7 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
bool tf32Enabled) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
- odsBuilder.getI64ArrayAttr(mmaShape),
- tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
+ mmaShape, tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
}
/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
@@ -138,7 +138,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
TypedValue<VectorType> matrixA,
TypedValue<VectorType> matrixB,
TypedValue<VectorType> matrixC,
- const std::array<int64_t, 3> &mmaShape,
+ ArrayRef<int64_t> mmaShape,
bool tf32Enabled, bool sparse = false) {
// The verification for mma.sync covering various shapes and data types is
// based on the fundamental tensor core shape.
@@ -209,7 +209,12 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
return op->emitError() << "matrixC must be 2 dimensional vector";
}
- auto [m, n, k] = mmaShape;
+ if (mmaShape.size() != 3) {
+ return op->emitError() << "mmaShape should be three integers";
+ }
+ int64_t m = mmaShape[0];
+ int64_t n = mmaShape[1];
+ int64_t k = mmaShape[2];
// verify warp-wide size for vector a
int64_t sparseFactor = sparse ? 2 : 1;
@@ -262,7 +267,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
LogicalResult MmaSyncOp::verify() {
return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
- getMatrixC(), getMmaShapeAsArray(),
+ getMatrixC(), getMmaShape(),
getOperation()->hasAttr(getTf32EnabledAttrName()));
}
@@ -274,17 +279,16 @@ void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
Value matrixB, Value matrixC, Value sparseMetadata,
ArrayRef<int64_t> mmaShape) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
- sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
+ sparseMetadata, mmaShape, 0, UnitAttr());
}
LogicalResult MmaSparseSyncOp::verify() {
unsigned sparsitySelector = getSparsitySelector();
if (sparsitySelector > 1)
return emitOpError() << "sparsity selector should be 0 or 1";
- return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
- getMatrixC(), getMmaShapeAsArray(),
- getOperation()->hasAttr(getTf32EnabledAttrName()),
- true);
+ return verifyMmaSyncOp(
+ this->getOperation(), getMatrixA(), getMatrixB(), getMatrixC(),
+ getMmaShape(), getOperation()->hasAttr(getTf32EnabledAttrName()), true);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0eb44789fe31d..74bf571b4a64e 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -14,7 +14,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
// CHECK-NOT: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -31,7 +31,7 @@ func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %a
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
// CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -59,7 +59,7 @@ func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: v
// CHECK-NOT: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 8>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -90,7 +90,7 @@ func.func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: ve
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -109,7 +109,7 @@ func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vect
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -134,7 +134,7 @@ func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vect
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 64>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 64>} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -145,7 +145,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
// CHECK-SAME: shape = #nvvm.shape<m = 8, n = 8, k = 4>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
// CHECK: llvm.mlir.poison : vector<2xf64>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
@@ -201,7 +201,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -370,7 +370,7 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 32>} :
(vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue %[[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -406,7 +406,7 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>} :
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -427,7 +427,7 @@ func.func @mma_sp_sync_f16_16816_01(%arg0: vector<2x2xf16>,
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3)
- {mmaShape = [16, 8, 16], sparsitySelector = 1 : i32} :
+ {mmaShape = array<i64: 16, 8, 16>, sparsitySelector = 1 : i32} :
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -465,7 +465,7 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 64>} :
(vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
diff --git a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
index 0afaa19d59d15..c42f5add697f0 100644
--- a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
@@ -29,7 +29,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
%B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
%C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
- // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
%D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
@@ -38,7 +38,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
%B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/173490
More information about the Mlir-commits
mailing list