[Mlir-commits] [mlir] [mlir][Vector] Support `vector.extract(xfer_read)` folding with dynamic indices (PR #143269)
Diego Caballero
llvmlistbot at llvm.org
Mon Jun 16 11:59:49 PDT 2025
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/143269
>From 6bc6e3a3b714485db9d8a39c61300603eb31bf68 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Sat, 7 Jun 2025 12:45:44 +0000
Subject: [PATCH 1/2] [mlir][Vector] Support `xfer_read(vector.extract))`
folding with dynamic indices
This PR is part of the step to remove `vector.extractelement` and
`vector.insertelement` ops. It adds support for folding
`vector.transfer_read(vector.extract) -> memref.load` with dynamic
indices, which is currently supported by `vector.extractelement`.
---
.../Transforms/VectorTransferOpTransforms.cpp | 24 +++++++++++----
.../scalar-vector-transfer-to-memref.mlir | 30 +++++++++++++++++++
2 files changed, 49 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 7dbb7a334fe62..36197eb1caeb1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -886,12 +886,26 @@ class RewriteScalarExtractOfTransferRead
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
- assert(isa<Attribute>(pos) && "Unexpected non-constant index");
- int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, extractOp.getLoc(),
- rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+
+ // Compute affine expression `newIndices[idx] + pos` where `pos` can be
+ // either a constant or a value.
+ OpFoldResult ofr;
+ if (auto attr = dyn_cast<Attribute>(pos)) {
+ int64_t offset = cast<IntegerAttr>(attr).getInt();
+ ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(),
+ rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+ } else {
+ Value dynamicOffset = cast<Value>(pos);
+ AffineExpr sym0, sym1;
+ bindSymbols(rewriter.getContext(), sym0, sym1);
+ ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(), sym0 + sym1,
+ {newIndices[idx], dynamicOffset});
+ }
+
+ // Update the corresponding index with the folded result.
if (auto value = dyn_cast<Value>(ofr)) {
newIndices[idx] = value;
} else {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 52b0fdee184f6..9f10063a75092 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -148,3 +148,33 @@ func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32>
return %1 : vector<16xf32>
}
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_1d_extract_dynamic(
+// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>, %[[M_IDX:.*]]: index, %[[E_IDX:.*]]: index
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[E_IDX]]]
+// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[APPLY]]]
+func.func @transfer_read_1d_extract_dynamic(%m: memref<?xf32>, %idx: index,
+ %offset: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %vec = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<5xf32>
+ %elem = vector.extract %vec[%offset] : f32 from vector<5xf32>
+ return %elem : f32
+}
+
+// -----
+
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_2d_extract_dynamic(
+// CHECK-SAME: %[[MEMREF:.*]]: memref<?x?xf32>, %[[M_IDX:.*]]: index, %[[ROW:.*]]: index, %[[COL:.*]]: index
+// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[ROW]]]
+// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[COL]]]
+// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]]
+func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %idx: index,
+ %row_offset: index, %col_offset: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %vec = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<10x5xf32>
+ %elem = vector.extract %vec[%row_offset, %col_offset] : f32 from vector<10x5xf32>
+ return %elem : f32
+}
>From f98e3480519924dd8f72b4155b3bc5b2807ecc02 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Mon, 16 Jun 2025 18:58:02 +0000
Subject: [PATCH 2/2] Feedback
---
.../Vector/Transforms/VectorTransferOpTransforms.cpp | 10 +++++-----
.../Vector/scalar-vector-transfer-to-memref.mlir | 10 +++++-----
2 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 36197eb1caeb1..384717aeca665 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -890,27 +890,27 @@ class RewriteScalarExtractOfTransferRead
// Compute affine expression `newIndices[idx] + pos` where `pos` can be
// either a constant or a value.
- OpFoldResult ofr;
+ OpFoldResult composedIdx;
if (auto attr = dyn_cast<Attribute>(pos)) {
int64_t offset = cast<IntegerAttr>(attr).getInt();
- ofr = affine::makeComposedFoldedAffineApply(
+ composedIdx = affine::makeComposedFoldedAffineApply(
rewriter, extractOp.getLoc(),
rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
} else {
Value dynamicOffset = cast<Value>(pos);
AffineExpr sym0, sym1;
bindSymbols(rewriter.getContext(), sym0, sym1);
- ofr = affine::makeComposedFoldedAffineApply(
+ composedIdx = affine::makeComposedFoldedAffineApply(
rewriter, extractOp.getLoc(), sym0 + sym1,
{newIndices[idx], dynamicOffset});
}
// Update the corresponding index with the folded result.
- if (auto value = dyn_cast<Value>(ofr)) {
+ if (auto value = dyn_cast<Value>(composedIdx)) {
newIndices[idx] = value;
} else {
newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
- extractOp.getLoc(), *getConstantIntValue(ofr));
+ extractOp.getLoc(), *getConstantIntValue(composedIdx));
}
}
if (isa<MemRefType>(xferOp.getBase().getType())) {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 9f10063a75092..7a1d6b3a8344a 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -167,14 +167,14 @@ func.func @transfer_read_1d_extract_dynamic(%m: memref<?xf32>, %idx: index,
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK-LABEL: func @transfer_read_2d_extract_dynamic(
-// CHECK-SAME: %[[MEMREF:.*]]: memref<?x?xf32>, %[[M_IDX:.*]]: index, %[[ROW:.*]]: index, %[[COL:.*]]: index
-// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[ROW]]]
-// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[COL]]]
+// CHECK-SAME: %[[MEMREF:.*]]: memref<?x?xf32>, %[[ROW_IDX:.*]]: index, %[[COL_IDX:.*]]: index, %[[ROW_OFFSET:.*]]: index, %[[COL_OFFSET:.*]]: index
+// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[ROW_IDX]], %[[ROW_OFFSET]]]
+// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[COL_IDX]], %[[COL_OFFSET]]]
// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]]
-func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %idx: index,
+func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %row_idx: index, %col_idx: index,
%row_offset: index, %col_offset: index) -> f32 {
%cst = arith.constant 0.0 : f32
- %vec = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<10x5xf32>
+ %vec = vector.transfer_read %m[%row_idx, %col_idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<10x5xf32>
%elem = vector.extract %vec[%row_offset, %col_offset] : f32 from vector<10x5xf32>
return %elem : f32
}
More information about the Mlir-commits
mailing list