[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for gather over a strided memref (PR #72991)
Cullen Rhodes
llvmlistbot at llvm.org
Thu Nov 30 02:18:38 PST 2023
================
@@ -96,6 +96,89 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
}
};
+/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
+/// MemRef with updated indices that model the strided access.
+///
+/// ```mlir
+/// %subview = memref.subview %M (...)
+/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// ```
+/// ==>
+/// ```mlir
+/// %collapse_shape = memref.collapse_shape %M (...)
+/// : memref<100x3xf32> into memref<300xf32>
+/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
+/// %gather = vector.gather %collapse_shape[%new_idxs] (...)
+/// : memref<300xf32> (...)
+/// ```
+///
+/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
+/// but should be fairly straightforward to extend beyond that.
+struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::GatherOp op,
+ PatternRewriter &rewriter) const override {
+ Value base = op.getBase();
+ if (!base.getDefiningOp())
+ return failure();
----------------
c-rhodes wrote:
this can be removed now
https://github.com/llvm/llvm-project/pull/72991
More information about the Mlir-commits
mailing list