[Mlir-commits] [mlir] cb393f4 - [MLIR][Shape] Canonicalize casted extent tensor operands

Frederik Gossen llvmlistbot at llvm.org
Wed Apr 28 02:52:17 PDT 2021


Author: Frederik Gossen
Date: 2021-04-28T11:51:58+02:00
New Revision: cb393f4c99c1e35d79a1017e59b00014f01daf3e

URL: https://github.com/llvm/llvm-project/commit/cb393f4c99c1e35d79a1017e59b00014f01daf3e
DIFF: https://github.com/llvm/llvm-project/commit/cb393f4c99c1e35d79a1017e59b00014f01daf3e.diff

LOG: [MLIR][Shape] Canonicalize casted extent tensor operands

Both, `shape.broadcast` and `shape.cstr_broadcastable` accept dynamic and static
extent tensors. If their operands are casted, we can use the original value
instead.

Differential Revision: https://reviews.llvm.org/D101376

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 47fd322ba47c..fd012aa84d1c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -628,12 +628,45 @@ struct BroadcastFoldConstantOperandsPattern
     return success();
   }
 };
+
+template <typename OpTy>
+struct CanonicalizeCastExtentTensorOperandsPattern
+    : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Canonicalize operands.
+    bool anyChange = false;
+    auto canonicalizeOperand = [&](Value operand) {
+      if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
+        // Only eliminate the cast if it holds no shape information.
+        bool isInformationLoosingCast =
+            castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
+        if (isInformationLoosingCast) {
+          anyChange = true;
+          return castOp.source();
+        }
+      }
+      return operand;
+    };
+    auto newOperands = llvm::to_vector<8>(
+        llvm::map_range(op.getOperands(), canonicalizeOperand));
+
+    // Rewrite op if any change required.
+    if (!anyChange)
+      return failure();
+    rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
   patterns.add<BroadcastFoldConstantOperandsPattern,
                BroadcastForwardSingleOperandPattern,
+               CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
                RemoveDuplicateOperandsPattern<BroadcastOp>,
                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
 }
@@ -716,7 +749,8 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
   // Canonicalization patterns have overlap with the considerations during
   // folding in case additional shape information is inferred at some point that
   // does not result in folding.
-  patterns.add<CstrBroadcastableEqOps,
+  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
+               CstrBroadcastableEqOps,
                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
 }
@@ -1188,7 +1222,7 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
 // ```
 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
 // ```
-struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
+struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(tensor::CastOp op,
@@ -1214,7 +1248,7 @@ struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
 
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
-  patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
+  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 3876e9ca34fc..367ce7f6ba1a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1115,8 +1115,8 @@ func @fold_div_mixed() -> !shape.size {
 // CHECK-LABEL: @fold_index_cast_on_index
 func @fold_index_cast_on_index(%arg: index) -> index {
   // CHECK-NOT: size_to_index
-  %casted = shape.size_to_index %arg : index
-  return %casted : index
+  %0 = shape.size_to_index %arg : index
+  return %0 : index
 }
 
 // -----
@@ -1125,8 +1125,8 @@ func @fold_index_cast_on_index(%arg: index) -> index {
 // CHECK-LABEL: @fold_to_extent_tensor_on_tensor
 func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex> {
   // CHECK-NOT: to_extent_tensor
-  %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
-  return %casted : tensor<?xindex>
+  %0 = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
+  return %0 : tensor<?xindex>
 }
 
 // -----
@@ -1264,9 +1264,9 @@ func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape {
 
 // -----
 
-// CHECK-LABEL: @casted_extent_tensor
+// CHECK-LABEL: @cast_extent_tensor
 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
-func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
+func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
   // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
   // CHECK: return %[[RESULT]] : tensor<?xindex>
   %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
@@ -1276,9 +1276,9 @@ func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
 
 // -----
 
-// CHECK-LABEL: @casted_extent_tensor
+// CHECK-LABEL: @cast_extent_tensor
 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
-func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
+func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
   // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
   // CHECK: return %[[RESULT]] : tensor<3xindex>
   %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
@@ -1288,8 +1288,8 @@ func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
 
 // -----
 
-// CHECK-LABEL: @casted_extent_tensor
-func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
+// CHECK-LABEL: @cast_extent_tensor
+func @cast_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
   // CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
   %0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
   %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
@@ -1298,8 +1298,8 @@ func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
 
 // -----
 
-// CHECK-LABEL: @casted_extent_tensor
-func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
+// CHECK-LABEL: @cast_extent_tensor
+func @cast_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
   // CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
   %0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
   %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
@@ -1335,3 +1335,21 @@ func @cstr_broadcastable_folding(%arg : tensor<?x4xf32>) {
   %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex>
   "use"(%2) : (!shape.witness) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: @cast_extent_tensor_operands
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<3xindex>)
+func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
+    %arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
+  // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
+  // CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
+  // CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
+  // CHECK: return %[[WIT]], %[[RES]]
+  %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
+  %1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
+  %2 = shape.cstr_broadcastable %0, %1 : tensor<3xindex>, tensor<?xindex>
+  %3 = shape.broadcast %0, %1 :tensor<3xindex>, tensor<?xindex>
+      -> tensor<?xindex>
+  return %2, %3 : !shape.witness, tensor<?xindex>
+}


        


More information about the Mlir-commits mailing list