[Mlir-commits] [mlir] [mlir][vector] Fix vector.gather lowering for strided memrefs. (PR #184706)

Han-Chung Wang llvmlistbot at llvm.org
Fri Mar 6 12:55:14 PST 2026


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/184706

>From f62c8797eb4fef59296367c202b648693d6e2ca1 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 4 Mar 2026 15:58:17 -0800
Subject: [PATCH 1/4] [mlir][vector] Fix vector.gather lowering for strided
 memrefs.

The old implementation did not take strides into account, which leads to
wrong access. It is correct for continguous memrefs, but not strided
memrefs.

The revision fixes it by delinearizing the 1-D offset back into N-D
indices to shifts the `baseOffsets` from the gather op.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Vector/Transforms/LowerVectorGather.cpp   | 65 +++++++++++++++++--
 .../Vector/vector-gather-lowering.mlir        | 35 ++++++++++
 2 files changed, 96 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 6bc8347bc6f76..7eec832a8c766 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -160,9 +161,41 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
   }
 };
 
+/// Returns true only if the given memref type is provably contiguous (dense
+/// row-major layout). Returns false if non-contiguous or if contiguity cannot
+/// be determined (e.g., dynamic dimensions/strides).
+static bool isContiguousMemRef(MemRefType memType) {
+  // Identity (default) layout is always dense row-major.
+  if (memType.getLayout().isIdentity())
+    return true;
+
+  // For explicit layouts, check if strides match contiguous.
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(memType.getStridesAndOffset(strides, offset)))
+    return false; // Can't determine strides; assume non-contiguous.
+
+  int64_t expectedStride = 1;
+  for (int64_t d = memType.getRank() - 1; d >= 0; --d) {
+    int64_t dimSize = memType.getDimSize(d);
+    if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(strides[d]))
+      return false; // Can't prove contiguous; assume non-contiguous.
+    if (strides[d] != expectedStride)
+      return false;
+    expectedStride *= dimSize;
+  }
+  return true;
+}
+
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
+///
+/// When the source memref is non-contiguous (e.g., from a `memref.subview`),
+/// the 1-D gather offset is delinearized back into N-D indices using the
+/// memref's shape. This is necessary because the gather offset was linearized
+/// assuming dense row-major layout; on strided memrefs, adding the linear
+/// offset to the last index does not correctly wrap to the next row.
 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
   using Base::Base;
 
@@ -183,15 +216,21 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
-    // (unless reading exactly 1 element)
+    // Check if the source memref has non-contiguous strides, requiring
+    // delinearization of the gather indices.
+    bool needsDelinearization = false;
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+      // vector.load requires the most minor memref dim to have unit stride
+      // (unless reading exactly 1 element).
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
         if (stridesAttr.getStrides().back() != 1 &&
             resultTy.getNumElements() != 1)
           return failure();
       }
+
+      if (memType.getRank() > 1 && !isContiguousMemRef(memType))
+        needsDelinearization = true;
     }
 
     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
@@ -200,6 +239,12 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     auto baseOffsets = llvm::to_vector(op.getOffsets());
     Value lastBaseOffset = baseOffsets.back();
 
+    // Compute the delinearization basis once, outside the per-element loop.
+    SmallVector<Value> origBaseOffsets(baseOffsets);
+    SmallVector<OpFoldResult> basis;
+    if (needsDelinearization)
+      basis = memref::getMixedSizes(rewriter, loc, base);
+
     Value result = op.getPassThru();
     BoolAttr nontemporalAttr = nullptr;
     IntegerAttr alignmentAttr = op.getAlignmentAttr();
@@ -210,8 +255,20 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
       Value condition =
           vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
       Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
