[Mlir-commits] [mlir] [mlir][tensor] add tensor insert/extract op folders (PR #142458)

Mehdi Amini llvmlistbot at llvm.org
Tue Jun 3 15:09:58 PDT 2025


================
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
 // InsertOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+/// Pattern to fold an insert op of a constant destination and scalar to a new
+/// constant.
+///
+/// Example:
+/// ```
+///   %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+///   %c0 = arith.constant 0 : index
+///   %c4_f32 = arith.constant 4.0 : f32
+///   %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+///   %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// ```
+class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    // Requires a ranked tensor type.
+    auto destType =
+        llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
+    if (!destType)
+      return failure();
+
+    // Pattern requires constant indices
+    SmallVector<uint64_t, 8> indices;
+    for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
+      auto indiceAttr = dyn_cast<Attribute>(indice);
+      if (!indiceAttr)
+        return failure();
+      indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
+    }
+
+    // Requires a constant scalar to insert
+    OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
+    Attribute scalarAttr = dyn_cast<Attribute>(scalar);
+    if (!scalarAttr)
+      return failure();
+
+    if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
+            insertOp.getDest().getDefiningOp())) {
+      if (auto sourceAttr =
+              llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
+        // Update the attribute at the inserted index.
+        auto sourceValues = sourceAttr.getValues<Attribute>();
+        auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
+        std::vector<Attribute> updatedValues;
----------------
joker-eph wrote:

I'm saying we should be able to avoid using individual attribues per-element in the common case indeed.
We may be lacking (possibly templated?) helpers to do this conveniently.

https://github.com/llvm/llvm-project/pull/142458


More information about the Mlir-commits mailing list