[Mlir-commits] [mlir] [mlir] Fold expand of cast (PR #112265)

Ian Wood llvmlistbot at llvm.org
Mon Oct 14 14:29:33 PDT 2024


https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/112265

Sink `tensor.cast` op through tensor.expand_shape ops when it makes the expand op more static. This allows for other ops further down to infer their shapes.

>From 67649194b893a9a017082964d285056f4c6656fa Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Mon, 14 Oct 2024 13:29:42 -0500
Subject: [PATCH] [mlir] Fold expand of cast

Sink tensor.cast op through tensor.expand_shape ops when it makes the
expand op more static. This allows for other ops further down infer
their shapes.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 31 +++++++++++++++++++++-
 mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++
 2 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4d6c5965c4fcc3..9be647f687e600 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1982,6 +1982,35 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+struct FoldExpandOfCast : public OpRewritePattern<ExpandShapeOp> {
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    SmallVector<OpFoldResult> outputOfr =
+        getMixedValues(expandOp.getResultType().getShape(),
+                       expandOp.getOutputShape(), rewriter);
+    std::optional<SmallVector<int64_t>> constantOutputShape =
+        getConstantIntValues(outputOfr);
+    if (!constantOutputShape.has_value()) {
+      return failure();
+    }
+    auto newType = RankedTensorType::get(
+        constantOutputShape.value(), expandOp.getSrcType().getElementType());
+
+    auto newExpand = rewriter.create<ExpandShapeOp>(
+        castOp.getLoc(), newType, castOp.getSource(),
+        expandOp.getReassociationIndices());
+    rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
+                                        newExpand.getResult());
+    return success();
+  }
+};
 } // namespace
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -1989,7 +2018,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
-      FoldReshapeWithConstant<ExpandShapeOp>,
+      FoldExpandOfCast, FoldReshapeWithConstant<ExpandShapeOp>,
       FoldReshapeWithSplat<ExpandShapeOp>,
       FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
       FoldDimOfCollapseShape>(context);
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0aa2d33ef17ed4..1509d26151119d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2718,3 +2718,17 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
   %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
   return %pack : tensor<128x?x100x16x1xf16>
 }
+
+// -----
+
+func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
+    -> tensor<?x?x?xf32> {
+  %c1 = arith.constant 1 : index 
+  %c10 = arith.constant 10 : index 
+  %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL:  func.func @fold_expand_of_cast
+//       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]



More information about the Mlir-commits mailing list