-      baseOffsets.back() =
-          rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
+
+      if (needsDelinearization) {
+        // The gather offset was linearized using dense row-major order:
+        //   offset = idx[0] * dim[1] * ... * dim[n] + ... + idx[n]
+        // Recover the N-D indices by delinearizing with the memref shape.
+        auto delinOp = affine::AffineDelinearizeIndexOp::create(
+            rewriter, loc, index, basis, /*hasOuterBound=*/true);
+        for (int64_t d = 0, rank = baseOffsets.size(); d < rank; ++d)
+          baseOffsets[d] = rewriter.createOrFold<arith::AddIOp>(
+              loc, origBaseOffsets[d], delinOp.getResult(d));
+      } else {
+        baseOffsets.back() =
+            rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
+      }
 
       auto loadBuilder = [&](OpBuilder &b, Location loc) {
         Value extracted;
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index d4ff603c2b887..75ed5579ba7d7 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -289,3 +289,38 @@ func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32>
   return %0 : vector<[2]xf32>
 }
+
+// Verify that gather on a strided 2D memref delinearizes the 1-D gather offset
+// back into N-D indices via affine.delinearize_index.
+// CHECK-LABEL: @gather_strided_memref_2d
+// CHECK-SAME:    (%[[BASE:.+]]: memref<4x2xf32, strided<[4, 1]>>,
+// CHECK-SAME:     %[[IDXVEC:.+]]: vector<4xi32>,
+// CHECK-SAME:     %[[MASK:.+]]: vector<4xi1>,
+// CHECK-SAME:     %[[PASS:.+]]: vector<4xf32>)
+// CHECK-DAG:     %[[IDXS:.+]] = arith.index_cast %[[IDXVEC]]
+//
+// First element: delinearize the scalar offset into 2D indices.
+// CHECK-DAG:     %[[M0:.+]]   = vector.extract %[[MASK]][0]
+// CHECK-DAG:     %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
+// CHECK:         %[[DL0:.+]]:2 = affine.delinearize_index %[[IDX0]] into (4, 2)
+// CHECK:         scf.if %[[M0]]
+// CHECK:           vector.load %[[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
+// CHECK:         else
+//
+// Remaining 3 elements follow the same pattern.
+// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
+// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
+// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
+// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
+// CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
+// CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
+func.func @gather_strided_memref_2d(
+    %base: memref<4x2xf32, strided<[4, 1]>>,
+    %v: vector<4xi32>, %mask: vector<4xi1>,
+    %pass_thru: vector<4xf32>) -> vector<4xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru
+    : memref<4x2xf32, strided<[4, 1]>>, vector<4xi32>,
+      vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0 : vector<4xf32>
+}

