[Mlir-commits] [mlir] [mlir][Vector][NFC] Run `extractInsertFoldConstantOp` earlier in the folder (PR #140814)
Diego Caballero
llvmlistbot at llvm.org
Tue May 20 15:54:55 PDT 2025
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/140814
This PR moves `extractInsertFoldConstantOp` earlier in the folder lists of `vector.extract` and `vector.insert`. Many folders require having non-dynamic indices so `extractInsertFoldConstantOp` is a requirement for them to trigger.
>From 103342afcdc52d2e10d06b80d2bd3cfdeb1a9a6d Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 20 May 2025 22:52:38 +0000
Subject: [PATCH] [mlir][Vector] Run `extractInsertFoldConstantOp` earlier in
the folder
This PR moves `extractInsertFoldConstantOp` earlier in the folder lists of
`vector.extract` and `vector.insert`. Many folders require having non-dynamic
indices so `extractInsertFoldConstantOp` is a requirement for them to trigger.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bbb366b01fa6e..cf2df1f24f91f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2143,11 +2143,16 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// mismatch).
if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
return getVector();
+ if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+ return res;
+ // Fold `arith.constant` indices into the `vector.extract` operation. Make
+ // sure that patterns requiring constant indices are added after this fold.
+ SmallVector<Value> operands = {getVector()};
+ if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
+ return val;
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
- if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
- return res;
if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
@@ -2166,9 +2171,6 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
- SmallVector<Value> operands = {getVector()};
- if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
- return val;
return OpFoldResult();
}
@@ -3145,6 +3147,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getValueToStoreType() == getType())
return getValueToStore();
+ // Fold `arith.constant` indices into the `vector.insert` operation. Make
+ // sure that patterns requiring constant indices are added after this fold.
SmallVector<Value> operands = {getValueToStore(), getDest()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
More information about the Mlir-commits
mailing list