[Mlir-commits] [mlir] [mlir][tensor] add tensor insert/extract op folders (PR #142458)
Mehdi Amini
llvmlistbot at llvm.org
Tue Jun 3 10:13:21 PDT 2025
================
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
+namespace {
+
+/// Pattern to fold an insert op of a constant destination and scalar to a new
+/// constant.
+///
+/// Example:
+/// ```
+/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// %c0 = arith.constant 0 : index
+/// %c4_f32 = arith.constant 4.0 : f32
+/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// ```
+class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ // Requires a ranked tensor type.
+ auto destType =
+ llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
+ if (!destType)
+ return failure();
+
+ // Pattern requires constant indices
+ SmallVector<uint64_t, 8> indices;
+ for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
+ auto indiceAttr = dyn_cast<Attribute>(indice);
+ if (!indiceAttr)
+ return failure();
+ indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
+ }
+
+ // Requires a constant scalar to insert
+ OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
+ Attribute scalarAttr = dyn_cast<Attribute>(scalar);
+ if (!scalarAttr)
+ return failure();
+
+ if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
+ insertOp.getDest().getDefiningOp())) {
+ if (auto sourceAttr =
+ llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
+ // Update the attribute at the inserted index.
+ auto sourceValues = sourceAttr.getValues<Attribute>();
+ auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
+ std::vector<Attribute> updatedValues;
----------------
joker-eph wrote:
That is an unfortunate slow path for what will be converted ultimately to a vector of int or float.
https://github.com/llvm/llvm-project/pull/142458
More information about the Mlir-commits
mailing list