>From 971ba6ef6f8084402ed842e159ff217d58e8f097 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 5 Mar 2026 15:50:28 -0800
Subject: [PATCH 2/4] address comments

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Vector/Transforms/LowerVectorGather.cpp   | 39 ++++++++++--------
 .../Vector/vector-gather-lowering.mlir        | 40 ++++++++++++++-----
 .../Dialect/Vector/TestVectorTransforms.cpp   |  7 ++--
 3 files changed, 58 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 7eec832a8c766..279879d4e0378 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -191,11 +191,11 @@ static bool isContiguousMemRef(MemRefType memType) {
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
 ///
-/// When the source memref is non-contiguous (e.g., from a `memref.subview`),
-/// the 1-D gather offset is delinearized back into N-D indices using the
-/// memref's shape. This is necessary because the gather offset was linearized
-/// assuming dense row-major layout; on strided memrefs, adding the linear
-/// offset to the last index does not correctly wrap to the next row.
+/// When the source memref has a non-identity layout (e.g., from a
+/// `memref.subview`), the gather index is combined with the base offsets via
+/// linearize-then-delinearize to produce correct N-D load indices:
+///   flatIdx = linearize(baseOffsets) + gatherIndex
+///   loadIndices = delinearize(flatIdx, memrefShape)
 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
   using Base::Base;
 
@@ -216,8 +216,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // Check if the source memref has non-contiguous strides, requiring
-    // delinearization of the gather indices.
+    // Check if the source memref has a non-identity layout, requiring
+    // linearize/delinearize to compute correct N-D load indices.
     bool needsDelinearization = false;
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
       // vector.load requires the most minor memref dim to have unit stride
@@ -239,11 +239,15 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     auto baseOffsets = llvm::to_vector(op.getOffsets());
     Value lastBaseOffset = baseOffsets.back();
 
-    // Compute the delinearization basis once, outside the per-element loop.
-    SmallVector<Value> origBaseOffsets(baseOffsets);
+    // Compute the basis (memref shape) and linearized base offset once,
+    // outside the per-element loop.
     SmallVector<OpFoldResult> basis;
-    if (needsDelinearization)
+    Value linearizedBase;
+    if (needsDelinearization) {
       basis = memref::getMixedSizes(rewriter, loc, base);
+      linearizedBase = affine::AffineLinearizeIndexOp::create(
+          rewriter, loc, baseOffsets, basis, /*disjoint=*/false);
+    }
 
     Value result = op.getPassThru();
     BoolAttr nontemporalAttr = nullptr;
@@ -257,14 +261,17 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
       Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
 
       if (needsDelinearization) {
-        // The gather offset was linearized using dense row-major order:
-        //   offset = idx[0] * dim[1] * ... * dim[n] + ... + idx[n]
-        // Recover the N-D indices by delinearizing with the memref shape.
+        // The gather index offsets the innermost dimension. Combine with
+        // the base offsets by linearizing, adding the gather index, then
+        // delinearizing back to N-D indices:
+        //   flatIdx = linearize(baseOffsets, shape) + gatherIndex
+        //   loadIndices = delinearize(flatIdx, shape)
+        Value flatIdx =
+            rewriter.createOrFold<arith::AddIOp>(loc, linearizedBase, index);
         auto delinOp = affine::AffineDelinearizeIndexOp::create(
-            rewriter, loc, index, basis, /*hasOuterBound=*/true);
+            rewriter, loc, flatIdx, basis, /*hasOuterBound=*/true);
         for (int64_t d = 0, rank = baseOffsets.size(); d < rank; ++d)
-          baseOffsets[d] = rewriter.createOrFold<arith::AddIOp>(
-              loc, origBaseOffsets[d], delinOp.getResult(d));
+          baseOffsets[d] = delinOp.getResult(d);
       } else {
         baseOffsets.back() =
             rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 75ed5579ba7d7..5439c30f64f24 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -290,24 +290,18 @@ func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask
   return %0 : vector<[2]xf32>
 }
 
-// Verify that gather on a strided 2D memref delinearizes the 1-D gather offset
-// back into N-D indices via affine.delinearize_index.
+// Verify that gather on a strided 2D memref with zero base offsets
+// delinearizes the gather index directly (linearize and addi fold away).
 // CHECK-LABEL: @gather_strided_memref_2d
 // CHECK-SAME:    (%[[BASE:.+]]: memref<4x2xf32, strided<[4, 1]>>,
 // CHECK-SAME:     %[[IDXVEC:.+]]: vector<4xi32>,
 // CHECK-SAME:     %[[MASK:.+]]: vector<4xi1>,
 // CHECK-SAME:     %[[PASS:.+]]: vector<4xf32>)
 // CHECK-DAG:     %[[IDXS:.+]] = arith.index_cast %[[IDXVEC]]
-//
-// First element: delinearize the scalar offset into 2D indices.
-// CHECK-DAG:     %[[M0:.+]]   = vector.extract %[[MASK]][0]
 // CHECK-DAG:     %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
 // CHECK:         %[[DL0:.+]]:2 = affine.delinearize_index %[[IDX0]] into (4, 2)
-// CHECK:         scf.if %[[M0]]
+// CHECK:         scf.if
 // CHECK:           vector.load %[[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
-// CHECK:         else
-//
-// Remaining 3 elements follow the same pattern.
 // CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
 // CHECK:         vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32, strided<[4, 1]>>, vector<1xf32>
 // CHECK:         affine.delinearize_index %{{.+}} into (4, 2)
@@ -324,3 +318,31 @@ func.func @gather_strided_memref_2d(
       vector<4xi1>, vector<4xf32> into vector<4xf32>
   return %0 : vector<4xf32>
 }
