[Mlir-commits] [mlir] [mlir][vector] Remove unit-stride check in Gather1DToConditionalLoads (PR #189178)
Jorn Tuyls
llvmlistbot at llvm.org
Sat Mar 28 09:57:38 PDT 2026
https://github.com/jtuyls created https://github.com/llvm/llvm-project/pull/189178
Each scalarized load is vector<1xelemTy>, valid regardless of stride. The check caused rank-1 strided gathers (from rank-reducing subviews of rank-3+ sources) to go unlowered and crash during LLVM translation.
>From 159948fc87eed0f06898497d679a8550b904b361 Mon Sep 17 00:00:00 2001
From: Jorn <jorn.tuyls at gmail.com>
Date: Sat, 28 Mar 2026 08:34:43 -0700
Subject: [PATCH] [mlir][vector] Remove unit-stride check in
Gather1DToConditionalLoads
Each scalarized load is vector<1xelemTy>, valid regardless of stride.
The check caused rank-1 strided gathers (from rank-reducing subviews
of rank-3+ sources) to go unlowered and crash during LLVM translation.
---
.../Vector/Transforms/LowerVectorGather.cpp | 10 -------
.../Vector/vector-gather-lowering.mlir | 28 +++++++++++++++++++
2 files changed, 28 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 7194d41d60df7..fd95ea0c39a54 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -195,16 +195,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
// correct N-D load indices from the 1-D gather index.
bool useDelinearization = false;
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
- // vector.load requires the most minor memref dim to have unit stride
- // (unless reading exactly 1 element).
- if (auto stridesAttr =
- dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
- if (stridesAttr.getStrides().back() != 1 &&
- resultTy.getNumElements() != 1)
- return rewriter.notifyMatchFailure(
- op, "most minor memref dim must have unit stride");
- }
-
if (memType.getRank() > 1)
useDelinearization = true;
}
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 59b13e300e5e5..1a5407e7f4752 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -360,3 +360,31 @@ func.func @gather_memref_2d_delinearize_nonzero_offsets(
vector<2xi1>, vector<2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+
+// -----
+
+// Verify that gather on a rank-1 strided memref (from rank-reducing subview)
+// is correctly scalarized to per-element vector.load ops.
+
+// CHECK-LABEL: @gather_rank1_strided_memref
+// CHECK-SAME: (%[[BASE:.+]]: memref<4xf32, strided<[6]>>
+// CHECK-NOT: vector.gather
+// CHECK: vector.extract {{.*}}[0]
+// CHECK: scf.if {{.*}} -> (vector<3xf32>)
+// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+// CHECK: vector.extract {{.*}}[1]
+// CHECK: scf.if {{.*}} -> (vector<3xf32>)
+// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+// CHECK: vector.extract {{.*}}[2]
+// CHECK: scf.if {{.*}} -> (vector<3xf32>)
+// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+func.func @gather_rank1_strided_memref(
+ %base: memref<4xf32, strided<[6]>>,
+ %v: vector<3xindex>, %mask: vector<3xi1>,
+ %pass_thru: vector<3xf32>) -> vector<3xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
+ : memref<4xf32, strided<[6]>>, vector<3xindex>,
+ vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %0 : vector<3xf32>
+}
More information about the Mlir-commits
mailing list