[Mlir-commits] [mlir] 511ffe1 - Revert "[MLIR][Shape] Concretize broadcast result type if possible"

Frederik Gossen llvmlistbot at llvm.org
Wed Apr 28 08:17:01 PDT 2021


Author: Frederik Gossen
Date: 2021-04-28T17:16:02+02:00
New Revision: 511ffe17edec6010de9c4a6e1ccc6a8d66e043d3

URL: https://github.com/llvm/llvm-project/commit/511ffe17edec6010de9c4a6e1ccc6a8d66e043d3
DIFF: https://github.com/llvm/llvm-project/commit/511ffe17edec6010de9c4a6e1ccc6a8d66e043d3.diff

LOG: Revert "[MLIR][Shape] Concretize broadcast result type if possible"

This reverts commit dca536103592cf1e92aa8316ed23f33d75da25bc.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/Shape.h
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 08c5d5ddbc82..570719eff64d 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -29,8 +29,7 @@ class PatternRewriter;
 namespace shape {
 
 /// Alias type for extent tensors.
-RankedTensorType getExtentTensorType(MLIRContext *ctx,
-                                     int64_t rank = ShapedType::kDynamicSize);
+RankedTensorType getExtentTensorType(MLIRContext *ctx);
 
 // Check if a type is an extent tensor, e.g., tensor<?xindex>.
 bool isExtentTensorType(Type);

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ac67a62a0aef..fd012aa84d1c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -27,8 +27,8 @@ namespace {
 #include "ShapeCanonicalization.inc"
 }
 
-RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
-  return RankedTensorType::get({rank}, IndexType::get(ctx));
+RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
+  return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
 }
 
 bool shape::isExtentTensorType(Type type) {
@@ -660,42 +660,11 @@ struct CanonicalizeCastExtentTensorOperandsPattern
     return success();
   }
 };
-
-struct BroadcastConcretizeResultTypePattern
-    : public OpRewritePattern<BroadcastOp> {
-  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BroadcastOp op,
-                                PatternRewriter &rewriter) const override {
-    // Only concretize dynamic extent tensor result types.
-    auto resultTy = op.getType().dyn_cast<RankedTensorType>();
-    if (!resultTy || !resultTy.isDynamicDim(0))
-      return failure();
-
-    // Infer resulting shape rank if possible.
-    int64_t maxRank = 0;
-    for (Value shape : op.shapes()) {
-      if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
-        // Cannot infer resulting shape rank if any operand is dynamically
-        // ranked.
-        if (extentTensorTy.isDynamicDim(0))
-          return failure();
-        maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
-      }
-    }
-
-    auto newOp = rewriter.create<BroadcastOp>(
-        op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
-    return success();
-  }
-};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<BroadcastConcretizeResultTypePattern,
-               BroadcastFoldConstantOperandsPattern,
+  patterns.add<BroadcastFoldConstantOperandsPattern,
                BroadcastForwardSingleOperandPattern,
                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
                RemoveDuplicateOperandsPattern<BroadcastOp>,

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 6e0243839132..367ce7f6ba1a 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1344,8 +1344,7 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
     %arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
   // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
   // CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
-  // CHECK: %[[UNCAST_RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
-  // CHECK: %[[RES:.*]] = tensor.cast %[[UNCAST_RES]] : tensor<3xindex> to tensor<?xindex>
+  // CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
   // CHECK: return %[[WIT]], %[[RES]]
   %0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
   %1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
@@ -1354,17 +1353,3 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
       -> tensor<?xindex>
   return %2, %3 : !shape.witness, tensor<?xindex>
 }
-
-// -----
-
-// CHECK-LABEL: @concretize_broadcast_result_type
-// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2xindex>, %[[ARG1:.*]]: tensor<3xindex>)
-func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>,
-    %arg1 : tensor<3xindex>) -> tensor<?xindex> {
-  // CHECK: %[[CONCR:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<2xindex>, tensor<3xindex> -> tensor<3xindex>
-  // CHECK: %[[RES:.*]] = tensor.cast %[[CONCR]] : tensor<3xindex> to tensor<?xindex>
-  // CHECK: return %[[RES]]
-  %0 = shape.broadcast %arg0, %arg1 : tensor<2xindex>, tensor<3xindex>
-      -> tensor<?xindex>
-  return %0 : tensor<?xindex>
-}


        


More information about the Mlir-commits mailing list