+
+// -----
+
+// Verify that gather with non-zero base offsets on a strided memref correctly
+// incorporates the base offsets via linearize + add + delinearize.
+// CHECK-LABEL: @gather_strided_memref_2d_nonzero_base
+// CHECK-SAME:    (%[[BASE:.+]]: memref<4x2xf32, strided<[4, 1]>>,
+// CHECK-SAME:     %[[OFF0:.+]]: index, %[[OFF1:.+]]: index,
+// CHECK-SAME:     %[[IDXVEC:.+]]: vector<2xi32>,
+// CHECK-SAME:     %[[MASK:.+]]: vector<2xi1>,
+// CHECK-SAME:     %[[PASS:.+]]: vector<2xf32>)
+// CHECK-DAG:     %[[IDXS:.+]] = arith.index_cast %[[IDXVEC]]
+// CHECK:         %[[LIN:.+]] = affine.linearize_index [%[[OFF0]], %[[OFF1]]] by (4, 2)
+// CHECK:         %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
+// CHECK:         %[[FLAT:.+]] = arith.addi %[[LIN]], %[[IDX0]]
+// CHECK:         %[[DL:.+]]:2 = affine.delinearize_index %[[FLAT]] into (4, 2)
+// CHECK:         scf.if
+// CHECK:           vector.load %[[BASE]][%[[DL]]#0, %[[DL]]#1]
+func.func @gather_strided_memref_2d_nonzero_base(
+    %base: memref<4x2xf32, strided<[4, 1]>>,
+    %off0: index, %off1: index,
+    %v: vector<2xi32>, %mask: vector<2xi1>,
+    %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.gather %base[%off0, %off1][%v], %mask, %pass_thru
+    : memref<4x2xf32, strided<[4, 1]>>, vector<2xi32>,
+      vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index dd2dfc4f3e441..7728ad3dd2ad1 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -783,9 +783,10 @@ struct TestVectorGatherLowering
            "loads";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, func::FuncDialect,
-                    memref::MemRefDialect, scf::SCFDialect,
-                    tensor::TensorDialect, vector::VectorDialect>();
+    registry.insert<affine::AffineDialect, arith::ArithDialect,
+                    func::FuncDialect, memref::MemRefDialect,
+                    scf::SCFDialect, tensor::TensorDialect,
+                    vector::VectorDialect>();
   }
 
   void runOnOperation() override {

>From e0eeba6ef1343eb5d04c6dbdf556242c411f833b Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 5 Mar 2026 17:32:34 -0800
Subject: [PATCH 3/4] clang-format

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 7728ad3dd2ad1..ff3520a286cc8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -784,9 +784,8 @@ struct TestVectorGatherLowering
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<affine::AffineDialect, arith::ArithDialect,
-                    func::FuncDialect, memref::MemRefDialect,
-                    scf::SCFDialect, tensor::TensorDialect,
-                    vector::VectorDialect>();
+                    func::FuncDialect, memref::MemRefDialect, scf::SCFDialect,
+                    tensor::TensorDialect, vector::VectorDialect>();
   }
 
   void runOnOperation() override {

>From 4f780da27b9459f63220a8f46eb724cceb8eead1 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 6 Mar 2026 12:54:46 -0800
Subject: [PATCH 4/4] Always use delinearization approach when rank > 1

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Vector/Transforms/LowerVectorGather.cpp   | 48 +++++--------------
 .../Vector/vector-gather-lowering.mlir        |  8 ++--
 2 files changed, 16 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 279879d4e0378..495fcaa114a2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -161,40 +161,14 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
   }
 };
 
-/// Returns true only if the given memref type is provably contiguous (dense
-/// row-major layout). Returns false if non-contiguous or if contiguity cannot
-/// be determined (e.g., dynamic dimensions/strides).
-static bool isContiguousMemRef(MemRefType memType) {
-  // Identity (default) layout is always dense row-major.
-  if (memType.getLayout().isIdentity())
-    return true;
-
-  // For explicit layouts, check if strides match contiguous.
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  if (failed(memType.getStridesAndOffset(strides, offset)))
-    return false; // Can't determine strides; assume non-contiguous.
-
-  int64_t expectedStride = 1;
-  for (int64_t d = memType.getRank() - 1; d >= 0; --d) {
-    int64_t dimSize = memType.getDimSize(d);
-    if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(strides[d]))
-      return false; // Can't prove contiguous; assume non-contiguous.
-    if (strides[d] != expectedStride)
-      return false;
-    expectedStride *= dimSize;
-  }
-  return true;
-}
-
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
 ///
