[Mlir-commits] [mlir] [mlir] Canonicalization pattern for 'shape.shape_of' (PR #98531)

Rafael Ubal llvmlistbot at llvm.org
Sat Jul 13 08:52:37 PDT 2024


================
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
   }
 };
 
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::isa<ShapedType>(op.getArg().getType()))
+    auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+    if (!tensorReshapeOp)
       return failure();
-    if (llvm::isa<ShapedType>(op.getType()))
+    if (op.getType() != tensorReshapeOp.getShape().getType())
----------------
rafaelubalmw wrote:

Yes, or when `shape.shape_of` uses a `!shape.shape` return type. I've split these two scenarios for broader support.

- If the return type is `!shape.shape`, do nothing. I was considering folding into a `shape.from_extent_tensor` instead, but I don't want to overcomplicate things.

- If the return type is a compatible but not identical tensor type, introduce a `tensor.cast` op.

https://github.com/llvm/llvm-project/pull/98531


More information about the Mlir-commits mailing list