[Mlir-commits] [mlir] 26864d8 - [mlir][tensor] Add pattern to drop redundant insert_slice rank expansion

Matthias Springer llvmlistbot at llvm.org
Wed May 31 23:54:58 PDT 2023


Author: Matthias Springer
Date: 2023-06-01T08:47:53+02:00
New Revision: 26864d8fb4c2c2f3f85cc0e1225f8c9596ef0b64

URL: https://github.com/llvm/llvm-project/commit/26864d8fb4c2c2f3f85cc0e1225f8c9596ef0b64
DIFF: https://github.com/llvm/llvm-project/commit/26864d8fb4c2c2f3f85cc0e1225f8c9596ef0b64.diff

LOG: [mlir][tensor] Add pattern to drop redundant insert_slice rank expansion

Drop insert_slice rank expansions if they are directly followed by an inverse rank reduction.

Differential Revision: https://reviews.llvm.org/D151800

Added: 
    mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
    mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
    mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
    mlir/lib/Dialect/Tensor/Utils/Utils.cpp
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 9922dc8358acb..fe8f6cc9ff286 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -48,6 +48,11 @@ void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns);
 void populateMergeConsecutiveInsertExtractSlicePatterns(
     RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that drop redundant tensor.insert_slice
+/// rank expansions.
+void populateDropRedundantInsertSliceRankExpansionPatterns(
+    RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold `tensor.expand_shape` and
 /// `tensor.collapse_shape` into other ops.
 void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index c0f33d15cb518..a037d40f901b0 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -42,6 +42,11 @@ FailureOr<RankedTensorType>
 computeTransposedType(RankedTensorType rankedTensorType,
                       ArrayRef<int64_t> transposeVector);
 
+/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
+/// source tensor or inserts the source tensor into a destination tensor with
+/// the same shape.
+bool isCastLikeInsertSliceOp(InsertSliceOp op);
+
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
index ff603c950bb1a..113a29b31d0ac 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
@@ -13,6 +13,6 @@ add_mlir_dialect_library(MLIRTensorTransformOps
   MLIRSCFDialect
   MLIRTensorDialect
   MLIRTensorTransforms
+  MLIRTensorUtils
   MLIRTransformDialect
-  MLIRValueBoundsOpInterface
 )

diff  --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 92f7dbd5ae95d..9b609a2f55f43 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -12,9 +12,9 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -24,29 +24,6 @@ using namespace tensor;
 // TrackingListener
 //===----------------------------------------------------------------------===//
 
-/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
-/// source tensor or inserts the source tensor into a destination tensor with
-/// the same shape.
-static bool isCastLikeInsertSliceOp(InsertSliceOp op) {
-  llvm::SmallBitVector droppedDims = op.getDroppedDims();
-  int64_t srcDim = 0;
-  // Source dims and destination dims (apart from dropped dims) must have the
-  // same size.
-  for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
-       ++resultDim) {
-    if (droppedDims.test(resultDim)) {
-      continue;
-    }
-    FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
-        op.getSource(), op.getResult(), srcDim, resultDim);
-    if (failed(equalDimSize) || !*equalDimSize)
-      return false;
-    ++srcDim;
-  }
-
-  return true;
-}
-
 Operation *
 tensor::TrackingListener::findReplacementOp(Operation *op,
                                             ValueRange newValues) const {

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c41e9e9ce6839..083c9c936d4cf 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRPass
   MLIRSCFDialect
   MLIRTensorDialect
+  MLIRTensorUtils
   MLIRTilingInterface
   MLIRTransforms
   MLIRVectorDialect

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
index 9b8853d123ea8..e32ddf08a769f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
@@ -76,6 +77,63 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
     return success();
   }
 };
+
+/// Drop redundant rank expansion. I.e., rank expansions that are directly
+/// followed by rank reductions. E.g.:
+/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
+/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
+///     : tensor<1x1x5x10xf32> to tensor<2x2xf32>
+struct DropRedundantInsertSliceRankExpansion
+    : public OpRewritePattern<ExtractSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
+                                PatternRewriter &rewriter) const override {
+    // Nothing to do if no dims are dropped.
+    llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
+    if (droppedDims.empty())
+      return failure();
+
+    // Look for tensor.insert_slice op that has an inverse rank expansion.
+    auto insertSliceOp =
+        extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
+    if (!insertSliceOp)
+      return failure();
+    llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
+
+    // TODO: This could be extended to support cases where the dropped dims are
+    // a subset of the expanded dims.
+    if (expandedDims != droppedDims)
+      return failure();
+
+    // The tensor.insert_slice may not be redundant if it has multiple users.
+    if (!insertSliceOp->hasOneUse())
+      return failure();
+
+    // Only consider tensor.insert_slice ops that are pure rank-reductions.
+    // I.e., no elements are taken from the destination.
+    if (!isCastLikeInsertSliceOp(insertSliceOp))
+      return failure();
+
+    // Extract directly from the source.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(extractSliceOp);
+    SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
+    for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
+         ++i) {
+      if (droppedDims.test(i))
+        continue;
+      newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
+      newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
+      newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
+    }
+    rewriter.replaceOpWithNewOp<ExtractSliceOp>(
+        extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
+        newSizes, newStrides);
+    rewriter.eraseOp(insertSliceOp);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
@@ -85,3 +143,8 @@ void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
                MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
       patterns.getContext());
 }
