[Mlir-commits] [mlir] [mlir][vector] Fix crash on invalid `permutation_map` (PR #74925)
Rik Huijzer
llvmlistbot at llvm.org
Sat Dec 9 02:50:12 PST 2023
https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/74925
Without this patch, MLIR crashes with
```
Assertion failed: (getNumDims() == map.getNumResults() && "Number of results mismatch"), function compose, file AffineMap.cpp, line 537.
```
during parsing.
>From edb5a7f00b93fef504940205bf3ed46df7b16afb Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sat, 9 Dec 2023 11:46:05 +0100
Subject: [PATCH] [mlir][vector] Fix crash on invalid `permutation_map`
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++++++++++
mlir/test/Dialect/Vector/invalid.mlir | 22 ++++++++++++++++++++++
2 files changed, 32 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c462b23e1133f..540959b486db9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3815,6 +3815,11 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
return parser.emitError(
maskInfo.location, "does not support masks with vector element type");
+ if (vectorType.getRank() != permMap.getNumResults()) {
+ return parser.emitError(typesLoc,
+ "expected the same rank for the vector and the "
+ "results of the permutation map");
+ }
// Instead of adding the mask type as an op type, compute it based on the
// vector type and the permutation map (to keep the type signature small).
auto maskType = inferTransferOpMaskType(vectorType, permMap);
@@ -4181,6 +4186,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
return parser.emitError(
maskInfo.location, "does not support masks with vector element type");
+ if (vectorType.getRank() != permMap.getNumResults()) {
+ return parser.emitError(typesLoc,
+ "expected the same rank for the vector and the "
+ "results of the permutation map");
+ }
auto maskType = inferTransferOpMaskType(vectorType, permMap);
if (parser.resolveOperand(maskInfo, maskType, result.operands))
return failure();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index edb2689364a98..ad248d1e14e72 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -332,6 +332,28 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
// -----
+#map1 = affine_map<(d0, d1, d2) -> (d0, 0, 0)>
+func.func @main(%m: memref<1xi32>, %2: vector<1x32xi1>) -> vector<1x32xi32> {
+ %0 = arith.constant 1 : index
+ %1 = arith.constant 1 : i32
+ // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+ %3 = vector.transfer_read %m[%0], %1, %2 { permutation_map = #map1 } : memref<1xi32>, vector<1x32xi32>
+ return %3 : vector<1x32xi32>
+}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, 0, 0)>
+func.func @test_vector.transfer_write(%m: memref<1xi32>, %2: vector<1x32xi32>) -> vector<1x32xi32> {
+ %0 = arith.constant 1 : index
+ %1 = arith.constant 1 : i32
+ // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+ %3 = vector.transfer_write %2, %m[%0], %1 { permutation_map = #map1 } : vector<1x32xi32>, memref<1xi32>
+ return %3 : vector<1x32xi32>
+}
+
+// -----
+
func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
More information about the Mlir-commits
mailing list