[Mlir-commits] [mlir] 858d488 - [MLIR][Shape] Ensure to preserve op type of `shape.broadcast`

Frederik Gossen llvmlistbot at llvm.org
Mon Apr 26 08:56:20 PDT 2021


Author: Frederik Gossen
Date: 2021-04-26T17:55:39+02:00
New Revision: 858d4885dcc2242910971607bf46428d5b563d95

URL: https://github.com/llvm/llvm-project/commit/858d4885dcc2242910971607bf46428d5b563d95
DIFF: https://github.com/llvm/llvm-project/commit/858d4885dcc2242910971607bf46428d5b563d95.diff

LOG: [MLIR][Shape] Ensure to preserve op type of `shape.broadcast`

Ensure to preserve the correct type during when folding and canonicalization.
`shape.broadcast` of of a single operand can only be folded away if the argument
type is correct.

Differential Revision: https://reviews.llvm.org/D101158

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 1bec838dfbb2..392bcddb920a 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -470,8 +470,12 @@ void AssumingAllOp::build(OpBuilder &b, OperationState &state,
 //===----------------------------------------------------------------------===//
 
 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
-  if (shapes().size() == 1)
+  if (shapes().size() == 1) {
+    // Otherwise, we need a cast which would be a canonicalization, not folding.
+    if (shapes().front().getType() != getType())
+      return nullptr;
     return shapes().front();
+  }
 
   // TODO: Support folding with more than 2 input shapes
   if (shapes().size() > 2)
@@ -556,8 +560,10 @@ struct BroadcastForwardSingleOperandPattern
                                 PatternRewriter &rewriter) const override {
     if (op.getNumOperands() == 1) {
       Value uniqueShapeOperand = op.shapes().front();
-      rewriter.replaceOp(op, uniqueShapeOperand);
-      return success();
+      if (uniqueShapeOperand.getType() == op.getType()) {
+        rewriter.replaceOp(op, uniqueShapeOperand);
+        return success();
+      }
     }
     return failure();
   }

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 9c698c493446..95d0de628b75 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1217,10 +1217,21 @@ func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
 // -----
 
 // CHECK-LABEL: @broadcast_on_single_operand
-// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
-func @broadcast_on_single_operand(%a : tensor<3xindex>) {
+// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>)
+func @broadcast_on_single_operand(%a : tensor<?xindex>) {
   // CHECK-NOT: broadcast
   // CHECK: "use"(%[[A]])
+  %0 = shape.broadcast %a : tensor<?xindex> -> tensor<?xindex>
+  "use"(%0) : (tensor<?xindex>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_single_operand
+// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
+func @broadcast_on_single_operand(%a : tensor<3xindex>) {
+  // CHECK: broadcast %[[A]]
   %0 = shape.broadcast %a : tensor<3xindex> -> tensor<?xindex>
   "use"(%0) : (tensor<?xindex>) -> ()
   return


        


More information about the Mlir-commits mailing list