[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for gather over a strided memref (PR #72991)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Nov 21 05:28:31 PST 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/72991

This patch adds a rewrite pattern for `vector.gather` over a strided memref like the following:

```mlir
%subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] :
    memref<100x3xf32> to memref<100xf32, strided<[3]>>
%gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst :
    memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
```

```mlir
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
    memref<100x3xf32> into memref<300xf32>
%1 = arith.muli %arg3, %cst : vector<4xindex>
%gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 :
    memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
```

Fixes https://github.com/openxla/iree/issues/15364.

>From b1fbe79c0252a03d3f337544140549e1ce80436c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 21 Nov 2023 13:13:03 +0000
Subject: [PATCH] [mlir][Vector] Add a rewrite pattern for gather over a
 strided memref

This patch adds a rewrite pattern for `vector.gather` over a strided
memref like the following:

```mlir
%subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] :
    memref<100x3xf32> to memref<100xf32, strided<[3]>>
%gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst :
    memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
```

```mlir
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
    memref<100x3xf32> into memref<300xf32>
%1 = arith.muli %arg3, %cst : vector<4xindex>
%gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 :
    memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
```

Fixes https://github.com/openxla/iree/issues/15364.
---
 .../Vector/Transforms/LowerVectorGather.cpp   | 59 ++++++++++++++-
 .../Vector/vector-gather-lowering.mlir        | 74 +++++++++++++++++++
 2 files changed, 131 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 152aefa65effc3d..0f26b0241806d55 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -96,6 +96,61 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
   }
 };
 
+/// TODO: Document what this pattern does
+struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    Value base = op.getBase();
+
+    if (!base.getDefiningOp())
+      return failure();
+
+    // TODO: Strided accesses might be coming from other ops as well
+    auto subview = dyn_cast<memref::SubViewOp>(base.getDefiningOp());
+    if (!subview)
+      return failure();
+
+    // Get strides
+    auto layout = subview.getResult().getType().getLayout();
+    auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
+
+    // TODO: Allow the access to be strided in multiple dimensions.
+    if (stridedLayoutAttr.getStrides().size() != 1)
+      return failure();
+
+    // Make sure that the stride matches the trailing dimension.
+    // TODO: Is this really needed?
+    int64_t lastDim = subview.getSource().getType().getShape().back();
+    if (stridedLayoutAttr.getStrides()[0] != lastDim)
+      return failure();
+
+    // Collapse the input memref and update the gather indices to model the
+    // strided access.
+    SmallVector<ReassociationIndices> reassoc = {{0, 1}};
+    Value collapsed = rewriter.create<memref::CollapseShapeOp>(
+        op.getLoc(), subview.getSource(), reassoc);
+
+    auto element = rewriter.getIndexAttr(lastDim);
+    auto vType = op.getIndexVec().getType();
+    Value mulCst = rewriter.create<arith::ConstantOp>(
+        op.getLoc(), vType, DenseElementsAttr::get(vType, element));
+
+    Value newIdxs =
+        rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
+
+    // Create an updated gather op - both the input memref and the indices have
+    // been updated.
+    Value newGather = rewriter.create<vector::GatherOp>(
+        op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
+        newIdxs, op.getMask(), op.getPassThru());
+    rewriter.replaceOp(op, newGather);
+
+    return success();
+  }
+};
+
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
@@ -168,6 +223,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
 void mlir::vector::populateVectorGatherLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
