[Mlir-commits] [mlir] 26864d8 - [mlir][tensor] Add pattern to drop redundant insert_slice rank expansion
Matthias Springer
llvmlistbot at llvm.org
Wed May 31 23:54:58 PDT 2023
Author: Matthias Springer
Date: 2023-06-01T08:47:53+02:00
New Revision: 26864d8fb4c2c2f3f85cc0e1225f8c9596ef0b64
URL: https://github.com/llvm/llvm-project/commit/26864d8fb4c2c2f3f85cc0e1225f8c9596ef0b64
DIFF: https://github.com/llvm/llvm-project/commit/26864d8fb4c2c2f3f85cc0e1225f8c9596ef0b64.diff
LOG: [mlir][tensor] Add pattern to drop redundant insert_slice rank expansion
Drop insert_slice rank expansions if they are directly followed by an inverse rank reduction.
Differential Revision: https://reviews.llvm.org/D151800
Added:
mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 9922dc8358acb..fe8f6cc9ff286 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -48,6 +48,11 @@ void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns);
void populateMergeConsecutiveInsertExtractSlicePatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that drop redundant tensor.insert_slice
+/// rank expansions.
+void populateDropRedundantInsertSliceRankExpansionPatterns(
+ RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold `tensor.expand_shape` and
/// `tensor.collapse_shape` into other ops.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index c0f33d15cb518..a037d40f901b0 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -42,6 +42,11 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);
+/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
+/// source tensor or inserts the source tensor into a destination tensor with
+/// the same shape.
+bool isCastLikeInsertSliceOp(InsertSliceOp op);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
index ff603c950bb1a..113a29b31d0ac 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
@@ -13,6 +13,6 @@ add_mlir_dialect_library(MLIRTensorTransformOps
MLIRSCFDialect
MLIRTensorDialect
MLIRTensorTransforms
+ MLIRTensorUtils
MLIRTransformDialect
- MLIRValueBoundsOpInterface
)
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 92f7dbd5ae95d..9b609a2f55f43 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -12,9 +12,9 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -24,29 +24,6 @@ using namespace tensor;
// TrackingListener
//===----------------------------------------------------------------------===//
-/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
-/// source tensor or inserts the source tensor into a destination tensor with
-/// the same shape.
-static bool isCastLikeInsertSliceOp(InsertSliceOp op) {
- llvm::SmallBitVector droppedDims = op.getDroppedDims();
- int64_t srcDim = 0;
- // Source dims and destination dims (apart from dropped dims) must have the
- // same size.
- for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
- ++resultDim) {
- if (droppedDims.test(resultDim)) {
- continue;
- }
- FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), srcDim, resultDim);
- if (failed(equalDimSize) || !*equalDimSize)
- return false;
- ++srcDim;
- }
-
- return true;
-}
-
Operation *
tensor::TrackingListener::findReplacementOp(Operation *op,
ValueRange newValues) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c41e9e9ce6839..083c9c936d4cf 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
MLIRPass
MLIRSCFDialect
MLIRTensorDialect
+ MLIRTensorUtils
MLIRTilingInterface
MLIRTransforms
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
index 9b8853d123ea8..e32ddf08a769f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -76,6 +77,63 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
return success();
}
};
+
+/// Drop redundant rank expansion. I.e., rank expansions that are directly
+/// followed by rank reductions. E.g.:
+/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
+/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
+/// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
+struct DropRedundantInsertSliceRankExpansion
+ : public OpRewritePattern<ExtractSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
+ PatternRewriter &rewriter) const override {
+ // Nothing to do if no dims are dropped.
+ llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
+ if (droppedDims.empty())
+ return failure();
+
+ // Look for tensor.insert_slice op that has an inverse rank expansion.
+ auto insertSliceOp =
+ extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
+ if (!insertSliceOp)
+ return failure();
+ llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
+
+ // TODO: This could be extended to support cases where the dropped dims are
+ // a subset of the expanded dims.
+ if (expandedDims != droppedDims)
+ return failure();
+
+ // The tensor.insert_slice may not be redundant if it has multiple users.
+ if (!insertSliceOp->hasOneUse())
+ return failure();
+
+ // Only consider tensor.insert_slice ops that are pure rank-reductions.
+ // I.e., no elements are taken from the destination.
+ if (!isCastLikeInsertSliceOp(insertSliceOp))
+ return failure();
+
+ // Extract directly from the source.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(extractSliceOp);
+ SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
+ for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
+ ++i) {
+ if (droppedDims.test(i))
+ continue;
+ newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
+ newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
+ newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
+ }
+ rewriter.replaceOpWithNewOp<ExtractSliceOp>(
+ extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
+ newSizes, newStrides);
+ rewriter.eraseOp(insertSliceOp);
+ return success();
+ }
+};
} // namespace
void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
@@ -85,3 +143,8 @@ void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
patterns.getContext());
}
+
+void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<DropRedundantInsertSliceRankExpansion>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
index b7848b1a44229..6de229b2fe141 100644
--- a/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_dialect_library(MLIRTensorUtils
MLIRArithUtils
MLIRIR
MLIRTensorDialect
+ MLIRValueBoundsOpInterface
)
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 4ecb800caab42..165cf9b0b2f7c 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
using namespace mlir::tensor;
@@ -102,3 +103,23 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
RTTBuilder(rankedTensorType).setShape(transposedShape);
return transposedTensorType;
}
+
+bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+ int64_t srcDim = 0;
+ // Source dims and destination dims (apart from dropped dims) must have the
+ // same size.
+ for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
+ ++resultDim) {
+ if (droppedDims.test(resultDim)) {
+ continue;
+ }
+ FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
+ op.getSource(), op.getResult(), srcDim, resultDim);
+ if (failed(equalDimSize) || !*equalDimSize)
+ return false;
+ ++srcDim;
+ }
+
+ return true;
+}
diff --git a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
new file mode 100644
index 0000000000000..e337fdd932142
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-drop-redundant-insert-slice-rank-expansion %s | FileCheck %s
+
+// CHECK-LABEL: func @test_drop_rank_expansion(
+// CHECK-SAME: %[[src:.*]]: tensor<128x480xf32>,
+// CHECK: %[[extract:.*]] = tensor.extract_slice %[[src]][0, 0] [123, 456] [1, 1] : tensor<128x480xf32> to tensor<123x456xf32>
+// CHECK: return %[[extract]]
+func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1x128x480xf32>) -> tensor<123x456xf32> {
+ %inserted_slice = tensor.insert_slice %src into %dest[0, 0, 0, 0] [1, 1, 128, 480] [1, 1, 1, 1] : tensor<128x480xf32> into tensor<1x1x128x480xf32>
+ %extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32>
+ return %extracted_slice : tensor<123x456xf32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index f28c9fda4c8f0..1263550f2e06b 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -62,6 +62,11 @@ struct TestTensorTransforms
"with loop nest"),
llvm::cl::init(false)};
+ Option<bool> testDropRedundantInsertSliceRankExpansion{
+ *this, "test-drop-redundant-insert-slice-rank-expansion",
+ llvm::cl::desc("Test dropping redundant insert_slice rank expansions"),
+ llvm::cl::init(false)};
+
Option<bool> testReassociativeReshapeFolding{
*this, "test-reassociative-reshape-folding",
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
@@ -135,6 +140,13 @@ static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+static void
+applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
static void applySimplifyPackPatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateSimplifyTensorPack(patterns);
@@ -367,6 +379,8 @@ void TestTensorTransforms::runOnOperation() {
applyFoldConstantExtractSlicePatterns(rootOp);
if (testFoldConsecutiveInsertExtractSlice)
applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
+ if (testDropRedundantInsertSliceRankExpansion)
+ applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
if (testReassociativeReshapeFolding)
applyReassociativeReshapeFoldingPatterns(rootOp);
if (testEmptyOpFolding)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b9b07b5d705fa..3451adc079566 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5935,6 +5935,7 @@ cc_library(
":ArithUtils",
":DialectUtils",
":TensorDialect",
+ ":ValueBoundsOpInterface",
"//llvm:Support",
],
)
@@ -5988,6 +5989,7 @@ cc_library(
":SCFDialect",
":TensorDialect",
":TensorPassIncGen",
+ ":TensorUtils",
":TilingInterface",
":Transforms",
":ValueBoundsOpInterface",
@@ -6039,6 +6041,7 @@ cc_library(
":TensorDialect",
":TensorTransformOpsIncGen",
":TensorTransforms",
+ ":TensorUtils",
":TransformDialect",
":ValueBoundsOpInterface",
"//llvm:Support",
More information about the Mlir-commits
mailing list