[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for gather over a strided memref (PR #72991)

Andrzej Warzyński llvmlistbot at llvm.org
Tue Nov 28 02:42:40 PST 2023


================
@@ -96,6 +96,82 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
   }
 };
 
+/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
+/// MemRef with updated indices that model the strided access.
+///
+/// ```mlir
+/// %subview = memref.subview %M (...) to memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>>
+/// ```
+/// ==>
+/// ```mlir
+/// %collapse_shape = memref.collapse_shape %M (...) into memref<300xf32>
+/// %1 = arith.muli %idxs, %c3 : vector<4xindex>
+/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...)
+/// ```
+///
+/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
+/// but should be fairly straightforward to extend beyond that.
+struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    Value base = op.getBase();
+    if (!base.getDefiningOp())
+      return failure();
+
+    // TODO: Strided accesses might be coming from other ops as well
+    auto subview = dyn_cast<memref::SubViewOp>(base.getDefiningOp());
+    if (!subview)
+      return failure();
----------------
banach-space wrote:

Hm, I'm not sure that that would work here. I'd still need something like this:
```cpp
    auto subview = dyn_cast_if_present<memref::SubViewOp>(base.getDefiningOp());
    if (!subview)
      return failure();
```

In most places `dyn_cast_if_present` would be used like this:
```cpp
    if (auto subview = dyn_cast_if_present<memref::SubViewOp>(base.getDefiningOp())) {
       // some logic here
    }
```

I could do the same, but then I would be unnecessarily indenting the whole function, which is something I'd rather avoid (also trying to follow https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code). Unless I am missing something 🤔 .

https://github.com/llvm/llvm-project/pull/72991


More information about the Mlir-commits mailing list