[Mlir-commits] [mlir] [MLIR][Vector] Allow any strided memref for one-element vector.load in lowering vector.gather (PR #122437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 10 05:50:27 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

In `Gather1DToConditionalLoads`, currently we will check if the stride of the most minor dim of the input memref is 1. And if not, the rewriting pattern will not be applied. However, according to the verification of `vector.load` here:
https://github.com/llvm/llvm-project/blob/4e32271e8b304eb018c69f74c16edd1668fcdaf3/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L4971-L4975

.. if the output vector type of `vector.load` contains only one element, we can ignore the requirement of the stride of the input memref, i.e. the input memref can be with any stride layout attribute in such case.

So here we can allow more cases in lowering `vector.gather` by relaxing such check.



---
Full diff: https://github.com/llvm/llvm-project/pull/122437.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2) 
- (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+18) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f1a5aa7664d2f3..4aff565b81b453 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -205,11 +205,13 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
+    // vector.load requires the most minor memref dim to have unit stride,
+    // or the result vector type to have only one element
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
-        if (stridesAttr.getStrides().back() != 1)
+        if (stridesAttr.getStrides().back() != 1 &&
+            resultTy.getNumElements() != 1)
           return failure();
       }
     }
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5ad3a23e0ba15c..5d7aff6f8762ad 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -136,6 +136,24 @@ func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: ve
   return %0 : vector<2xf32>
 }
 
+// CHECK-LABEL: @gather_strided_memref_1d
+// CHECK: %[[MASK:.*]] = vector.extract %arg2[0] : i1 from vector<1xi1>
+// CHECK: %1 = vector.extract %arg1[0] : index from vector<1xindex>
+// CHECK: %[[RET:.*]] = scf.if %[[MASK]] -> (vector<1xf32>) {
+// CHECK:   %[[VEC:.*]] = vector.load %arg0[%1] : memref<4xf32, strided<[2]>>, vector<1xf32>
+// CHECK:   %[[VAL:.*]] = vector.extract %[[VEC]][0] : f32 from vector<1xf32>
+// CHECK:   %[[RES:.*]] = vector.insert %[[VAL]], %arg3 [0] : f32 into vector<1xf32>
+// CHECK:   scf.yield %[[RES]] : vector<1xf32>
+// CHECK: } else {
+// CHECK:    scf.yield %arg3 : vector<1xf32>
+// CHECK: }
+// CHECK: return %[[RET]] : vector<1xf32>
+func.func @gather_strided_memref_1d(%base: memref<4xf32, strided<[2]>>, %v: vector<1xindex>, %mask: vector<1xi1>, %pass_thru: vector<1xf32>) -> vector<1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<4xf32, strided<[2]>>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32>
+  return %0 : vector<1xf32>
+}
+
 // CHECK-LABEL: @gather_tensor_2d
 // CHECK:  scf.if
 // CHECK:    tensor.extract

``````````

</details>


https://github.com/llvm/llvm-project/pull/122437


More information about the Mlir-commits mailing list