[Mlir-commits] [mlir] b91d5af - [MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (#122437)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 12 08:02:45 PST 2025
Author: Twice
Date: 2025-01-12T16:02:41Z
New Revision: b91d5af1ac3ad2c18b1dfde2061a6ac1d638e6e4
URL: https://github.com/llvm/llvm-project/commit/b91d5af1ac3ad2c18b1dfde2061a6ac1d638e6e4
DIFF: https://github.com/llvm/llvm-project/commit/b91d5af1ac3ad2c18b1dfde2061a6ac1d638e6e4.diff
LOG: [MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (#122437)
In `Gather1DToConditionalLoads`, currently we will check if the stride
of the most minor dim of the input memref is 1. And if not, the
rewriting pattern will not be applied. However, according to the
verification of `vector.load` here:
https://github.com/llvm/llvm-project/blob/4e32271e8b304eb018c69f74c16edd1668fcdaf3/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L4971-L4975
.. if the output vector type of `vector.load` contains only one element,
we can ignore the requirement of the stride of the input memref, i.e.
the input memref can be with any stride layout attribute in such case.
So here we can allow more cases in lowering `vector.gather` by relaxing
such check.
As shown in the test case attached in this patch
[here](https://github.com/llvm/llvm-project/blob/1933fbad58302814ccce5991a9320c0967f3571b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir#L151),
now `vector.gather` of memref with non-trivial stride can be lowered
successfully if the result vector contains only one element.
---------
Signed-off-by: PragmaTwice <twice at apache.org>
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
mlir/test/Dialect/Vector/vector-gather-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f1a5aa7664d2f3..3b38505becd188 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -206,10 +206,12 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value base = op.getBase();
// vector.load requires the most minor memref dim to have unit stride
+ // (unless reading exactly 1 element)
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
- if (stridesAttr.getStrides().back() != 1)
+ if (stridesAttr.getStrides().back() != 1 &&
+ resultTy.getNumElements() != 1)
return failure();
}
}
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5ad3a23e0ba15c..20e9400ed698d4 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -136,6 +136,34 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
return %0 : vector<2xf32>
}
+// CHECK-LABEL: @gather_memref_non_unit_stride_read_1_element
+// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
+// CHECK: %[[IDX:.*]] = vector.extract %arg1[0] : index from vector<1xindex>
+// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
+// CHECK: %[[VEC:.*]] = vector.load %arg0[%[[IDX]]] : memref<4xf32, strided<[2]>>, vector<1xf32>
+// CHECK: %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
+// CHECK: scf.yield %[[RES]] : vector<1xf32>
+// CHECK: } else {
+// CHECK: scf.yield %arg3 : vector<1xf32>
+// CHECK: }
+// CHECK: return %[[RET]] : vector<1xf32>
+func.func @gather_memref_non_unit_stride_read_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
+ return %0 : vector<1xf32>
+}
+
+// CHECK-LABEL: @gather_memref_non_unit_stride_read_more_than_1_element
+// CHECK: %[[CONST:.*]] = arith.constant 0 : index
+// CHECK: %[[RET:.*]] = vector.gather %arg0[%[[CONST]]] [%arg1], %arg2, %arg3 : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+// CHECK: return %[[RET]] : vector<2xf32>
+func.func @gather_memref_non_unit_stride_read_more_than_1_element(%base: memref<4xf32, strided<[2]>>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: @gather_tensor_2d
// CHECK: scf.if
// CHECK: tensor.extract
More information about the Mlir-commits
mailing list