[Mlir-commits] [mlir] [mlir][vector] Remove unit-stride check in Gather1DToConditionalLoads (PR #189178)

Jorn Tuyls llvmlistbot at llvm.org
Sat Mar 28 09:57:38 PDT 2026


https://github.com/jtuyls created https://github.com/llvm/llvm-project/pull/189178

Each scalarized load is vector<1xelemTy>, valid regardless of stride. The check caused rank-1 strided gathers (from rank-reducing subviews of rank-3+ sources) to go unlowered and crash during LLVM translation.

>From 159948fc87eed0f06898497d679a8550b904b361 Mon Sep 17 00:00:00 2001
From: Jorn <jorn.tuyls at gmail.com>
Date: Sat, 28 Mar 2026 08:34:43 -0700
Subject: [PATCH] [mlir][vector] Remove unit-stride check in
 Gather1DToConditionalLoads

Each scalarized load is vector<1xelemTy>, valid regardless of stride.
The check caused rank-1 strided gathers (from rank-reducing subviews
of rank-3+ sources) to go unlowered and crash during LLVM translation.
---
 .../Vector/Transforms/LowerVectorGather.cpp   | 10 -------
 .../Vector/vector-gather-lowering.mlir        | 28 +++++++++++++++++++
 2 files changed, 28 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 7194d41d60df7..fd95ea0c39a54 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -195,16 +195,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     // correct N-D load indices from the 1-D gather index.
     bool useDelinearization = false;
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
-      // vector.load requires the most minor memref dim to have unit stride
-      // (unless reading exactly 1 element).
-      if (auto stridesAttr =
-              dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
-        if (stridesAttr.getStrides().back() != 1 &&
-            resultTy.getNumElements() != 1)
-          return rewriter.notifyMatchFailure(
-              op, "most minor memref dim must have unit stride");
-      }
-
       if (memType.getRank() > 1)
         useDelinearization = true;
     }
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 59b13e300e5e5..1a5407e7f4752 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -360,3 +360,31 @@ func.func @gather_memref_2d_delinearize_nonzero_offsets(
       vector<2xi1>, vector<2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
+
+// -----
+
+// Verify that gather on a rank-1 strided memref (from rank-reducing subview)
+// is correctly scalarized to per-element vector.load ops.
+
+// CHECK-LABEL: @gather_rank1_strided_memref
+// CHECK-SAME:    (%[[BASE:.+]]: memref<4xf32, strided<[6]>>
+// CHECK-NOT:     vector.gather
+// CHECK:         vector.extract {{.*}}[0]
+// CHECK:         scf.if {{.*}} -> (vector<3xf32>)
+// CHECK:           vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+// CHECK:         vector.extract {{.*}}[1]
+// CHECK:         scf.if {{.*}} -> (vector<3xf32>)
+// CHECK:           vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+// CHECK:         vector.extract {{.*}}[2]
+// CHECK:         scf.if {{.*}} -> (vector<3xf32>)
+// CHECK:           vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32>
+func.func @gather_rank1_strided_memref(
+    %base: memref<4xf32, strided<[6]>>,
+    %v: vector<3xindex>, %mask: vector<3xi1>,
+    %pass_thru: vector<3xf32>) -> vector<3xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
+    : memref<4xf32, strided<[6]>>, vector<3xindex>,
+      vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %0 : vector<3xf32>
+}



More information about the Mlir-commits mailing list