[Mlir-commits] [mlir] a383817 - [mlir][Vector] Add a rewrite pattern for gather over a strided memref (#72991)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 08:33:26 PST 2023


Author: Andrzej WarzyƄski
Date: 2023-11-30T16:33:20Z
New Revision: a383817b7e24d948dd5e342e8df8d12d0f15d536

URL: https://github.com/llvm/llvm-project/commit/a383817b7e24d948dd5e342e8df8d12d0f15d536
DIFF: https://github.com/llvm/llvm-project/commit/a383817b7e24d948dd5e342e8df8d12d0f15d536.diff

LOG: [mlir][Vector] Add a rewrite pattern for gather over a strided memref (#72991)

This patch adds a rewrite pattern for `vector.gather` over a strided
memref like the following:

```mlir
%subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] :
    memref<100x3xf32> to memref<100xf32, strided<[3]>>
%gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst :
    memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
```

After the pattern added in this patch:
```mlir
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
    memref<100x3xf32> into memref<300xf32>
%1 = arith.muli %arg3, %cst : vector<4xindex>
%gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 :
    memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
```

Fixes https://github.com/openxla/iree/issues/15364.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
    mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 152aefa65effc3d..90128126d0fa102 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -96,6 +96,87 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
   }
 };
 
+/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
+/// MemRef with updated indices that model the strided access.
+///
+/// ```mlir
+///   %subview = memref.subview %M (...)
+///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// ```
+/// ==>
+/// ```mlir
+///   %collapse_shape = memref.collapse_shape %M (...)
+///     : memref<100x3xf32> into memref<300xf32>
+///   %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
+///   %gather = vector.gather %collapse_shape[%new_idxs] (...)
+///     : memref<300xf32> (...)
+/// ```
+///
+/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
+/// but should be fairly straightforward to extend beyond that.
+struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    Value base = op.getBase();
+
+    // TODO: Strided accesses might be coming from other ops as well
+    auto subview = base.getDefiningOp<memref::SubViewOp>();
+    if (!subview)
+      return failure();
+
+    auto sourceType = subview.getSource().getType();
+
+    // TODO: Allow ranks > 2.
+    if (sourceType.getRank() != 2)
+      return failure();
+
+    // Get strides
+    auto layout = subview.getResult().getType().getLayout();
+    auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
+    if (!stridedLayoutAttr)
+      return failure();
+
+    // TODO: Allow the access to be strided in multiple dimensions.
+    if (stridedLayoutAttr.getStrides().size() != 1)
+      return failure();
+
+    int64_t srcTrailingDim = sourceType.getShape().back();
+
+    // Assume that the stride matches the trailing dimension of the source
+    // memref.
+    // TODO: Relax this assumption.
+    if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
+      return failure();
+
+    // 1. Collapse the input memref so that it's "flat".
+    SmallVector<ReassociationIndices> reassoc = {{0, 1}};
+    Value collapsed = rewriter.create<memref::CollapseShapeOp>(
+        op.getLoc(), subview.getSource(), reassoc);
+
+    // 2. Generate new gather indices that will model the
+    // strided access.
+    IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
+    VectorType vType = op.getIndexVec().getType();
+    Value mulCst = rewriter.create<arith::ConstantOp>(
+        op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
+
+    Value newIdxs =
+        rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
+
+    // 3. Create an updated gather op with the collapsed input memref and the
+    // updated indices.
+    Value newGather = rewriter.create<vector::GatherOp>(
+        op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
+        newIdxs, op.getMask(), op.getPassThru());
+    rewriter.replaceOp(op, newGather);
+
+    return success();
+  }
+};
+
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
@@ -115,6 +196,16 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
     Value condMask = op.getMask();
     Value base = op.getBase();
+
+    // vector.load requires the most minor memref dim to have unit stride
+    if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+      if (auto stridesAttr =
+              dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
+        if (stridesAttr.getStrides().back() != 1)
+          return failure();
+      }
+    }
+
     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
         loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
         op.getIndexVec());
@@ -168,6 +259,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
 void mlir::vector::populateVectorGatherLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
-                                                          benefit);
+  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
+               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 026bec8cd65d3f5..d047ac629d87ea6 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -151,3 +151,58 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
+
+// Check that vector.gather of a strided memref is replaced with a
+// vector.gather with indices encoding the original strides. Note that multiple
+// patterns are run for this example, e.g.:
+  //  1. "remove stride from gather source"
+  //  2. "flatten gather"
+// However, the main goal is to the test Pattern 1 above.
+#map = affine_map<()[s0] -> (s0 * 4096)>
+func.func @strided_gather(%base : memref<100x3xf32>,
+                          %idxs : vector<4xindex>,
+                          %x : index, %y : index) -> vector<4xf32> {
+  %c0 = arith.constant 0 : index
+  %x_1 = affine.apply #map()[%x]
+  // Strided MemRef
+  %subview = memref.subview %base[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+  %mask = arith.constant dense<true> : vector<4xi1>
+  %pass_thru = arith.constant dense<0.000000e+00> : vector<4xf32>
+  // Gather of a strided MemRef
+  %res = vector.gather %subview[%c0] [%idxs], %mask, %pass_thru : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %res : vector<4xf32>
+}
+// CHECK-LABEL:   func.func @strided_gather(
+// CHECK-SAME:                         %[[base:.*]]: memref<100x3xf32>,
+// CHECK-SAME:                         %[[IDXS:.*]]: vector<4xindex>,
+// CHECK-SAME:                         %[[VAL_4:.*]]: index,
+// CHECK-SAME:                         %[[VAL_5:.*]]: index) -> vector<4xf32> {
+// CHECK:           %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
+// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
+
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
+// CHECK:           %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
+
+// CHECK:           %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK:             %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
+
+// CHECK:           %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK:             %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
+
+// CHECK:           %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_2]] -> (vector<4xf32>)
+// CHECK:             %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
+
+// CHECK:           %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_3]] -> (vector<4xf32>)
+// CHECK:             %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>


        


More information about the Mlir-commits mailing list