[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for gather over a strided memref (PR #72991)

Diego Caballero llvmlistbot at llvm.org
Mon Nov 27 06:15:03 PST 2023

@@ -96,6 +96,82 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
+/// MemRef with updated indices that model the strided access.
+/// ```mlir
+/// %subview = memref.subview %M (...) to memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>>
+/// ```
+/// ==>
+/// ```mlir
+/// %collapse_shape = memref.collapse_shape %M (...) into memref<300xf32>
+/// %1 = arith.muli %idxs, %c3 : vector<4xindex>
+/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...)
+/// ```
+/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
+/// but should be fairly straightforward to extend beyond that.
+struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    Value base = op.getBase();
+    if (!base.getDefiningOp())
+      return failure();
+    // TODO: Strided accesses might be coming from other ops as well
+    auto subview = dyn_cast<memref::SubViewOp>(base.getDefiningOp());
+    if (!subview)
+      return failure();
+    // TODO: Allows ranks > 2.
+    if (subview.getSource().getType().getRank() != 2)
+      return failure();
+    // Get strides
+    auto layout = subview.getResult().getType().getLayout();
+    auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
+    // TODO: Allow the access to be strided in multiple dimensions.
+    if (stridedLayoutAttr.getStrides().size() != 1)
+      return failure();
+    int64_t srcTrailingDim = subview.getSource().getType().getShape().back();
+    // Assume that the stride matches the trailing dimension of the source
+    // memref.
+    // TODO: Relax this assumption.
+    if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
+      return failure();
+    // 1. Collapse the input memref so that it's "flat".
+    SmallVector<ReassociationIndices> reassoc = {{0, 1}};
+    Value collapsed = rewriter.create<memref::CollapseShapeOp>(
+        op.getLoc(), subview.getSource(), reassoc);
+    // 2. Generate new gather indices that will model the
+    // strided access.
+    auto stride = rewriter.getIndexAttr(srcTrailingDim);
+    auto vType = op.getIndexVec().getType();
dcaballe wrote:

nit: spell out autos


More information about the Mlir-commits mailing list