[Mlir-commits] [mlir] [mlir][vector] Support complete folding in single pass for vector.insert/vector.extract (PR #142124)
Yang Bai
llvmlistbot at llvm.org
Sat May 31 08:34:09 PDT 2025
https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/142124
>From 32b9f8cc3b07770ca409f93206485e949d856de8 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Fri, 30 May 2025 03:42:09 -0700
Subject: [PATCH 1/3] [mlir][vector] Support complete folding in single pass
for vector.insert/vector.extract
After successfully converting dynamic indices to static indices, continue folding
instead of returning early, allowing subsequent fold operations to be executed.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 890a5e9e5c9b4..2e0c917b2139d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2062,6 +2062,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
+ // Return the original result to indicate an in-place folding happened.
return op.getResult();
}
return {};
@@ -2148,8 +2149,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold `arith.constant` indices into the `vector.extract` operation. Make
// sure that patterns requiring constant indices are added after this fold.
SmallVector<Value> operands = {getVector()};
- if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
- return val;
+ auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
@@ -2171,7 +2172,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
- return OpFoldResult();
+
+ return inplaceFolded;
}
namespace {
@@ -3150,8 +3152,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// Fold `arith.constant` indices into the `vector.insert` operation. Make
// sure that patterns requiring constant indices are added after this fold.
SmallVector<Value> operands = {getValueToStore(), getDest()};
- if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
- return val;
+ auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
@@ -3161,7 +3163,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
return res;
}
- return {};
+ return inplaceFolded;
}
//===----------------------------------------------------------------------===//
>From 7586aa1727bcba311ed4a6c0d1092da428c7f163 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Sat, 31 May 2025 08:28:43 -0700
Subject: [PATCH 2/3] refine comment
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2e0c917b2139d..efcd75c5c1bb7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2146,8 +2146,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return getVector();
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
return res;
- // Fold `arith.constant` indices into the `vector.extract` operation. Make
- // sure that patterns requiring constant indices are added after this fold.
+ // Fold `arith.constant` indices into the `vector.extract` operation.
+ // Do not stop here as this fold may enable subsequent folds that require
+ // constant indices.
SmallVector<Value> operands = {getVector()};
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
@@ -3149,8 +3150,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getValueToStoreType() == getType())
return getValueToStore();
- // Fold `arith.constant` indices into the `vector.insert` operation. Make
- // sure that patterns requiring constant indices are added after this fold.
+ // Fold `arith.constant` indices into the `vector.insert` operation.
+ // Do not stop here as this fold may enable subsequent folds that require
+ // constant indices.
SmallVector<Value> operands = {getValueToStore(), getDest()};
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
>From 22453c9f6309b721bbd747c2194c88b126ee78ae Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Sat, 31 May 2025 08:33:37 -0700
Subject: [PATCH 3/3] add test
---
mlir/test/Dialect/Vector/constant-fold.mlir | 26 +++++++++++++++++++++
1 file changed, 26 insertions(+)
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 66c91d6b2041b..b23e76f590d78 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -11,3 +11,29 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
%2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
return %2 : vector<4x4xf16>
}
+
+// -----
+
+// CHECK-LABEL: fold_extract_in_single_pass
+// CHECK-SAME: (%{{.*}}: vector<4xf16>, %[[ARG1:.+]]: f16)
+func.func @fold_extract_in_single_pass(%arg0: vector<4xf16>, %arg1: f16) -> f16 {
+ %0 = vector.insert %arg1, %arg0 [1] : f16 into vector<4xf16>
+ %c1 = arith.constant 1 : index
+ // Verify that the fold is finished in a single pass even if the index is dynamic.
+ %1 = vector.extract %0[%c1] : f16 from vector<4xf16>
+ // CHECK: return %[[ARG1]] : f16
+ return %1 : f16
+}
+
+// -----
+
+// CHECK-LABEL: fold_insert_in_single_pass
+func.func @fold_insert_in_single_pass() -> vector<2xf16> {
+ %cst = arith.constant dense<0.000000e+00> : vector<2xf16>
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2.5 : f16
+ // Verify that the fold is finished in a single pass even if the index is dynamic.
+ // CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
+ %0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
+ return %0 : vector<2xf16>
+}
More information about the Mlir-commits
mailing list