[Mlir-commits] [mlir] [mlir][vector] Fix crash on invalid `permutation_map` (PR #74925)

Rik Huijzer llvmlistbot at llvm.org
Sat Dec 9 02:50:12 PST 2023


https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/74925

Without this patch, MLIR crashes with
```
Assertion failed: (getNumDims() == map.getNumResults() && "Number of results mismatch"), function compose, file AffineMap.cpp, line 537.
```
during parsing.


>From edb5a7f00b93fef504940205bf3ed46df7b16afb Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sat, 9 Dec 2023 11:46:05 +0100
Subject: [PATCH] [mlir][vector] Fix crash on invalid `permutation_map`

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++++++++++
 mlir/test/Dialect/Vector/invalid.mlir    | 22 ++++++++++++++++++++++
 2 files changed, 32 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c462b23e1133f..540959b486db9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3815,6 +3815,11 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
     if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
       return parser.emitError(
           maskInfo.location, "does not support masks with vector element type");
+    if (vectorType.getRank() != permMap.getNumResults()) {
+      return parser.emitError(typesLoc,
+                              "expected the same rank for the vector and the "
+                              "results of the permutation map");
+    }
     // Instead of adding the mask type as an op type, compute it based on the
     // vector type and the permutation map (to keep the type signature small).
     auto maskType = inferTransferOpMaskType(vectorType, permMap);
@@ -4181,6 +4186,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
     if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
       return parser.emitError(
           maskInfo.location, "does not support masks with vector element type");
+    if (vectorType.getRank() != permMap.getNumResults()) {
+      return parser.emitError(typesLoc,
+                              "expected the same rank for the vector and the "
+                              "results of the permutation map");
+    }
     auto maskType = inferTransferOpMaskType(vectorType, permMap);
     if (parser.resolveOperand(maskInfo, maskType, result.operands))
       return failure();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index edb2689364a98..ad248d1e14e72 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -332,6 +332,28 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2) -> (d0, 0, 0)>
+func.func @main(%m:  memref<1xi32>, %2: vector<1x32xi1>) -> vector<1x32xi32> {
+  %0 = arith.constant 1 : index
+  %1 = arith.constant 1 : i32
+  // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+  %3 = vector.transfer_read %m[%0], %1, %2 { permutation_map = #map1 } : memref<1xi32>, vector<1x32xi32>
+  return %3 : vector<1x32xi32>
+}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, 0, 0)>
+func.func @test_vector.transfer_write(%m:  memref<1xi32>, %2: vector<1x32xi32>) -> vector<1x32xi32> {
+  %0 = arith.constant 1 : index
+  %1 = arith.constant 1 : i32
+  // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+  %3 = vector.transfer_write %2, %m[%0], %1 { permutation_map = #map1 } : vector<1x32xi32>, memref<1xi32>
+  return %3 : vector<1x32xi32>
+}
+
+// -----
+
 func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32



More information about the Mlir-commits mailing list