[Mlir-commits] [mlir] f92c750 - Revert "[mlir][tensor] Fold rank-reducing extract_slice with inverse expand_shape"

Matthias Springer llvmlistbot at llvm.org
Fri Dec 2 12:26:30 PST 2022


Author: Matthias Springer
Date: 2022-12-02T21:22:20+01:00
New Revision: f92c7506e3270dfef516e1ef5cac9a9ef1ff0adb

URL: https://github.com/llvm/llvm-project/commit/f92c7506e3270dfef516e1ef5cac9a9ef1ff0adb
DIFF: https://github.com/llvm/llvm-project/commit/f92c7506e3270dfef516e1ef5cac9a9ef1ff0adb.diff

LOG: Revert "[mlir][tensor] Fold rank-reducing extract_slice with inverse expand_shape"

This reverts commit a076f57a1a6b6d775aa4f11ac678d1c43ab33fb1.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Removed: 
    mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
    mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 267972e771688..13ff67e3af433 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -36,10 +36,6 @@ FailureOr<Value> replaceExtractSliceWithTiledProducer(
 void populateMergeConsecutiveInsertExtractSlicePatterns(
     RewritePatternSet &patterns);
 
-/// Populates `patterns` with patterns that fold `tensor.expand_shape` and
-/// `tensor.collapse_shape` into other ops.
-void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
-
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 08a0d5a96b91d..75216e79df371 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
   Bufferize.cpp
   ExtractSliceFromReshapeUtils.cpp
   MergeConsecutiveInsertExtractSlicePatterns.cpp
-  ReshapePatterns.cpp
   SplitPaddingPatterns.cpp
   SwapExtractSliceWithProducerPatterns.cpp
 
@@ -27,4 +26,4 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRTensorDialect
   MLIRTilingInterface
   MLIRTransforms
-)
+  )

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
deleted file mode 100644
index c1166c5eb5ec6..0000000000000
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ /dev/null
@@ -1,57 +0,0 @@
-//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
-#include "mlir/IR/PatternMatch.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "mlir-tensor-split-padding"
-
-using namespace mlir;
-using namespace mlir::tensor;
-
-namespace {
-/// Fold expand_shape(extract_slice) ops that cancel itself out.
-struct FoldExpandOfRankReducingExtract
-    : public OpRewritePattern<ExpandShapeOp> {
-  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
-                                PatternRewriter &rewriter) const override {
-    RankedTensorType resultType = expandShapeOp.getResultType();
-    auto extractSliceOp =
-        expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
-    if (!extractSliceOp)
-      return failure();
-    RankedTensorType srcType = extractSliceOp.getSourceType();
-
-    // Only cases where the ExpandShapeOp can be folded away entirely are
-    // supported. Moreover, only simple cases where the resulting ExtractSliceOp
-    // has no rank-reduction anymore are supported at the moment.
-    RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
-        srcType, extractSliceOp.getStaticOffsets(),
-        extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
-    if (nonReducingExtractType != resultType)
-      return failure();
-
-    SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
-    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
-        expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
-        mixedStrides);
-    return success();
-  }
-};
-} // namespace
-
-void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<FoldExpandOfRankReducingExtract>(patterns.getContext());
-}

diff  --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
deleted file mode 100644
index c81e531507a28..0000000000000
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ /dev/null
@@ -1,19 +0,0 @@
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-reassociative-reshape-folding %s | FileCheck %s
-
-// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract(
-//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
-//   CHECK-DAG:   %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
-//   CHECK-DAG:   %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
-//       CHECK:   return %[[extract1]], %[[extract2]]
-func.func @expand_shape_of_rank_reducing_extract(
-    %t: tensor<?x?x?x?xf32>, %idx: index)
-  -> (tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>)
-{
-  %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
-      : tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
-  %1 = tensor.expand_shape %0 [[0], [1, 2], [3]]
-      : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
-  %2 = tensor.expand_shape %0 [[0, 1], [2], [3]]
-      : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
-  return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
-}

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 1802387c3a56a..fa3cdb7887e48 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -65,11 +65,6 @@ struct TestTensorTransforms
                      "with loop nest"),
       llvm::cl::init(false)};
 
-  Option<bool> testReassociativeReshapeFolding{
-      *this, "test-reassociative-reshape-folding",
-      llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
-      llvm::cl::init(false)};
-
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -79,12 +74,6 @@ struct TestTensorTransforms
 };
 } // namespace
 
-static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
-  RewritePatternSet patterns(rootOp->getContext());
-  tensor::populateReassociativeReshapeFoldingPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
-}
-
 static void applySplitPaddingPatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
   tensor::populateSplitPaddingPatterns(patterns);
@@ -262,8 +251,6 @@ void TestTensorTransforms::runOnOperation() {
     applyFoldConstantExtractSlicePatterns(rootOp);
   if (testFoldConsecutiveInsertExtractSlice)
     applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
-  if (testReassociativeReshapeFolding)
-    applyReassociativeReshapeFoldingPatterns(rootOp);
   if (testRewriteExtractSliceWithTiledCollapseShape) {
     if (failed(
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))


        


More information about the Mlir-commits mailing list