[Mlir-commits] [mlir] [mlir][vector] Remove unit-stride check in Gather1DToConditionalLoads (PR #189178)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 28 09:58:12 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
Author: Jorn Tuyls (jtuyls)
<details>
<summary>Changes</summary>
Each scalarized load is vector<1xelemTy>, valid regardless of stride. The check caused rank-1 strided gathers (from rank-reducing subviews of rank-3+ sources) to go unlowered and crash during LLVM translation.
---
Full diff: https://github.com/llvm/llvm-project/pull/189178.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (-10)
- (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+28)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 7194d41d60df7..fd95ea0c39a54 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -195,16 +195,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
// correct N-D load indices from the 1-D gather index.
bool useDelinearization = false;
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
- // vector.load requires the most minor memref dim to have unit stride
- // (unless reading exactly 1 element).
- if (auto stridesAttr =
- dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
- if (stridesAttr.getStrides().back() != 1 &&
- resultTy.getNumElements() != 1)
- return rewriter.notifyMatchFailure(
- op, "most minor memref dim must have unit stride");
- }
-
if (memType.getRank() > 1)
useDelinearization = true;
}
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 59b13e300e5e5..1a5407e7f4752 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -360,3 +360,31 @@ func.func @gather_memref_2d_delinearize_nonzero_offsets(
vector<2xi1>, vector<2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+
+// -----
+
+// Verify that gather on a rank-1 strided memref (from rank-reducing subview)
+// is correctly scalarized to per-element vector.load ops.
+
+// CHECK-LABEL: @gather_rank1_strided_memref
+// CHECK-SAME: (%[[BASE:.+]]: memref<4xf32, strided<[6]>>
+// CHECK-NOT: vector.gather
+// CHECK: vector.extract {{.*}}[0]
+// CHECK: scf.if {{.*}} -> (vector<3xf32>)
+// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+// CHECK: vector.extract {{.*}}[1]
+// CHECK: scf.if {{.*}} -> (vector<3xf32>)
+// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+// CHECK: vector.extract {{.*}}[2]
+// CHECK: scf.if {{.*}} -> (vector<3xf32>)
+// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+func.func @gather_rank1_strided_memref(
+ %base: memref<4xf32, strided<[6]>>,
+ %v: vector<3xindex>, %mask: vector<3xi1>,
+ %pass_thru: vector<3xf32>) -> vector<3xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
+ : memref<4xf32, strided<[6]>>, vector<3xindex>,
+ vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %0 : vector<3xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/189178
More information about the Mlir-commits
mailing list