[Mlir-commits] [mlir] 305dc4e - [mlir][vector] Lower vector.gather with delinearization approach (#184706)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 13 14:58:44 PDT 2026


Author: Han-Chung Wang
Date: 2026-03-13T14:58:40-07:00
New Revision: 305dc4e5a9a623b8d1effff5dddcd5fdfe56d6a3

URL: https://github.com/llvm/llvm-project/commit/305dc4e5a9a623b8d1effff5dddcd5fdfe56d6a3
DIFF: https://github.com/llvm/llvm-project/commit/305dc4e5a9a623b8d1effff5dddcd5fdfe56d6a3.diff

LOG: [mlir][vector] Lower vector.gather with delinearization approach (#184706)

The old implementation did not handle n-D memref correctly, which leads
to wrong access. E.g.,

```
 func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
  return %0 : vector<2x3xf32>
 }
```

is lowered to

```
  func.func @gather_memref_2d(%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %0 = ub.poison : vector<2x3xf32>
    %1 = vector.extract %arg3[0] : vector<3xf32> from vector<2x3xf32>
    %2 = vector.extract %arg2[0, 0] : i1 from vector<2x3xi1>
    %3 = vector.extract %arg1[0, 0] : index from vector<2x3xindex>
    %4 = arith.addi %3, %c1 : index
    %5 = scf.if %2 -> (vector<3xf32>) {
      %29 = vector.load %arg0[%c0, %4] : memref<?x?xf32>, vector<1xf32>
      %30 = vector.extract %29[0] : f32 from vector<1xf32>
      %31 = vector.insert %30, %1 [0] : f32 into vector<3xf32>
      scf.yield %31 : vector<3xf32>
    } else {
      scf.yield %1 : vector<3xf32>
    }
    // ...
```

The revision fixes it by by using `linearize(baseOffsets) + gatherIndex`
followed by `delinearize` to recover correct `n-D` load indices. This is
applied unconditionally for all rank > 1 memrefs.

Note that it enables the cases with strideds because we use
delinearization approach.

---------

Signed-off-by: hanhanW <hanhan0912 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
    mlir/test/Dialect/Vector/vector-gather-lowering.mlir
    mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 6bc8347bc6f76..7194d41d60df7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -163,6 +164,13 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
+///
+/// For multi-dimensional memrefs (rank > 1), the gather index is combined
+/// with the offsets via linearize-then-delinearize to produce correct
+/// N-D load indices:
+///   idx = indices[i]
+///   flatIdx = linearize(offsets, memrefShape) + idx
+///   loadIndices = delinearize(flatIdx, memrefShape)
 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
   using Base::Base;
 
@@ -183,22 +191,39 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
-    // (unless reading exactly 1 element)
+    // For multi-dimensional memrefs, use linearize+delinearize to compute
+    // correct N-D load indices from the 1-D gather index.
+    bool useDelinearization = false;
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+      // vector.load requires the most minor memref dim to have unit stride
+      // (unless reading exactly 1 element).
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
         if (stridesAttr.getStrides().back() != 1 &&
             resultTy.getNumElements() != 1)
-          return failure();
+          return rewriter.notifyMatchFailure(
+              op, "most minor memref dim must have unit stride");
       }
+
+      if (memType.getRank() > 1)
+        useDelinearization = true;
     }
 
     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
         loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
         op.getIndices());
-    auto baseOffsets = llvm::to_vector(op.getOffsets());
-    Value lastBaseOffset = baseOffsets.back();
+    auto loadOffsets = llvm::to_vector(op.getOffsets());
+    Value lastLoadOffset = loadOffsets.back();
+
+    // Compute the memref shape and linearized offsets once, outside the
+    // per-element loop.
+    SmallVector<OpFoldResult> baseShape;
+    Value linearizedOffsets;
+    if (useDelinearization) {
+      baseShape = memref::getMixedSizes(rewriter, loc, base);
+      linearizedOffsets = affine::AffineLinearizeIndexOp::create(
+          rewriter, loc, loadOffsets, baseShape, /*disjoint=*/false);
+    }
 
     Value result = op.getPassThru();
     BoolAttr nontemporalAttr = nullptr;
@@ -210,8 +235,23 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
       Value condition =
           vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
       Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
-      baseOffsets.back() =
-          rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
+
+      if (useDelinearization) {
+        // The gather index offsets the innermost dimension. Combine with
+        // the offsets by linearizing, adding the gather index, then
+        // delinearizing back to N-D indices:
+        //   flatIdx = linearize(offsets, shape) + idx
+        //   loadIndices = delinearize(flatIdx, shape)
+        Value flatIdx =
+            rewriter.createOrFold<arith::AddIOp>(loc, linearizedOffsets, index);
+        auto delinOp = affine::AffineDelinearizeIndexOp::create(
+            rewriter, loc, flatIdx, baseShape, /*hasOuterBound=*/true);
+        for (int64_t d = 0, rank = loadOffsets.size(); d < rank; ++d)
+          loadOffsets[d] = delinOp.getResult(d);
+      } else {
+        loadOffsets.back() =
+            rewriter.createOrFold<arith::AddIOp>(loc, lastLoadOffset, index);
+      }
 
       auto loadBuilder = [&](OpBuilder &b, Location loc) {
         Value extracted;
@@ -219,12 +259,12 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
           // `vector.load` does not support scalar result; emit a vector load
           // and extract the single result instead.
           Value load =
-              vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets,
+              vector::LoadOp::create(b, loc, elemVecTy, base, loadOffsets,
                                      nontemporalAttr, alignmentAttr);
           int64_t zeroIdx[1] = {0};
           extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
         } else {
-          extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets);
+          extracted = tensor::ExtractOp::create(b, loc, base, loadOffsets);
         }
 
         Value newResult =