-                                                          benefit);
+  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
+               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 026bec8cd65d3f5..f08e3b80008378d 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -151,3 +151,77 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
+
+// Check that vector.gather of a strided memref is replaced with a
+// vector.gather with indices encoding the original strides.
+#map = affine_map<()[s0] -> (s0 * 4096)>
+#map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)>
+func.func @example_1(%arg0 : memref<100x3xf32>, %arg0_1 : memref<100xf32>, %arg1: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) {
+  %c0 = arith.constant 0 : index
+  %x_1 = affine.apply #map()[%x]
+  %subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+  %cst_0 = arith.constant dense<true> : vector<4xi1>
+  %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %7 = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  %subview_1 = memref.subview %arg1[%x_1] [%y] [1] : memref<518400xf32> to memref<?xf32, strided<[1], offset: ?>>
+  vector.store %7, %subview_1[%c0] : memref<?xf32, strided<[1], offset: ?>>, vector<4xf32>
+  return
+}
+// CHECK-LABEL:   func.func @example_1(
+// CHECK-SAME:                         %[[VAL_0:.*]]: memref<100x3xf32>,
+// CHECK-SAME:                         %[[VAL_1:.*]]: memref<100xf32>,
+// CHECK-SAME:                         %[[VAL_2:.*]]: memref<518400xf32>,
+// CHECK-SAME:                         %[[VAL_3:.*]]: vector<4xindex>,
+// CHECK-SAME:                         %[[VAL_4:.*]]: index,
+// CHECK-SAME:                         %[[VAL_5:.*]]: index) {
+// CHECK:           %[[VAL_6:.*]] = arith.constant dense<3> : vector<4xindex>
+// CHECK:           %[[VAL_7:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_8:.*]] = arith.constant dense<true> : vector<4xi1>
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_10:.*]] = affine.apply #{{.*}}(){{\[}}%[[VAL_4]]]
+// CHECK:           %[[VAL_11:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : vector<4xindex>
+// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_8]][0] : i1 from vector<4xi1>
+// CHECK:           %[[VAL_14:.*]] = vector.extract %[[VAL_12]][0] : index from vector<4xindex>
+// CHECK:           %[[VAL_15:.*]] = scf.if %[[VAL_13]] -> (vector<4xf32>) {
+// CHECK:             %[[VAL_16:.*]] = vector.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[VAL_17:.*]] = vector.extract %[[VAL_16]][0] : f32 from vector<1xf32>
+// CHECK:             %[[VAL_18:.*]] = vector.insert %[[VAL_17]], %[[VAL_7]] [0] : f32 into vector<4xf32>
+// CHECK:             scf.yield %[[VAL_18]] : vector<4xf32>
+// CHECK:           } else {
+// CHECK:             scf.yield %[[VAL_7]] : vector<4xf32>
+// CHECK:           }
+// CHECK:           %[[VAL_19:.*]] = vector.extract %[[VAL_8]][1] : i1 from vector<4xi1>
+// CHECK:           %[[VAL_20:.*]] = vector.extract %[[VAL_12]][1] : index from vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = scf.if %[[VAL_19]] -> (vector<4xf32>) {
+// CHECK:             %[[VAL_22:.*]] = vector.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[VAL_23:.*]] = vector.extract %[[VAL_22]][0] : f32 from vector<1xf32>
+// CHECK:             %[[VAL_24:.*]] = vector.insert %[[VAL_23]], %[[VAL_15]] [1] : f32 into vector<4xf32>
+// CHECK:             scf.yield %[[VAL_24]] : vector<4xf32>
+// CHECK:           } else {
+// CHECK:             scf.yield %[[VAL_15]] : vector<4xf32>
+// CHECK:           }
+// CHECK:           %[[VAL_25:.*]] = vector.extract %[[VAL_8]][2] : i1 from vector<4xi1>
+// CHECK:           %[[VAL_26:.*]] = vector.extract %[[VAL_12]][2] : index from vector<4xindex>
+// CHECK:           %[[VAL_27:.*]] = scf.if %[[VAL_25]] -> (vector<4xf32>) {
+// CHECK:             %[[VAL_28:.*]] = vector.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[VAL_29:.*]] = vector.extract %[[VAL_28]][0] : f32 from vector<1xf32>
+// CHECK:             %[[VAL_30:.*]] = vector.insert %[[VAL_29]], %[[VAL_21]] [2] : f32 into vector<4xf32>
+// CHECK:             scf.yield %[[VAL_30]] : vector<4xf32>
+// CHECK:           } else {
+// CHECK:             scf.yield %[[VAL_21]] : vector<4xf32>
+// CHECK:           }
+// CHECK:           %[[VAL_31:.*]] = vector.extract %[[VAL_8]][3] : i1 from vector<4xi1>
+// CHECK:           %[[VAL_32:.*]] = vector.extract %[[VAL_12]][3] : index from vector<4xindex>
+// CHECK:           %[[VAL_33:.*]] = scf.if %[[VAL_31]] -> (vector<4xf32>) {
+// CHECK:             %[[VAL_34:.*]] = vector.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : f32 from vector<1xf32>
+// CHECK:             %[[VAL_36:.*]] = vector.insert %[[VAL_35]], %[[VAL_27]] [3] : f32 into vector<4xf32>
+// CHECK:             scf.yield %[[VAL_36]] : vector<4xf32>
+// CHECK:           } else {
+// CHECK:             scf.yield %[[VAL_27]] : vector<4xf32>
+// CHECK:           }
+// CHECK:           %[[VAL_37:.*]] = memref.subview %[[VAL_2]]{{\[}}%[[VAL_10]]] {{\[}}%[[VAL_5]]] [1] : memref<518400xf32> to memref<?xf32, strided<[1], offset: ?>>
+// CHECK:           vector.store %[[VAL_33]], %[[VAL_37]]{{\[}}%[[VAL_9]]] : memref<?xf32, strided<[1], offset: ?>>, vector<4xf32>
+// CHECK:           return
+// CHECK:         }



More information about the Mlir-commits mailing list