[Mlir-commits] [mlir] [mlir][vector] Remove unit-stride check in Gather1DToConditionalLoads (PR #189178)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 28 09:58:12 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Jorn Tuyls (jtuyls)

<details>
<summary>Changes</summary>

Each scalarized load is vector<1xelemTy>, valid regardless of stride. The check caused rank-1 strided gathers (from rank-reducing subviews of rank-3+ sources) to go unlowered and crash during LLVM translation.

---
Full diff: https://github.com/llvm/llvm-project/pull/189178.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (-10) 
- (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+28) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 7194d41d60df7..fd95ea0c39a54 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -195,16 +195,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     // correct N-D load indices from the 1-D gather index.
     bool useDelinearization = false;
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
-      // vector.load requires the most minor memref dim to have unit stride
-      // (unless reading exactly 1 element).
-      if (auto stridesAttr =
-              dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
-        if (stridesAttr.getStrides().back() != 1 &&
-            resultTy.getNumElements() != 1)
-          return rewriter.notifyMatchFailure(
-              op, "most minor memref dim must have unit stride");
-      }
-
       if (memType.getRank() > 1)
         useDelinearization = true;
     }
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 59b13e300e5e5..1a5407e7f4752 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -360,3 +360,31 @@ func.func @gather_memref_2d_delinearize_nonzero_offsets(
       vector<2xi1>, vector<2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
+
+// -----
+
+// Verify that gather on a rank-1 strided memref (from rank-reducing subview)
+// is correctly scalarized to per-element vector.load ops.
+
+// CHECK-LABEL: @gather_rank1_strided_memref
+// CHECK-SAME:    (%[[BASE:.+]]: memref<4xf32, strided<[6]>>
+// CHECK-NOT:     vector.gather
+// CHECK:         vector.extract {{.*}}[0]
+// CHECK:         scf.if {{.*}} -> (vector<3xf32>)
+// CHECK:           vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+// CHECK:         vector.extract {{.*}}[1]
+// CHECK:         scf.if {{.*}} -> (vector<3xf32>)
+// CHECK:           vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+// CHECK:         vector.extract {{.*}}[2]
+// CHECK:         scf.if {{.*}} -> (vector<3xf32>)
+// CHECK:           vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+func.func @gather_rank1_strided_memref(
+    %base: memref<4xf32, strided<[6]>>,
+    %v: vector<3xindex>, %mask: vector<3xi1>,
+    %pass_thru: vector<3xf32>) -> vector<3xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
+    : memref<4xf32, strided<[6]>>, vector<3xindex>,
+      vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %0 : vector<3xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/189178


More information about the Mlir-commits mailing list