diff  --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index d4ff603c2b887..59b13e300e5e5 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -54,11 +54,13 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
 // CHECK-DAG:     %[[C0:.+]]    = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.+]]    = arith.constant 1 : index
 // CHECK-DAG:     [[PTV0:%.+]]  = vector.extract [[PASS]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK:         %[[LIN:.+]]   = affine.linearize_index [%[[C0]], %[[C1]]] by
 // CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0, 0] : i1 from vector<2x3xi1>
 // CHECK-DAG:     [[IDX0:%.+]]  = vector.extract [[IDXVEC]][0, 0] : index from vector<2x3xindex>
-// CHECK-NEXT:    %[[OFF0:.+]]  = arith.addi [[IDX0]], %[[C1]] : index
-// CHECK-NEXT:    [[RES0:%.+]]  = scf.if [[M0]] -> (vector<3xf32>)
-// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[C0]], %[[OFF0]]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:         %[[FLAT0:.+]] = arith.addi %[[LIN]], [[IDX0]] : index
+// CHECK:         %[[DL0:.+]]:2 = affine.delinearize_index %[[FLAT0]] into
+// CHECK:         [[RES0:%.+]]  = scf.if [[M0]] -> (vector<3xf32>)
+// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<?x?xf32>, vector<1xf32>
 // CHECK-NEXT:      [[ELEM0:%.+]] = vector.extract [[LD0]][0] : f32 from vector<1xf32>
 // CHECK-NEXT:      [[INS0:%.+]]  = vector.insert [[ELEM0]], [[PTV0]] [0] : f32 into vector<3xf32>
 // CHECK-NEXT:      scf.yield [[INS0]] : vector<3xf32>
@@ -289,3 +291,72 @@ func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32>
   return %0 : vector<[2]xf32>
 }
