[Mlir-commits] [mlir] 34d8275 - [mlir][tensor] add tensor insert/extract op folders (#142458)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 3 09:16:06 PDT 2025
Author: asraa
Date: 2025-06-03T09:16:03-07:00
New Revision: 34d8275e4fcd619226e2872ea0ee07f8a1634ff7
URL: https://github.com/llvm/llvm-project/commit/34d8275e4fcd619226e2872ea0ee07f8a1634ff7
DIFF: https://github.com/llvm/llvm-project/commit/34d8275e4fcd619226e2872ea0ee07f8a1634ff7.diff
LOG: [mlir][tensor] add tensor insert/extract op folders (#142458)
Adds a few canonicalizers, folders, and rewrite patterns to tensor ops:
* tensor.insert folder: insert into a constant is replaced with a new
constant
* tensor.extract folder: extract from a parent tensor that was inserted
at the same indices is folded into the inserted value
* rewrite pattern added that replaces an extract of a collapse shape
with an extract of the source tensor (requires static source dimensions)
Signed-off-by: Asra Ali <asraa at google.com>
Added:
mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index eb550bb469b9f..e8e1342ef36fd 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns(
return false;
});
+/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
+/// source tensor.
+void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..c0885a3763827 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasFolder = 1;
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 30ca20fc0d883..f2a7220b4bedc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -33,10 +34,12 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
+#include <vector>
using namespace mlir;
using namespace mlir::tensor;
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
}
};
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
+/// tensor<12xf64>
+/// %extracted_element = tensor.extract %val[%c10] :
+/// tensor<12xf64>
+///
+/// to
+///
+/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
+struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
+ PatternRewriter &rewriter) const final {
+ auto collapseOp =
+ extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseOp)
+ return failure();
+ if (!collapseOp.getSrcType().hasStaticShape())
+ return failure();
+
+ auto sourceSizes = collapseOp.getSrcType().getShape();
+
+ SmallVector<Value> indices(extractOp.getIndices().begin(),
+ extractOp.getIndices().end());
+ SmallVector<Value> sourceIndices;
+ for (auto [index, group] :
+ llvm::zip(indices, collapseOp.getReassociationIndices())) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ auto groupSize = group.size();
+
+ if (groupSize == 1) {
+ sourceIndices.push_back(index);
+ continue;
+ }
+
+ SmallVector<int64_t> basis =
+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
+ llvm::append_range(sourceIndices, delinearize.getResults());
+ }
+ if (collapseOp.getReassociationIndices().empty()) {
+ auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+ int64_t srcRank =
+ cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(), zeroAffineMap,
+ ArrayRef<OpFoldResult>{});
+ for (int64_t i = 0; i < srcRank; i++) {
+ sourceIndices.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ extractOp, collapseOp.getSrc(), sourceIndices);
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getAsmResultNames(
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
return success();
}
+/// If we have an ExtractOp consuming an InsertOp with the same
+/// indices, we can return the InsertOp's scalar directly.
+// TODO: This only checks the immediate producer; extend to go up the
+// insert/extract chain if the slices are disjoint.
+static Value foldExtractAfterInsert(ExtractOp extractOp) {
+ auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
+
+ auto isSame = [](Value a, Value b) {
+ return getAsOpFoldResult(a) == getAsOpFoldResult(b);
+ };
+ if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
+ llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
+ return insertOp.getScalar();
+
+ return {};
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (Attribute tensor = adaptor.getTensor()) {
// If this is a splat elements attribute, simply return the value.
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return elementsAttr.getValues<Attribute>()[indices];
}
+ if (Value result = foldExtractAfterInsert(*this))
+ return result;
+
return {};
}
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExtractFromTensorCast>(context);
}
+void mlir::tensor::populateFoldCollapseExtractPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ExtractFromCollapseShape>(patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
+namespace {
+
+/// Pattern to fold an insert op of a constant destination and scalar to a new
+/// constant.
+///
+/// Example:
+/// ```
+/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// %c0 = arith.constant 0 : index
+/// %c4_f32 = arith.constant 4.0 : f32
+/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// ```
+class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ // Requires a ranked tensor type.
+ auto destType =
+ llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
+ if (!destType)
+ return failure();
+
+ // Pattern requires constant indices
+ SmallVector<uint64_t, 8> indices;
+ for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
+ auto indiceAttr = dyn_cast<Attribute>(indice);
+ if (!indiceAttr)
+ return failure();
+ indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
+ }
+
+ // Requires a constant scalar to insert
+ OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
+ Attribute scalarAttr = dyn_cast<Attribute>(scalar);
+ if (!scalarAttr)
+ return failure();
+
+ if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
+ insertOp.getDest().getDefiningOp())) {
+ if (auto sourceAttr =
+ llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
+ // Update the attribute at the inserted index.
+ auto sourceValues = sourceAttr.getValues<Attribute>();
+ auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
+ std::vector<Attribute> updatedValues;
+ updatedValues.reserve(sourceAttr.getNumElements());
+ for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
+ updatedValues.push_back(i == flattenedIndex ? scalarAttr
+ : sourceValues[i]);
+ }
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ insertOp, sourceAttr.getType(),
+ DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
+ updatedValues));
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+
+} // namespace
+
void InsertOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<InsertOpConstantFold>(context);
+}
+
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3eaf824b99115..646b2197d9aa6 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12x
// -----
// CHECK-LABEL: func @fold_extract
-func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
+func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>, i32) {
%const_0 = arith.constant 0 : index
%const_1 = arith.constant 1 : index
%const_3 = arith.constant 3 : index
@@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
%ext_5 = tensor.extract %4[] : tensor<complex<f32>>
- // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
- return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+ // Fold an extract after an insert.
+ // CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32>
+ %inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32>
+ %ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32>
+
+ // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]]
+ return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex<f32>, i32
}
// -----
@@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}
+
+// -----
+
+func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
+ // Fold an insert into a splat.
+ // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
+ // CHECK-LITERAL:
+ // CHECK-NEXT: return %[[C4]]
+ %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4_i32 = arith.constant 4 : i32
+ %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
+ return %inserted : tensor<2x2xi32>
+}
+
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
diff --git a/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
new file mode 100644
index 0000000000000..c301f494a7c87
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s
+
+// CHECK-LABEL: @extract_from_collapse_shape
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>)
+func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8>
+ %extracted = tensor.extract %collapsed[%c0] : tensor<8xi8>
+ %extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8>
+ func.return %extracted, %extracted_0 : i8, i8
+}
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8>
+// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8>
+// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8
+
+// -----
+
+// CHECK-LABEL: @extract_from_static_shape
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32>
+ %1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index e435130c2a417..0e191c32f009e 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -77,6 +77,11 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
llvm::cl::init(false)};
+ Option<bool> testFoldExtractFromCollapseShape{
+ *this, "test-fold-extract-from-collapse-shape",
+ llvm::cl::desc("Test folding of extract from collapse_shape"),
+ llvm::cl::init(false)};
+
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
@@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}
+static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateFoldCollapseExtractPatterns(patterns);
+ (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() {
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
return signalPassFailure();
}
+ if (testFoldExtractFromCollapseShape)
+ applyFoldExtractFromCollapseShapePatterns(rootOp);
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();
More information about the Mlir-commits
mailing list