[Mlir-commits] [mlir] [mlir] Canonicalization pattern for 'shape.shape_of' (PR #98531)
Rafael Ubal
llvmlistbot at llvm.org
Sat Jul 13 08:52:37 PDT 2024
================
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
}
};
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
- if (!llvm::isa<ShapedType>(op.getArg().getType()))
+ auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+ if (!tensorReshapeOp)
return failure();
- if (llvm::isa<ShapedType>(op.getType()))
+ if (op.getType() != tensorReshapeOp.getShape().getType())
----------------
rafaelubalmw wrote:
Yes, or when `shape.shape_of` uses a `!shape.shape` return type. I've split these two scenarios for broader support.
- If the return type is `!shape.shape`, do nothing. I was considering folding into a `shape.from_extent_tensor` instead, but I don't want to overcomplicate things.
- If the return type is a compatible but not identical tensor type, introduce a `tensor.cast` op.
https://github.com/llvm/llvm-project/pull/98531
More information about the Mlir-commits
mailing list