-/// When the source memref has a non-identity layout (e.g., from a
-/// `memref.subview`), the gather index is combined with the base offsets via
-/// linearize-then-delinearize to produce correct N-D load indices:
-///   flatIdx = linearize(baseOffsets) + gatherIndex
+/// For multi-dimensional memrefs (rank > 1), the gather index is combined
+/// with the base offsets via linearize-then-delinearize to produce correct
+/// N-D load indices:
+///   flatIdx = linearize(baseOffsets, memrefShape) + gatherIndex
 ///   loadIndices = delinearize(flatIdx, memrefShape)
 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
   using Base::Base;
@@ -216,9 +190,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // Check if the source memref has a non-identity layout, requiring
-    // linearize/delinearize to compute correct N-D load indices.
-    bool needsDelinearization = false;
+    // For multi-dimensional memrefs, use linearize+delinearize to compute
+    // correct N-D load indices from the 1-D gather offset.
+    bool useDelinearization = false;
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
       // vector.load requires the most minor memref dim to have unit stride
       // (unless reading exactly 1 element).
@@ -229,8 +203,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
           return failure();
       }
 
-      if (memType.getRank() > 1 && !isContiguousMemRef(memType))
-        needsDelinearization = true;
+      if (memType.getRank() > 1)
+        useDelinearization = true;
     }
 
     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
@@ -243,7 +217,7 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     // outside the per-element loop.
     SmallVector<OpFoldResult> basis;
     Value linearizedBase;
-    if (needsDelinearization) {
+    if (useDelinearization) {
       basis = memref::getMixedSizes(rewriter, loc, base);
       linearizedBase = affine::AffineLinearizeIndexOp::create(
           rewriter, loc, baseOffsets, basis, /*disjoint=*/false);
@@ -260,7 +234,7 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
           vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
       Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
 
-      if (needsDelinearization) {
+      if (useDelinearization) {
         // The gather index offsets the innermost dimension. Combine with
         // the base offsets by linearizing, adding the gather index, then
         // delinearizing back to N-D indices:
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5439c30f64f24..48a22958c0f67 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -54,11 +54,13 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
 // CHECK-DAG:     %[[C0:.+]]    = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.+]]    = arith.constant 1 : index
 // CHECK-DAG:     [[PTV0:%.+]]  = vector.extract [[PASS]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK:         %[[LIN:.+]]   = affine.linearize_index [%[[C0]], %[[C1]]] by
 // CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0, 0] : i1 from vector<2x3xi1>
 // CHECK-DAG:     [[IDX0:%.+]]  = vector.extract [[IDXVEC]][0, 0] : index from vector<2x3xindex>
-// CHECK-NEXT:    %[[OFF0:.+]]  = arith.addi [[IDX0]], %[[C1]] : index
-// CHECK-NEXT:    [[RES0:%.+]]  = scf.if [[M0]] -> (vector<3xf32>)
-// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[C0]], %[[OFF0]]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:         %[[FLAT0:.+]] = arith.addi %[[LIN]], [[IDX0]] : index
+// CHECK:         %[[DL0:.+]]:2 = affine.delinearize_index %[[FLAT0]] into
+// CHECK:         [[RES0:%.+]]  = scf.if [[M0]] -> (vector<3xf32>)
+// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<?x?xf32>, vector<1xf32>
 // CHECK-NEXT:      [[ELEM0:%.+]] = vector.extract [[LD0]][0] : f32 from vector<1xf32>
 // CHECK-NEXT:      [[INS0:%.+]]  = vector.insert [[ELEM0]], [[PTV0]] [0] : f32 into vector<3xf32>
 // CHECK-NEXT:      scf.yield [[INS0]] : vector<3xf32>



More information about the Mlir-commits mailing list