[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for gather over a strided memref (PR #72991)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Nov 28 05:07:57 PST 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/72991
>From c65261785efc0a60dc9b457a0831557ca2f62d0f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 21 Nov 2023 13:13:03 +0000
Subject: [PATCH 1/4] [mlir][Vector] Add a rewrite pattern for gather over a
strided memref
This patch adds a rewrite pattern for `vector.gather` over a strided
memref like the following:
```mlir
%subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] :
memref<100x3xf32> to memref<100xf32, strided<[3]>>
%gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst :
memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32>
into vector<4xf32>
```
```mlir
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
memref<100x3xf32> into memref<300xf32>
%1 = arith.muli %arg3, %cst : vector<4xindex>
%gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 :
memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
into vector<4xf32>
```
Fixes https://github.com/openxla/iree/issues/15364.
---
.../Vector/Transforms/LowerVectorGather.cpp | 80 ++++++++++++++++++-
.../Vector/vector-gather-lowering.mlir | 54 +++++++++++++
2 files changed, 132 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 152aefa65effc3d..54b350d7ac3524c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -96,6 +96,82 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
}
};
+/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
+/// MemRef with updated indices that model the strided access.
+///
+/// ```mlir
+/// %subview = memref.subview %M (...) to memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>>
+/// ```
+/// ==>
+/// ```mlir
+/// %collapse_shape = memref.collapse_shape %M (...) into memref<300xf32>
+/// %1 = arith.muli %idxs, %c3 : vector<4xindex>
+/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...)
+/// ```
+///
+/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
+/// but should be fairly straightforward to extend beyond that.
+struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::GatherOp op,
+ PatternRewriter &rewriter) const override {
+ Value base = op.getBase();
+ if (!base.getDefiningOp())
+ return failure();
+
+ // TODO: Strided accesses might be coming from other ops as well
+ auto subview = dyn_cast<memref::SubViewOp>(base.getDefiningOp());
+ if (!subview)
+ return failure();
+
+ // TODO: Allows ranks > 2.
+ if (subview.getSource().getType().getRank() != 2)
+ return failure();
+
+ // Get strides
+ auto layout = subview.getResult().getType().getLayout();
+ auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
+
+ // TODO: Allow the access to be strided in multiple dimensions.
+ if (stridedLayoutAttr.getStrides().size() != 1)
+ return failure();
+
+ int64_t srcTrailingDim = subview.getSource().getType().getShape().back();
+
+ // Assume that the stride matches the trailing dimension of the source
+ // memref.
+ // TODO: Relax this assumption.
+ if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
+ return failure();
+
+ // 1. Collapse the input memref so that it's "flat".
+ SmallVector<ReassociationIndices> reassoc = {{0, 1}};
+ Value collapsed = rewriter.create<memref::CollapseShapeOp>(
+ op.getLoc(), subview.getSource(), reassoc);
+
+ // 2. Generate new gather indices that will model the
+ // strided access.
+ auto stride = rewriter.getIndexAttr(srcTrailingDim);
+ auto vType = op.getIndexVec().getType();
+ Value mulCst = rewriter.create<arith::ConstantOp>(
+ op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
+
+ Value newIdxs =
+ rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
+
+ // 3. Create an updated gather op with the collapsed input memref and the
+ // updated indices.
+ Value newGather = rewriter.create<vector::GatherOp>(
+ op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
+ newIdxs, op.getMask(), op.getPassThru());
+ rewriter.replaceOp(op, newGather);
+
+ return success();
+ }
+};
+
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
@@ -168,6 +244,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
void mlir::vector::populateVectorGatherLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
- benefit);
+ patterns.add<FlattenGather, RemoveStrideFromGatherSource,
+ Gather1DToConditionalLoads>(patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 026bec8cd65d3f5..3de7f44e4fb3e27 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -151,3 +151,57 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+
+// Check that vector.gather of a strided memref is replaced with a
+// vector.gather with indices encoding the original strides. Note that with the
+// other patterns
+#map = affine_map<()[s0] -> (s0 * 4096)>
+#map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)>
+func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) {
+ %c0 = arith.constant 0 : index
+ %x_1 = affine.apply #map()[%x]
+ // Strided MemRef
+ %subview = memref.subview %M_in[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+ %cst_0 = arith.constant dense<true> : vector<4xi1>
+ %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+ // Gather of a strided MemRef
+ %7 = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ %subview_1 = memref.subview %M_out[%x_1] [%y] [1] : memref<518400xf32> to memref<?xf32, strided<[1], offset: ?>>
+ vector.store %7, %subview_1[%c0] : memref<?xf32, strided<[1], offset: ?>>, vector<4xf32>
+ return
+}
+// CHECK-LABEL: func.func @strided_gather(
+// CHECK-SAME: %[[M_in:.*]]: memref<100x3xf32>,
+// CHECK-SAME: %[[M_out:.*]]: memref<518400xf32>,
+// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
+// CHECK-SAME: %[[VAL_4:.*]]: index,
+// CHECK-SAME: %[[VAL_5:.*]]: index) {
+// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
+
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_in]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
+// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
+
+// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
+// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
+// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
+
+// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
+// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
+// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
+
+// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
+// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
+// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
+// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
+
+// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
+// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
+// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
+// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
>From bb71dbcfe1d6b0e0016712cdeb7c9941b23405d8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 28 Nov 2023 11:42:37 +0000
Subject: [PATCH 2/4] fixup! [mlir][Vector] Add a rewrite pattern for gather
over a strided memref
Refine based on PR feedback
---
.../Vector/Transforms/LowerVectorGather.cpp | 20 ++++++-----
.../Vector/vector-gather-lowering.mlir | 36 ++++++++++---------
2 files changed, 32 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 54b350d7ac3524c..74487db5cdfc2bb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -100,12 +100,12 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
/// MemRef with updated indices that model the strided access.
///
/// ```mlir
-/// %subview = memref.subview %M (...) to memref<100xf32, strided<[3]>>
+/// %subview = memref.subview %M (...) memref<100x3xf32> to memref<100xf32, strided<[3]>>
/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>>
/// ```
/// ==>
/// ```mlir
-/// %collapse_shape = memref.collapse_shape %M (...) into memref<300xf32>
+/// %collapse_shape = memref.collapse_shape %M (...) memref<100x3xf32> into memref<300xf32>
/// %1 = arith.muli %idxs, %c3 : vector<4xindex>
/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...)
/// ```
@@ -122,23 +122,27 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
return failure();
// TODO: Strided accesses might be coming from other ops as well
- auto subview = dyn_cast<memref::SubViewOp>(base.getDefiningOp());
+ auto subview = base.getDefiningOp<memref::SubViewOp>();
if (!subview)
return failure();
- // TODO: Allows ranks > 2.
- if (subview.getSource().getType().getRank() != 2)
+ auto sourceType = subview.getSource().getType();
+
+ // TODO: Allow ranks > 2.
+ if (sourceType.getRank() != 2)
return failure();
// Get strides
auto layout = subview.getResult().getType().getLayout();
auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
+ if (!stridedLayoutAttr)
+ return failure();
// TODO: Allow the access to be strided in multiple dimensions.
if (stridedLayoutAttr.getStrides().size() != 1)
return failure();
- int64_t srcTrailingDim = subview.getSource().getType().getShape().back();
+ int64_t srcTrailingDim = sourceType.getShape().back();
// Assume that the stride matches the trailing dimension of the source
// memref.
@@ -153,8 +157,8 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
// 2. Generate new gather indices that will model the
// strided access.
- auto stride = rewriter.getIndexAttr(srcTrailingDim);
- auto vType = op.getIndexVec().getType();
+ IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
+ VectorType vType = op.getIndexVec().getType();
Value mulCst = rewriter.create<arith::ConstantOp>(
op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 3de7f44e4fb3e27..a7291c359d3511d 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -153,45 +153,49 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
}
// Check that vector.gather of a strided memref is replaced with a
-// vector.gather with indices encoding the original strides. Note that with the
-// other patterns
+// vector.gather with indices encoding the original strides. Note that multiple
+// patterns are run for this example, e.g.:
+ // 1. "remove stride from gather source"
+ // 2. "flatten gather"
+// However, the main goal is to the test Pattern 1 above.
#map = affine_map<()[s0] -> (s0 * 4096)>
#map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)>
-func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) {
+func.func @strided_gather(%base : memref<100x3xf32>,
+ %M_out: memref<518400xf32>,
+ %idxs : vector<4xindex>,
+ %x : index, %y : index) -> vector<4xf32> {
%c0 = arith.constant 0 : index
%x_1 = affine.apply #map()[%x]
// Strided MemRef
- %subview = memref.subview %M_in[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
- %cst_0 = arith.constant dense<true> : vector<4xi1>
- %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+ %subview = memref.subview %base[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+ %mask = arith.constant dense<true> : vector<4xi1>
+ %pass_thru = arith.constant dense<0.000000e+00> : vector<4xf32>
// Gather of a strided MemRef
- %7 = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
- %subview_1 = memref.subview %M_out[%x_1] [%y] [1] : memref<518400xf32> to memref<?xf32, strided<[1], offset: ?>>
- vector.store %7, %subview_1[%c0] : memref<?xf32, strided<[1], offset: ?>>, vector<4xf32>
- return
+ %res = vector.gather %subview[%c0] [%idxs], %mask, %pass_thru : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %res : vector<4xf32>
}
// CHECK-LABEL: func.func @strided_gather(
-// CHECK-SAME: %[[M_in:.*]]: memref<100x3xf32>,
+// CHECK-SAME: %[[base:.*]]: memref<100x3xf32>,
// CHECK-SAME: %[[M_out:.*]]: memref<518400xf32>,
// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
// CHECK-SAME: %[[VAL_4:.*]]: index,
-// CHECK-SAME: %[[VAL_5:.*]]: index) {
+// CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> {
// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
-// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_in]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
-// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
-// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
@@ -203,5 +207,5 @@ func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>,
// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
-// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
>From 2fad34c305999e2fa333bba620795db3ff485cd0 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 28 Nov 2023 11:54:22 +0000
Subject: [PATCH 3/4] fixup! [mlir][Vector] Add a rewrite pattern for gather
over a strided memref
Fix formatting
---
.../Dialect/Vector/Transforms/LowerVectorGather.cpp | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 74487db5cdfc2bb..3bbbf8167f52bfa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -100,14 +100,15 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
/// MemRef with updated indices that model the strided access.
///
/// ```mlir
-/// %subview = memref.subview %M (...) memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>>
+/// %subview = memref.subview %M (...) memref<100x3xf32> to memref<100xf32,
+/// strided<[3]>> %gather = vector.gather %subview (...) : memref<100xf32,
+/// strided<[3]>>
/// ```
/// ==>
/// ```mlir
-/// %collapse_shape = memref.collapse_shape %M (...) memref<100x3xf32> into memref<300xf32>
-/// %1 = arith.muli %idxs, %c3 : vector<4xindex>
-/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...)
+/// %collapse_shape = memref.collapse_shape %M (...) memref<100x3xf32> into
+/// memref<300xf32> %1 = arith.muli %idxs, %c3 : vector<4xindex> %gather =
+/// vector.gather %collapse_shape (...) : memref<300xf32> (...)
/// ```
///
/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
>From 7d9f17c6b1a0cf77c06f89b9d4ae60bfee08ad4d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 28 Nov 2023 13:07:02 +0000
Subject: [PATCH 4/4] fixup! [mlir][Vector] Add a rewrite pattern for gather
over a strided memref
Restrict Gather1DToConditionalLoads
---
.../Dialect/Vector/Transforms/LowerVectorGather.cpp | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3bbbf8167f52bfa..8372ffa157162c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -196,6 +196,17 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();
+
+ // vector.load requires the most minor memref dim to have unit stride
+ if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+ memType.getLayout();
+ if (auto stridesAttr =
+ dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
+ if (stridesAttr.getStrides().back() != 1)
+ return failure();
+ }
+ }
+
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndexVec());
More information about the Mlir-commits
mailing list