[Mlir-commits] [mlir] 630afc6 - [MLIR][Shape] Canonicalize casted dynamic extent tensor

Frederik Gossen llvmlistbot at llvm.org
Mon Mar 29 04:59:35 PDT 2021


Author: Frederik Gossen
Date: 2021-03-29T13:59:19+02:00
New Revision: 630afc61a85429c2b0e6dbc9ef08e6013be4ad52

URL: https://github.com/llvm/llvm-project/commit/630afc61a85429c2b0e6dbc9ef08e6013be4ad52
DIFF: https://github.com/llvm/llvm-project/commit/630afc61a85429c2b0e6dbc9ef08e6013be4ad52.diff

LOG: [MLIR][Shape] Canonicalize casted dynamic extent tensor

Differential Revision: https://reviews.llvm.org/D99161

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a1419322afb3..bb7ed5cf05ce 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -987,11 +987,43 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
     return success();
   }
 };
+
+// Canonicalize
+// ```
+// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
+// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
+// ```
+// to
+// ```
+// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
+// ```
+struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
+  using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::CastOp op,
+                                PatternRewriter &rewriter) const override {
+    auto ty = op.getType().dyn_cast<RankedTensorType>();
+    if (!ty || ty.getRank() != 1)
+      return failure();
+
+    auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
+    if (!shapeOfOp)
+      return failure();
+
+    // Argument type must be ranked and must not conflict.
+    auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
+    if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
+    return success();
+  }
+};
 } // namespace
 
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
-  patterns.add<ShapeOfWithTensor>(context);
+  patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 39f17e9d253f..b0c12ea0b149 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -648,7 +648,7 @@ func @f() {
   // CHECK: shape.cstr_broadcastable
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %cs0 = shape.const_shape [8, 1] : !shape.shape  
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
   %cs1 = shape.const_shape [1, 8] : !shape.shape
   %cs2 = shape.const_shape [1, -1] : !shape.shape
   %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
@@ -1144,3 +1144,47 @@ func @broadcast_on_single_operand(%a : tensor<3xindex>) {
   "use"(%0) : (tensor<?xindex>) -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: @casted_extent_tensor
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
+func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
+  // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
+  // CHECK: return %[[RESULT]] : tensor<?xindex>
+  %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
+  %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
+  return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @casted_extent_tensor
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
+func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
+  // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
+  // CHECK: return %[[RESULT]] : tensor<3xindex>
+  %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
+  %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
+  return %1 : tensor<3xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @casted_extent_tensor
+func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
+  // CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
+  %0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
+  %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
+  return %1 : tensor<3xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @casted_extent_tensor
+func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
+  // CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
+  %0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
+  %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
+  return %1 : tensor<3xindex>
+}


        


More information about the Mlir-commits mailing list