[Mlir-commits] [mlir] 7d6ef5c - [mlir][tensor] Fold `tensor.cast` into `tensor.collapse_shape` op

Prashant Kumar llvmlistbot at llvm.org
Thu Jul 28 00:41:22 PDT 2022


Author: Gaurav Shukla
Date: 2022-07-28T13:11:43+05:30
New Revision: 7d6ef5caef80a24d170dee0f1fec54f3bc7fd979

URL: https://github.com/llvm/llvm-project/commit/7d6ef5caef80a24d170dee0f1fec54f3bc7fd979
DIFF: https://github.com/llvm/llvm-project/commit/7d6ef5caef80a24d170dee0f1fec54f3bc7fd979.diff

LOG: [mlir][tensor] Fold `tensor.cast` into `tensor.collapse_shape` op

This commit folds a `tensor.cast` op into a `tensor.collapse_shape` op
when following two conditions meet:
1. the `tensor.collapse_shape` op consumes result of the `tensor.cast` op.
2. `tensor.cast` op casts to a more dynamic version of the source tensor.
This is added as a canonicalization pattern in `tensor.collapse_shape` op.

Signed-Off-By: Gaurav Shukla <gaurav at nod-labs.com>

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D130650

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a9437634b285d..2d91f45205e6c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -928,6 +928,36 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
   }
 };
 
+// Fold CastOp into CollapseShapeOp when adding static information.
+struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
+  using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
+    if (!tensor::canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    RankedTensorType srcType =
+        castOp.getSource().getType().cast<RankedTensorType>();
+    RankedTensorType newResultType = computeTensorReshapeCollapsedType(
+        srcType, collapseShapeOp.getReassociationMaps());
+
+    if (newResultType == collapseShapeOp.getResultType()) {
+      rewriter.updateRootInPlace(collapseShapeOp, [&]() {
+        collapseShapeOp.getSrcMutable().assign(castOp.getSource());
+      });
+    } else {
+      auto newOp = rewriter.create<CollapseShapeOp>(
+          collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
+          collapseShapeOp.getReassociation());
+      rewriter.replaceOpWithNewOp<tensor::CastOp>(
+          collapseShapeOp, collapseShapeOp.getResultType(), newOp);
+    }
+    return success();
+  }
+};
+
 } // namespace
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -940,10 +970,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
-              FoldReshapeWithConstant<CollapseShapeOp>,
-              FoldReshapeWithFromElements<CollapseShapeOp>>(context);
+  results
+      .add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
+           ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
+           FoldReshapeWithConstant<CollapseShapeOp>,
+           FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
+          context);
 }
 
 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d07f3e894e242..1eb1a5d7beca7 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -673,6 +673,20 @@ func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
 
 // -----
 
+// CHECK-LABEL: func.func @collapse_of_cast(
+// CHECK-SAME:         %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
+// CHECK-NEXT:    %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32>
+// CHECK-NEXT     %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor<?x32xf32>
+// CHECK-NEXT     return %[[CAST]] : tensor<?x32xf32>
+func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
+  %0 = tensor.cast %t : tensor<8x12x32xf32> to tensor<?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
+  %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<?x32xf32>
+  return %2 : tensor<?x32xf32>
+}
+
+// -----
+
 func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
   %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
       : tensor<12x4xf32> into tensor<3x4x4xf32>


        


More information about the Mlir-commits mailing list