[Mlir-commits] [mlir] [mlir] Canonicalization pattern for 'shape.shape_of' (PR #98531)
Jacques Pienaar
llvmlistbot at llvm.org
Fri Jul 12 18:37:48 PDT 2024
================
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
}
};
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
- if (!llvm::isa<ShapedType>(op.getArg().getType()))
+ auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+ if (!tensorReshapeOp)
return failure();
----------------
jpienaar wrote:
While here, lets use rewriter.notifyMatchFailure to make debugging easier
https://github.com/llvm/llvm-project/pull/98531
More information about the Mlir-commits
mailing list