[Mlir-commits] [mlir] aa2dc79 - [mlir][vector] Fix rewrite pattern API violation in `VectorToSCF` (#77909)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 12 04:44:58 PST 2024


Author: Matthias Springer
Date: 2024-01-12T13:44:54+01:00
New Revision: aa2dc792abd5f6b061e277607722b9b773ce2178

URL: https://github.com/llvm/llvm-project/commit/aa2dc792abd5f6b061e277607722b9b773ce2178
DIFF: https://github.com/llvm/llvm-project/commit/aa2dc792abd5f6b061e277607722b9b773ce2178.diff

LOG: [mlir][vector] Fix rewrite pattern API violation in `VectorToSCF` (#77909)

A rewrite pattern is not allowed to change the IR if it returns
"failure". This commit fixes
`test/Conversion/VectorToSCF/vector-to-scf.mlir` when running with
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

```
Processing operation : 'vector.transfer_read'(0x55823a409a60) {
  %5 = "vector.transfer_read"(%arg0, %0, %0, %2, %4) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 1>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<?x4xf32>, index, index, f32, vector<[4]x4xi1>) -> vector<[4]x4xf32>

  * Pattern (anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion : 'vector.transfer_read -> ()' {
Trying to match "(anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion"
    ** Insert  : 'vector.splat'(0x55823a445640)
"(anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion" result 0
  } -> failure : pattern failed to match

LLVM ERROR: pattern returned failure but IR did change
```

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index a1aff1ab36a52b..44fbac1935fed7 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1060,10 +1060,10 @@ struct UnrollTransferReadConversion
     setHasBoundedRewriteRecursion();
   }
 
-  /// Return the vector into which the newly created TransferReadOp results
-  /// are inserted.
-  Value getResultVector(TransferReadOp xferOp,
-                        PatternRewriter &rewriter) const {
+  /// Get or build the vector into which the newly created TransferReadOp
+  /// results are inserted.
+  Value buildResultVector(PatternRewriter &rewriter,
+                          TransferReadOp xferOp) const {
     if (auto insertOp = getInsertOp(xferOp))
       return insertOp.getDest();
     Location loc = xferOp.getLoc();
@@ -1098,24 +1098,27 @@ struct UnrollTransferReadConversion
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
     if (xferOp.getVectorType().getRank() <= options.targetRank)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          xferOp, "vector rank is less or equal to target rank");
     if (isTensorOp(xferOp) && !options.lowerTensors)
-      return failure();
+      return rewriter.notifyMatchFailure(
+          xferOp, "transfers operating on tensors are excluded");
     // Transfer ops that modify the element type are not supported atm.
     if (xferOp.getVectorType().getElementType() !=
         xferOp.getShapedType().getElementType())
-      return failure();
-
-    auto insertOp = getInsertOp(xferOp);
-    auto vec = getResultVector(xferOp, rewriter);
-    auto vecType = dyn_cast<VectorType>(vec.getType());
+      return rewriter.notifyMatchFailure(
+          xferOp, "not yet supported: element type mismatch");
     auto xferVecType = xferOp.getVectorType();
-
     if (xferVecType.getScalableDims()[0]) {
       // Cannot unroll a scalable dimension at compile time.
-      return failure();
+      return rewriter.notifyMatchFailure(
+          xferOp, "scalable dimensions cannot be unrolled");
     }
 
+    auto insertOp = getInsertOp(xferOp);
+    auto vec = buildResultVector(rewriter, xferOp);
+    auto vecType = dyn_cast<VectorType>(vec.getType());
+
     VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
 
     int64_t dimSize = xferVecType.getShape()[0];


        


More information about the Mlir-commits mailing list