[Mlir-commits] [mlir] 305dc4e - [mlir][vector] Lower vector.gather with delinearization approach (#184706)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 13 14:58:44 PDT 2026
Author: Han-Chung Wang
Date: 2026-03-13T14:58:40-07:00
New Revision: 305dc4e5a9a623b8d1effff5dddcd5fdfe56d6a3
URL: https://github.com/llvm/llvm-project/commit/305dc4e5a9a623b8d1effff5dddcd5fdfe56d6a3
DIFF: https://github.com/llvm/llvm-project/commit/305dc4e5a9a623b8d1effff5dddcd5fdfe56d6a3.diff
LOG: [mlir][vector] Lower vector.gather with delinearization approach (#184706)
The old implementation did not handle n-D memref correctly, which leads
to wrong access. E.g.,
```
func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
```
is lowered to
```
func.func @gather_memref_2d(%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = ub.poison : vector<2x3xf32>
%1 = vector.extract %arg3[0] : vector<3xf32> from vector<2x3xf32>
%2 = vector.extract %arg2[0, 0] : i1 from vector<2x3xi1>
%3 = vector.extract %arg1[0, 0] : index from vector<2x3xindex>
%4 = arith.addi %3, %c1 : index
%5 = scf.if %2 -> (vector<3xf32>) {
%29 = vector.load %arg0[%c0, %4] : memref<?x?xf32>, vector<1xf32>
%30 = vector.extract %29[0] : f32 from vector<1xf32>
%31 = vector.insert %30, %1 [0] : f32 into vector<3xf32>
scf.yield %31 : vector<3xf32>
} else {
scf.yield %1 : vector<3xf32>
}
// ...
```
The revision fixes it by by using `linearize(baseOffsets) + gatherIndex`
followed by `delinearize` to recover correct `n-D` load indices. This is
applied unconditionally for all rank > 1 memrefs.
Note that it enables the cases with strideds because we use
delinearization approach.
---------
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
mlir/test/Dialect/Vector/vector-gather-lowering.mlir
mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 6bc8347bc6f76..7194d41d60df7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -163,6 +164,13 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
+///
+/// For multi-dimensional memrefs (rank > 1), the gather index is combined
+/// with the offsets via linearize-then-delinearize to produce correct
+/// N-D load indices:
+/// idx = indices[i]
+/// flatIdx = linearize(offsets, memrefShape) + idx
+/// loadIndices = delinearize(flatIdx, memrefShape)
struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
using Base::Base;
@@ -183,22 +191,39 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();
- // vector.load requires the most minor memref dim to have unit stride
- // (unless reading exactly 1 element)
+ // For multi-dimensional memrefs, use linearize+delinearize to compute
+ // correct N-D load indices from the 1-D gather index.
+ bool useDelinearization = false;
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+ // vector.load requires the most minor memref dim to have unit stride
+ // (unless reading exactly 1 element).
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
if (stridesAttr.getStrides().back() != 1 &&
resultTy.getNumElements() != 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "most minor memref dim must have unit stride");
}
+
+ if (memType.getRank() > 1)
+ useDelinearization = true;
}
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndices());
- auto baseOffsets = llvm::to_vector(op.getOffsets());
- Value lastBaseOffset = baseOffsets.back();
+ auto loadOffsets = llvm::to_vector(op.getOffsets());
+ Value lastLoadOffset = loadOffsets.back();
+
+ // Compute the memref shape and linearized offsets once, outside the
+ // per-element loop.
+ SmallVector<OpFoldResult> baseShape;
+ Value linearizedOffsets;
+ if (useDelinearization) {
+ baseShape = memref::getMixedSizes(rewriter, loc, base);
+ linearizedOffsets = affine::AffineLinearizeIndexOp::create(
+ rewriter, loc, loadOffsets, baseShape, /*disjoint=*/false);
+ }
Value result = op.getPassThru();
BoolAttr nontemporalAttr = nullptr;
@@ -210,8 +235,23 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condition =
vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
- baseOffsets.back() =
- rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
+
+ if (useDelinearization) {
+ // The gather index offsets the innermost dimension. Combine with
+ // the offsets by linearizing, adding the gather index, then
+ // delinearizing back to N-D indices:
+ // flatIdx = linearize(offsets, shape) + idx
+ // loadIndices = delinearize(flatIdx, shape)
+ Value flatIdx =
+ rewriter.createOrFold<arith::AddIOp>(loc, linearizedOffsets, index);
+ auto delinOp = affine::AffineDelinearizeIndexOp::create(
+ rewriter, loc, flatIdx, baseShape, /*hasOuterBound=*/true);
+ for (int64_t d = 0, rank = loadOffsets.size(); d < rank; ++d)
+ loadOffsets[d] = delinOp.getResult(d);
+ } else {
+ loadOffsets.back() =
+ rewriter.createOrFold<arith::AddIOp>(loc, lastLoadOffset, index);
+ }
auto loadBuilder = [&](OpBuilder &b, Location loc) {
Value extracted;
@@ -219,12 +259,12 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
// `vector.load` does not support scalar result; emit a vector load
// and extract the single result instead.
Value load =
- vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets,
+ vector::LoadOp::create(b, loc, elemVecTy, base, loadOffsets,
nontemporalAttr, alignmentAttr);
int64_t zeroIdx[1] = {0};
extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
} else {
- extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets);
+ extracted = tensor::ExtractOp::create(b, loc, base, loadOffsets);
}
Value newResult =
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index d4ff603c2b887..59b13e300e5e5 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -54,11 +54,13 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: [[PTV0:%.+]] = vector.extract [[PASS]][0] : vector<3xf32> from vector<2x3xf32>
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[C0]], %[[C1]]] by
// CHECK-DAG: [[M0:%.+]] = vector.extract [[MASK]][0, 0] : i1 from vector<2x3xi1>
// CHECK-DAG: [[IDX0:%.+]] = vector.extract [[IDXVEC]][0, 0] : index from vector<2x3xindex>
-// CHECK-NEXT: %[[OFF0:.+]] = arith.addi [[IDX0]], %[[C1]] : index
-// CHECK-NEXT: [[RES0:%.+]] = scf.if [[M0]] -> (vector<3xf32>)
-// CHECK-NEXT: [[LD0:%.+]] = vector.load [[BASE]][%[[C0]], %[[OFF0]]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[FLAT0:.+]] = arith.addi %[[LIN]], [[IDX0]] : index
+// CHECK: %[[DL0:.+]]:2 = affine.delinearize_index %[[FLAT0]] into
+// CHECK: [[RES0:%.+]] = scf.if [[M0]] -> (vector<3xf32>)
+// CHECK-NEXT: [[LD0:%.+]] = vector.load [[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<?x?xf32>, vector<1xf32>
// CHECK-NEXT: [[ELEM0:%.+]] = vector.extract [[LD0]][0] : f32 from vector<1xf32>
// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[ELEM0]], [[PTV0]] [0] : f32 into vector<3xf32>
// CHECK-NEXT: scf.yield [[INS0]] : vector<3xf32>
@@ -289,3 +291,72 @@ func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32>
return %0 : vector<[2]xf32>
}
+
+// Verify that gather on a 2D memref delinearizes the gather index.
+// With zero base offsets, the linearize and addi fold away.
+
+// CHECK-LABEL: @gather_memref_2d_delinearize
+// CHECK-SAME: (%[[BASE:.+]]: memref<4x2xf32>,
+// CHECK-SAME: %[[IDXVEC:.+]]: vector<4xi32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>,
+// CHECK-SAME: %[[PASS:.+]]: vector<4xf32>)
+// CHECK-DAG: %[[IDXS:.+]] = arith.index_cast %[[IDXVEC]]
+//
+// CHECK-DAG: %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
+// CHECK: %[[DL0:.+]]:2 = affine.delinearize_index %[[IDX0]] into (4, 2)
+// CHECK: scf.if
+// CHECK: vector.load %[[BASE]][%[[DL0]]#0, %[[DL0]]#1] : memref<4x2xf32>, vector<1xf32>
+//
+// CHECK: %[[IDX1:.+]] = vector.extract %[[IDXS]][1]
+// CHECK: affine.delinearize_index %[[IDX1]] into (4, 2)
+// CHECK: scf.if
+// CHECK: vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32>, vector<1xf32>
+//
+// CHECK: %[[IDX2:.+]] = vector.extract %[[IDXS]][2]
+// CHECK: affine.delinearize_index %[[IDX2]] into (4, 2)
+// CHECK: scf.if
+// CHECK: vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32>, vector<1xf32>
+//
+// CHECK: %[[IDX3:.+]] = vector.extract %[[IDXS]][3]
+// CHECK: affine.delinearize_index %[[IDX3]] into (4, 2)
+// CHECK: scf.if
+// CHECK: vector.load %[[BASE]][%{{.+}}, %{{.+}}] : memref<4x2xf32>, vector<1xf32>
+func.func @gather_memref_2d_delinearize(
+ %base: memref<4x2xf32>,
+ %v: vector<4xi32>, %mask: vector<4xi1>,
+ %pass_thru: vector<4xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru
+ : memref<4x2xf32>, vector<4xi32>,
+ vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// Verify that gather on a 2D memref with non-zero base offsets correctly
+// incorporates the offsets via linearize + add + delinearize.
+
+// CHECK-LABEL: @gather_memref_2d_delinearize_nonzero_offsets
+// CHECK-SAME: (%[[BASE:.+]]: memref<4x2xf32>,
+// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index,
+// CHECK-SAME: %[[IDXVEC:.+]]: vector<2xi32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2xi1>,
+// CHECK-SAME: %[[PASS:.+]]: vector<2xf32>)
+// CHECK-DAG: %[[IDXS:.+]] = arith.index_cast %[[IDXVEC]]
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[OFF0]], %[[OFF1]]] by (4, 2)
+// CHECK: %[[IDX0:.+]] = vector.extract %[[IDXS]][0]
+// CHECK: %[[FLAT:.+]] = arith.addi %[[LIN]], %[[IDX0]]
+// CHECK: %[[DL:.+]]:2 = affine.delinearize_index %[[FLAT]] into (4, 2)
+// CHECK: scf.if
+// CHECK: vector.load %[[BASE]][%[[DL]]#0, %[[DL]]#1]
+func.func @gather_memref_2d_delinearize_nonzero_offsets(
+ %base: memref<4x2xf32>,
+ %off0: index, %off1: index,
+ %v: vector<2xi32>, %mask: vector<2xi1>,
+ %pass_thru: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.gather %base[%off0, %off1][%v], %mask, %pass_thru
+ : memref<4x2xf32>, vector<2xi32>,
+ vector<2xi1>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 148891f3f8d20..94205a6c26ba2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -164,11 +164,14 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
// First shuffle + if ladder for row 0
// CHECK: %[[ROW0_INIT:.*]] = vector.shuffle %[[PASS_CAST]], %[[PASS_CAST]] [0, 1, 2]
+// CHECK: %[[DIM0:.*]] = memref.dim %[[BASE]], %[[C0]]
+// CHECK: %[[DIM1:.*]] = memref.dim %[[BASE]], %[[C1]]
// CHECK: %[[MASK_0_0:.*]] = vector.extract %[[MASK]][0, 0]
// CHECK: %[[IDX_0_0:.*]] = vector.extract %[[IDX]][0, 0]
// CHECK: %[[OFF_0_0:.*]] = arith.addi %[[IDX_0_0]], %[[C1]]
+// CHECK: %[[DL_0_0:.*]]:2 = affine.delinearize_index %[[OFF_0_0]] into (%[[DIM0]], %[[DIM1]])
// CHECK: %[[IF_0_0:.*]] = scf.if %[[MASK_0_0]] -> (vector<3xf32>) {
-// CHECK: %[[LOAD_0_0:.*]] = vector.load %[[BASE]][%[[C0]], %[[OFF_0_0]]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[LOAD_0_0:.*]] = vector.load %[[BASE]][%[[DL_0_0]]#0, %[[DL_0_0]]#1] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[ELEM_0_0:.*]] = vector.extract %[[LOAD_0_0]][0] : f32
// CHECK: %[[INS_0_0:.*]] = vector.insert %[[ELEM_0_0]], %[[ROW0_INIT]] [0] : f32 into vector<3xf32>
// CHECK: scf.yield %[[INS_0_0]] : vector<3xf32>
@@ -179,6 +182,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
// CHECK: %[[MASK_0_1:.*]] = vector.extract %[[MASK]][0, 1]
// CHECK: %[[IDX_0_1:.*]] = vector.extract %[[IDX]][0, 1]
// CHECK: %[[OFF_0_1:.*]] = arith.addi %[[IDX_0_1]], %[[C1]]
+// CHECK: %[[DL_0_1:.*]]:2 = affine.delinearize_index %[[OFF_0_1]] into (%[[DIM0]], %[[DIM1]])
// CHECK: %[[IF_0_1:.*]] = scf.if %[[MASK_0_1]] -> (vector<3xf32>)
// … (similar checks for the rest of row 0, then row 1)
@@ -190,6 +194,7 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
// CHECK: %[[MASK_1_0:.*]] = vector.extract %[[MASK]][1, 0]
// CHECK: %[[IDX_1_0:.*]] = vector.extract %[[IDX]][1, 0]
// CHECK: %[[OFF_1_0:.*]] = arith.addi %[[IDX_1_0]], %[[C1]]
+// CHECK: %[[DL_1_0:.*]]:2 = affine.delinearize_index %[[OFF_1_0]] into
// CHECK: %[[IF_1_0:.*]] = scf.if %[[MASK_1_0]] -> (vector<3xf32>)
// … (similar checks for remaining row 1 inserts)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index dd2dfc4f3e441..ff3520a286cc8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -783,8 +783,8 @@ struct TestVectorGatherLowering
"loads";
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, func::FuncDialect,
- memref::MemRefDialect, scf::SCFDialect,
+ registry.insert<affine::AffineDialect, arith::ArithDialect,
+ func::FuncDialect, memref::MemRefDialect, scf::SCFDialect,
tensor::TensorDialect, vector::VectorDialect>();
}
More information about the Mlir-commits
mailing list