[Mlir-commits] [mlir] [mlir][tensor] add tensor insert/extract op folders (PR #142458)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 2 12:13:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (asraa)

<details>
<summary>Changes</summary>

Adds a few canonicalizers, folders, and rewrite patterns to tensor ops:

* tensor.insert folder: insert into a constant is replaced with a new constant
* tensor.extract folder: extract from a parent tensor that was inserted at the same indices is folded into the inserted value
* rewrite pattern added that replaces an extract of a collapse shape with an extract of the source tensor (requires static source dimensions)


---
Full diff: https://github.com/llvm/llvm-project/pull/142458.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (+4) 
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+1) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+165) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+26-3) 
- (added) mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir (+31) 
- (modified) mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp (+13) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index eb550bb469b9f..e8e1342ef36fd 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns(
           return false;
         });
 
+/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
+/// source tensor.
+void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
+
 } // namespace tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..c0885a3763827 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 30ca20fc0d883..f2a7220b4bedc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -33,10 +34,12 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include <algorithm>
 #include <optional>
+#include <vector>
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
   }
 };
 
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
+/// tensor<12xf64>
+/// %extracted_element = tensor.extract %val[%c10] :
+/// tensor<12xf64>
+///
+/// to
+///
+/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
+struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const final {
+    auto collapseOp =
+        extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseOp)
+      return failure();
+    if (!collapseOp.getSrcType().hasStaticShape())
+      return failure();
+
+    auto sourceSizes = collapseOp.getSrcType().getShape();
+
+    SmallVector<Value> indices(extractOp.getIndices().begin(),
+                               extractOp.getIndices().end());
+    SmallVector<Value> sourceIndices;
+    for (auto [index, group] :
+         llvm::zip(indices, collapseOp.getReassociationIndices())) {
+      assert(!group.empty() && "association indices groups cannot be empty");
+      auto groupSize = group.size();
+
+      if (groupSize == 1) {
+        sourceIndices.push_back(index);
+        continue;
+      }
+
+      SmallVector<int64_t> basis =
+          llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+      auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+          extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
+      llvm::append_range(sourceIndices, delinearize.getResults());
+    }
+    if (collapseOp.getReassociationIndices().empty()) {
+      auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+      int64_t srcRank =
+          cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
+      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+          rewriter, extractOp.getLoc(), zeroAffineMap,
+          ArrayRef<OpFoldResult>{});
+      for (int64_t i = 0; i < srcRank; i++) {
+        sourceIndices.push_back(
+            getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+        extractOp, collapseOp.getSrc(), sourceIndices);
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getAsmResultNames(
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
   return success();
 }
 
+/// If we have an ExtractOp consuming an InsertOp with the same
+/// indices, we can return the InsertOp's scalar directly.
+// TODO: This only checks the immediate producer; extend to go up the
+// insert/extract chain if the slices are disjoint.
+static Value foldExtractAfterInsert(ExtractOp extractOp) {
+  auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
+
+  auto isSame = [](Value a, Value b) {
+    return getAsOpFoldResult(a) == getAsOpFoldResult(b);
+  };
+  if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
+      llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
+    return insertOp.getScalar();
+
+  return {};
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   if (Attribute tensor = adaptor.getTensor()) {
     // If this is a splat elements attribute, simply return the value.
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
       return elementsAttr.getValues<Attribute>()[indices];
   }
 
+  if (Value result = foldExtractAfterInsert(*this))
+    return result;
+
   return {};
 }
 
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExtractFromTensorCast>(context);
 }
 
+void mlir::tensor::populateFoldCollapseExtractPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ExtractFromCollapseShape>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // FromElementsOp
 //===----------------------------------------------------------------------===//
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
 // InsertOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+/// Pattern to fold an insert op of a constant destination and scalar to a new
+/// constant.
+///
+/// Example:
+/// ```
+///   %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+///   %c0 = arith.constant 0 : index
+///   %c4_f32 = arith.constant 4.0 : f32
+///   %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+///   %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// ```
+class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    // Requires a ranked tensor type.
+    auto destType =
+        llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
+    if (!destType)
+      return failure();
+
+    // Pattern requires constant indices
+    SmallVector<uint64_t, 8> indices;
+    for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
+      auto indiceAttr = dyn_cast<Attribute>(indice);
+      if (!indiceAttr)
+        return failure();
+      indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
+    }
+
+    // Requires a constant scalar to insert
+    OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
+    Attribute scalarAttr = dyn_cast<Attribute>(scalar);
+    if (!scalarAttr)
+      return failure();
+
+    if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
+            insertOp.getDest().getDefiningOp())) {
+      if (auto sourceAttr =
+              llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
+        // Update the attribute at the inserted index.
+        auto sourceValues = sourceAttr.getValues<Attribute>();
+        auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
+        std::vector<Attribute> updatedValues;
+        updatedValues.reserve(sourceAttr.getNumElements());
+        for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
+          updatedValues.push_back(i == flattenedIndex ? scalarAttr
+                                                      : sourceValues[i]);
+        }
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            insertOp, sourceAttr.getType(),
+            DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
+                                   updatedValues));
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+
+} // namespace
+
 void InsertOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "inserted");
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<InsertOpConstantFold>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // GenerateOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index cdcd7f305d2d9..0abec7e01d184 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12x
 // -----
 
 // CHECK-LABEL: func @fold_extract
-func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
+func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>, i32) {
   %const_0 = arith.constant 0 : index
   %const_1 = arith.constant 1 : index
   %const_3 = arith.constant 3 : index
@@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
   %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
 
-  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
-  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+  // Fold an extract after an insert.
+  // CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32>
+  %inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32>
+  %ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32>
+
+  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]]
+  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex<f32>, i32
 }
 
 // -----
@@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
   return %ins_1 : tensor<4xf32>
 }
 
+
+// -----
+
+func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
+  // Fold an insert into a splat.
+  // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
+  // CHECK-LITERAL:
+  // CHECK-NEXT: return %[[C4]]
+  %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4_i32 = arith.constant 4 : i32
+  %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
+  return %inserted : tensor<2x2xi32>
+}
+
 // -----
 
 // CHECK-LABEL: func @extract_from_tensor.cast
diff --git a/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
new file mode 100644
index 0000000000000..c301f494a7c87
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s
+
+// CHECK-LABEL: @extract_from_collapse_shape
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>)
+func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8>
+  %extracted = tensor.extract %collapsed[%c0] : tensor<8xi8>
+  %extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8>
+  func.return %extracted, %extracted_0 : i8, i8
+}
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8>
+// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8>
+// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8
+
+// -----
+
+// CHECK-LABEL: @extract_from_static_shape
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32>
+  %1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index e435130c2a417..0e191c32f009e 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -77,6 +77,11 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
       llvm::cl::init(false)};
 
+  Option<bool> testFoldExtractFromCollapseShape{
+      *this, "test-fold-extract-from-collapse-shape",
+      llvm::cl::desc("Test folding of extract from collapse_shape"),
+      llvm::cl::init(false)};
+
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
   (void)applyPatternsGreedily(rootOp, std::move(patterns));
 }
 
+static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateFoldCollapseExtractPatterns(patterns);
+  (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() {
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
       return signalPassFailure();
   }
+  if (testFoldExtractFromCollapseShape)
+    applyFoldExtractFromCollapseShapePatterns(rootOp);
   if (testTrackingListener)
     if (failed(testTrackingListenerReplacements(rootOp)))
       return signalPassFailure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/142458


More information about the Mlir-commits mailing list