+
+// Verify that gather on a 2D memref delinearizes the gather index.
+// With zero base offsets, the linearize and addi fold away.
+
+// CHECK-LABEL: @gather_memref_2d_delinearize
+// CHECK-SAME:    (%[[BASE:.+]]: memref<4x2xf32>,
+// CHECK-SAME:     %[[IDXVEC:.+]]: vector<4xi32>,
+// CHECK-SAME:     %[[MASK:.+]]: vector<4xi1>,
+// CHECK-SAME:     %[[PASS:.+]]: vector<4xf32>)
+// CHECK-DAG:     %[[IDXS:.+]] = arith.index_cast %[[IDXVEC]]
+//
+// CHECK-DAG:     %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
+// CHECK:         %[[DL0:.+]]:2 = affine.delinearize_index %[[IDX0]] into (4, 2)
+// CHECK:         scf.if
+// CHECK:           vector.load %[[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<4x2xf32>, vector<1xf32>
+//
+// CHECK:         %[[IDX1:.+]] = vector.extract %[[IDXS]][1]
+// CHECK:         affine.delinearize_index %[[IDX1]] into (4, 2)
+// CHECK:         scf.if
+// CHECK:           vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32>, vector<1xf32>
+//
+// CHECK:         %[[IDX2:.+]] = vector.extract %[[IDXS]][2]
+// CHECK:         affine.delinearize_index %[[IDX2]] into (4, 2)
+// CHECK:         scf.if
+// CHECK:           vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32>, vector<1xf32>
+//
+// CHECK:         %[[IDX3:.+]] = vector.extract %[[IDXS]][3]
+// CHECK:         affine.delinearize_index %[[IDX3]] into (4, 2)
+// CHECK:         scf.if
+// CHECK:           vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32>, vector<1xf32>
+func.func @gather_memref_2d_delinearize(
+    %base: memref<4x2xf32>,
+    %v: vector<4xi32>, %mask: vector<4xi1>,
+    %pass_thru: vector<4xf32>) -> vector<4xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru
+    : memref<4x2xf32>, vector<4xi32>,
+      vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// Verify that gather on a 2D memref with non-zero base offsets correctly
+// incorporates the offsets via linearize + add + delinearize.
+
+// CHECK-LABEL: @gather_memref_2d_delinearize_nonzero_offsets
+// CHECK-SAME:    (%[[BASE:.+]]: memref<4x2xf32>,
+// CHECK-SAME:     %[[OFF0:.+]]: index, %[[OFF1:.+]]: index,
+// CHECK-SAME:     %[[IDXVEC:.+]]: vector<2xi32>,
+// CHECK-SAME:     %[[MASK:.+]]: vector<2xi1>,
+// CHECK-SAME:     %[[PASS:.+]]: vector<2xf32>)
+// CHECK-DAG:     %[[IDXS:.+]] = arith.index_cast %[[IDXVEC]]
+// CHECK:         %[[LIN:.+]] = affine.linearize_index [%[[OFF0]], %[[OFF1]]] by (4, 2)
+// CHECK:         %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
+// CHECK:         %[[FLAT:.+]] = arith.addi %[[LIN]], %[[IDX0]]
+// CHECK:         %[[DL:.+]]:2 = affine.delinearize_index %[[FLAT]] into (4, 2)
+// CHECK:         scf.if
+// CHECK:           vector.load %[[BASE]][%[[DL]]#0, %[[DL]]#1]
+func.func @gather_memref_2d_delinearize_nonzero_offsets(
+    %base: memref<4x2xf32>,
+    %off0: index, %off1: index,
+    %v: vector<2xi32>, %mask: vector<2xi1>,
+    %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.gather %base[%off0, %off1][%v], %mask, %pass_thru
+    : memref<4x2xf32>, vector<2xi32>,
+      vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 148891f3f8d20..94205a6c26ba2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -164,11 +164,14 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
 
 // First shuffle + if ladder for row 0
 // CHECK: %[[ROW0_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[PASS_CAST]] [0, 1, 2]
+// CHECK: %[[DIM0:.*]] = memref.dim %[[BASE]], %[[C0]]
+// CHECK: %[[DIM1:.*]] = memref.dim %[[BASE]], %[[C1]]
 // CHECK: %[[MASK_0_0:.*]] = vector.extract %[[MASK]][0, 0]
 // CHECK: %[[IDX_0_0:.*]] = vector.extract %[[IDX]][0, 0]
 // CHECK: %[[OFF_0_0:.*]] = arith.addi %[[IDX_0_0]], %[[C1]]
+// CHECK: %[[DL_0_0:.*]]:2 = affine.delinearize_index %[[OFF_0_0]] into (%[[DIM0]], %[[DIM1]])
 // CHECK: %[[IF_0_0:.*]] = scf.if %[[MASK_0_0]] -> (vector<3xf32>) {
-// CHECK:   %[[LOAD_0_0:.*]] = vector.load %[[BASE]][%[[C0]], %[[OFF_0_0]]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:   %[[LOAD_0_0:.*]] = vector.load %[[BASE]][%[[DL_0_0]]#0, %[[DL_0_0]]#1] : memref<?x?xf32>, vector<1xf32>
 // CHECK:   %[[ELEM_0_0:.*]] = vector.extract %[[LOAD_0_0]][0] : f32
 // CHECK:   %[[INS_0_0:.*]] = vector.insert %[[ELEM_0_0]], %[[ROW0_INIT]] [0] : f32 into vector<3xf32>
 // CHECK:   scf.yield %[[INS_0_0]] : vector<3xf32>
@@ -179,6 +182,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
 // CHECK: %[[MASK_0_1:.*]] = vector.extract %[[MASK]][0, 1]
 // CHECK: %[[IDX_0_1:.*]] = vector.extract %[[IDX]][0, 1]
 // CHECK: %[[OFF_0_1:.*]] = arith.addi %[[IDX_0_1]], %[[C1]]
+// CHECK: %[[DL_0_1:.*]]:2 = affine.delinearize_index %[[OFF_0_1]] into (%[[DIM0]], %[[DIM1]])
 // CHECK: %[[IF_0_1:.*]] = scf.if %[[MASK_0_1]] -> (vector<3xf32>)
 
 // … (similar checks for the rest of row 0, then row 1)
@@ -190,6 +194,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
 // CHECK: %[[MASK_1_0:.*]] = vector.extract %[[MASK]][1, 0]
 // CHECK: %[[IDX_1_0:.*]] = vector.extract %[[IDX]][1, 0]
 // CHECK: %[[OFF_1_0:.*]] = arith.addi %[[IDX_1_0]], %[[C1]]
+// CHECK: %[[DL_1_0:.*]]:2 = affine.delinearize_index %[[OFF_1_0]] into
 // CHECK: %[[IF_1_0:.*]] = scf.if %[[MASK_1_0]] -> (vector<3xf32>)
 
 // … (similar checks for remaining row 1 inserts)

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index dd2dfc4f3e441..ff3520a286cc8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -783,8 +783,8 @@ struct TestVectorGatherLowering
            "loads";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, func::FuncDialect,
-                    memref::MemRefDialect, scf::SCFDialect,
+    registry.insert<affine::AffineDialect, arith::ArithDialect,
+                    func::FuncDialect, memref::MemRefDialect, scf::SCFDialect,
                     tensor::TensorDialect, vector::VectorDialect>();
   }
 


        


More information about the Mlir-commits mailing list