+
+void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DropRedundantInsertSliceRankExpansion>(patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
index b7848b1a44229..6de229b2fe141 100644
--- a/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_dialect_library(MLIRTensorUtils
   MLIRArithUtils
   MLIRIR
   MLIRTensorDialect
+  MLIRValueBoundsOpInterface
 )

diff  --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 4ecb800caab42..165cf9b0b2f7c 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -102,3 +103,23 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
       RTTBuilder(rankedTensorType).setShape(transposedShape);
   return transposedTensorType;
 }
+
+bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
+  llvm::SmallBitVector droppedDims = op.getDroppedDims();
+  int64_t srcDim = 0;
+  // Source dims and destination dims (apart from dropped dims) must have the
+  // same size.
+  for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
+       ++resultDim) {
+    if (droppedDims.test(resultDim)) {
+      continue;
+    }
+    FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
+        op.getSource(), op.getResult(), srcDim, resultDim);
+    if (failed(equalDimSize) || !*equalDimSize)
+      return false;
+    ++srcDim;
+  }
+
+  return true;
+}

diff  --git a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
new file mode 100644
index 0000000000000..e337fdd932142
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-drop-redundant-insert-slice-rank-expansion %s | FileCheck %s
+
+// CHECK-LABEL: func @test_drop_rank_expansion(
+//  CHECK-SAME:     %[[src:.*]]: tensor<128x480xf32>,
+//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[src]][0, 0] [123, 456] [1, 1] : tensor<128x480xf32> to tensor<123x456xf32>
+//       CHECK:   return %[[extract]]
+func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1x128x480xf32>) -> tensor<123x456xf32> {
+  %inserted_slice = tensor.insert_slice %src into %dest[0, 0, 0, 0] [1, 1, 128, 480] [1, 1, 1, 1] : tensor<128x480xf32> into tensor<1x1x128x480xf32>
+  %extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32>
+  return %extracted_slice : tensor<123x456xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index f28c9fda4c8f0..1263550f2e06b 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -62,6 +62,11 @@ struct TestTensorTransforms
                      "with loop nest"),
       llvm::cl::init(false)};
 
+  Option<bool> testDropRedundantInsertSliceRankExpansion{
+      *this, "test-drop-redundant-insert-slice-rank-expansion",
+      llvm::cl::desc("Test dropping redundant insert_slice rank expansions"),
+      llvm::cl::init(false)};
+
   Option<bool> testReassociativeReshapeFolding{
       *this, "test-reassociative-reshape-folding",
       llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
@@ -135,6 +140,13 @@ static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void
+applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 static void applySimplifyPackPatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
   tensor::populateSimplifyTensorPack(patterns);
@@ -367,6 +379,8 @@ void TestTensorTransforms::runOnOperation() {
     applyFoldConstantExtractSlicePatterns(rootOp);
   if (testFoldConsecutiveInsertExtractSlice)
     applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
+  if (testDropRedundantInsertSliceRankExpansion)
+    applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
   if (testReassociativeReshapeFolding)
     applyReassociativeReshapeFoldingPatterns(rootOp);
   if (testEmptyOpFolding)

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b9b07b5d705fa..3451adc079566 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5935,6 +5935,7 @@ cc_library(
         ":ArithUtils",
         ":DialectUtils",
         ":TensorDialect",
+        ":ValueBoundsOpInterface",
         "//llvm:Support",
     ],
 )
@@ -5988,6 +5989,7 @@ cc_library(
         ":SCFDialect",
         ":TensorDialect",
         ":TensorPassIncGen",
+        ":TensorUtils",
         ":TilingInterface",
         ":Transforms",
         ":ValueBoundsOpInterface",
@@ -6039,6 +6041,7 @@ cc_library(
         ":TensorDialect",
         ":TensorTransformOpsIncGen",
         ":TensorTransforms",
+        ":TensorUtils",
         ":TransformDialect",
         ":ValueBoundsOpInterface",
         "//llvm:Support",


        


More information about the Mlir-commits mailing list