[Mlir-commits] [mlir] [mlir][vector] Skip redundant affine.apply when unrolling transfer ops. (PR #192700)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Fri Apr 17 10:08:15 PDT 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/192700
>From fd4dc8f9e53f7313277232564288701ab159fd6c Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 17 Apr 2026 12:28:23 -0400
Subject: [PATCH 1/2] [mlir][vector] Skip redundant affine.apply when unrolling
transfer ops.
Unrolling transfer ops generates affine.apply ops for every
non-broadcasted dimension, even when the offset being added is zero.
Skip these.
---
.../Vector/Transforms/VectorUnroll.cpp | 5 +--
.../Vector/vector-transfer-unroll.mlir | 33 +++++++++++++++++++
2 files changed, 36 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index ec08f01d2a4b9..cd54f56f48829 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -41,11 +41,12 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
SmallVector<Value> slicedIndices(indices);
for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
- if (isBroadcast(dim.value()))
+ int64_t elementOffset = elementOffsets[dim.index()];
+ if (isBroadcast(dim.value()) || elementOffset == 0)
continue;
unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
auto expr = getAffineDimExpr(0, builder.getContext()) +
- getAffineConstantExpr(elementOffsets[dim.index()], ctx);
+ getAffineConstantExpr(elementOffset, ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
slicedIndices[pos] =
affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 44601a4a47dda..456277121c86c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -383,3 +383,36 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
%res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref<?x?x?xf32>, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32>
return %res : vector<6x4xf32>
}
+
+// -----
+
+// Verify that no redundant affine.apply ops are generated for zero offsets
+// when the base indices are dynamic.
+
+// ALL-LABEL: func @transfer_read_unroll_dynamic_index(
+// ALL-SAME: %[[MEM:.*]]: memref<4x4xf32>,
+// ALL-SAME: %[[IDX:.*]]: index
+// ALL: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// ALL-NOT: affine.apply
+// ALL: %[[MAP:.*]] = affine.apply {{.*}}(%[[IDX]])
+// ALL: vector.transfer_read %[[MEM]][%[[MAP]], %[[IDX]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+func.func @transfer_read_unroll_dynamic_index(%mem : memref<4x4xf32>, %idx : index) -> vector<4x2xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%idx, %idx], %cf0 : memref<4x4xf32>, vector<4x2xf32>
+ return %res : vector<4x2xf32>
+}
+
+// -----
+
+// ALL-LABEL: func @transfer_write_unroll_dynamic_index(
+// ALL-SAME: %[[MEM:.*]]: memref<4x4xf32>,
+// ALL-SAME: %[[VEC:.*]]: vector<4x2xf32>,
+// ALL-SAME: %[[IDX:.*]]: index
+// ALL-NOT: affine.apply
+// ALL: vector.transfer_write %{{.*}}, %[[MEM]][%[[IDX]], %[[IDX]]] : vector<2x2xf32>, memref<4x4xf32>
+// ALL: %[[MAP:.*]] = affine.apply {{.*}}(%[[IDX]])
+// ALL: vector.transfer_write %{{.*}}, %[[MEM]][%[[MAP]], %[[IDX]]] : vector<2x2xf32>, memref<4x4xf32>
+func.func @transfer_write_unroll_dynamic_index(%mem : memref<4x4xf32>, %vec : vector<4x2xf32>, %idx : index) {
+ vector.transfer_write %vec, %mem[%idx, %idx] : vector<4x2xf32>, memref<4x4xf32>
+ return
+}
>From 2b36a244bf250e7864a38d7ebfc0e7c153d0cb39 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Fri, 17 Apr 2026 12:56:24 -0400
Subject: [PATCH 2/2] Export silceTransferIndices
---
.../Dialect/Vector/Transforms/VectorRewritePatterns.h | 6 ++++++
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 9 +++------
2 files changed, 9 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 00b321fecc09f..7d6d565d5a4f4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -322,6 +322,12 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
+/// Compute indices for a transfer op slice.
+SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
+ ArrayRef<Value> indices,
+ AffineMap permutationMap, Location loc,
+ OpBuilder &builder);
+
/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
/// outermost dimension of the operand.
void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index cd54f56f48829..d7730cfd78ab7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -26,12 +26,9 @@
using namespace mlir;
using namespace mlir::vector;
-/// Compute the indices of the slice `index` for a transfer op.
-static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
- ArrayRef<Value> indices,
- AffineMap permutationMap,
- Location loc,
- OpBuilder &builder) {
+SmallVector<Value> mlir::vector::sliceTransferIndices(
+ ArrayRef<int64_t> elementOffsets, ArrayRef<Value> indices,
+ AffineMap permutationMap, Location loc, OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
auto isBroadcast = [](AffineExpr expr) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
More information about the Mlir-commits
mailing list