[Mlir-commits] [mlir] b55f424 - [MLIR] Add canonicalization for `shape.broadcast`

Frederik Gossen llvmlistbot at llvm.org
Mon Mar 15 02:11:48 PDT 2021


Author: Frederik Gossen
Date: 2021-03-15T10:11:28+01:00
New Revision: b55f424ffcaca1d639ded583b9dc8ba151d92e2d

URL: https://github.com/llvm/llvm-project/commit/b55f424ffcaca1d639ded583b9dc8ba151d92e2d
DIFF: https://github.com/llvm/llvm-project/commit/b55f424ffcaca1d639ded583b9dc8ba151d92e2d.diff

LOG: [MLIR] Add canonicalization for `shape.broadcast`

Remove redundant operands and fold if only one left.

Differential Revision: https://reviews.llvm.org/D98402

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index ae14d81b9e45..a17be3834dc2 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -89,6 +89,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
   ];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let verifier = [{ return ::verify(*this); }];
 }
 
@@ -277,10 +278,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
     };
   }];
 
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 
   let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
-  let verifier = [{ return ::verify(*this); }];
 }
 
 def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 741d065822db..472197c52f4e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -354,13 +354,16 @@ static LogicalResult verify(AssumingAllOp op) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
-  if (!operands[1])
-    return nullptr;
+  if (operands.size() == 1)
+    return shapes().front();
 
   // TODO: Support folding with more than 2 input shapes
   if (shapes().size() > 2)
     return nullptr;
 
+  if (!operands[1])
+    return nullptr;
+
   auto rhsShape = llvm::to_vector<6>(
       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
   if (rhsShape.empty())
@@ -384,13 +387,40 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
 }
 
 static LogicalResult verify(BroadcastOp op) {
-  // Ensure that AssumingAllOp contains at least one operand
-  if (op.getNumOperands() < 2)
-    return op.emitOpError("required at least 2 input shapes");
-
   return verifyShapeOrExtentTensorOp(op);
 }
 
+namespace {
+template <typename OpTy>
+struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Find unique operands.
+    SmallVector<Value, 2> unique;
+    for (Value v : op.getOperands()) {
+      if (!llvm::is_contained(unique, v))
+        unique.push_back(v);
+    }
+
+    // Reduce op to equivalent with unique operands.
+    if (unique.size() < op.getNumOperands()) {
+      rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
+                                        op.getAttrs());
+      return success();
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void BroadcastOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ConcatOp
 //===----------------------------------------------------------------------===//
@@ -772,49 +802,18 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
 // IsBroadcastableOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(IsBroadcastableOp op) {
-  // Ensure that AssumingAllOp contains at least one operand
-  if (op.getNumOperands() < 2)
-    return op.emitOpError("required at least 2 input shapes");
-  return success();
+void IsBroadcastableOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
 }
 
-namespace {
-struct IsBroadcastableCanonicalizationPattern
-    : public OpRewritePattern<IsBroadcastableOp> {
-  using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IsBroadcastableOp op,
-                                PatternRewriter &rewriter) const override {
-    // Find unique operands.
-    SmallVector<Value, 2> unique;
-    for (Value v : op.getOperands()) {
-      if (!llvm::is_contained(unique, v))
-        unique.push_back(v);
-    }
-
-    // Can always broadcast fewer than two shapes.
-    if (unique.size() < 2) {
-      rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
-                                                    rewriter.getBoolAttr(true));
-      return success();
-    }
-
-    // Reduce op to equivalent with unique operands.
-    if (unique.size() < op.getNumOperands()) {
-      rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
-                                                     unique);
-      return success();
-    }
-
-    return failure();
+OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+  // Can always broadcast fewer than two shapes.
+  if (operands.size() < 2) {
+    return BoolAttr::get(getContext(), true);
   }
-};
-} // namespace
 
-void IsBroadcastableOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
+  return nullptr;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5589221eba82..53f27e4839cf 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1088,9 +1088,34 @@ func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 {
 // CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
 func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
     -> i1 {
-  // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
+  // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]] :
   // CHECK: return %[[RES]]
   %0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
       !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
   return %0 : i1
 }
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_same_shape
+// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
+func @broadcast_on_same_shape(%shape : !shape.shape) -> !shape.shape {
+  // CHECK-NOT: broadcast
+  // CHECK: return %[[SHAPE]]
+  %0 = shape.broadcast %shape, %shape, %shape : !shape.shape, !shape.shape,
+      !shape.shape -> !shape.shape
+  return %0 : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_duplicate_shapes
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
+func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
+    -> !shape.shape {
+  // CHECK: %[[RES:.*]] = shape.broadcast %[[A]], %[[B]] :
+  // CHECK: return %[[RES]]
+  %0 = shape.broadcast %a, %b, %a, %a, %a, %b : !shape.shape, !shape.shape,
+      !shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape
+  return %0 : !shape.shape
+}

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d685e6766072..d42f0fac4b5c 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -249,18 +249,8 @@ module attributes {shape.lib = @fn} { }
 
 // -----
 
-func @fn(%arg: !shape.shape) -> i1 {
-  // expected-error at +1 {{required at least 2 input shapes}}
-  %0 = shape.is_broadcastable %arg : !shape.shape
-  return %0 : i1
-}
-
-// -----
-
 func @fn(%arg: !shape.shape) -> !shape.witness {
   // expected-error at +1 {{required at least 2 input shapes}}
   %0 = shape.cstr_broadcastable %arg : !shape.shape
   return %0 : !shape.witness
 }
-
-


        


More information about the Mlir-commits mailing list