[Mlir-commits] [mlir] 1ce70c1 - [MLIR] Canonicalize broadcast operations on single shapes

Frederik Gossen llvmlistbot at llvm.org
Thu Mar 18 01:00:19 PDT 2021


Author: Frederik Gossen
Date: 2021-03-18T08:59:50+01:00
New Revision: 1ce70c15ed3b9c84d6d73abd74f6605bccdf2e7b

URL: https://github.com/llvm/llvm-project/commit/1ce70c15ed3b9c84d6d73abd74f6605bccdf2e7b
DIFF: https://github.com/llvm/llvm-project/commit/1ce70c15ed3b9c84d6d73abd74f6605bccdf2e7b.diff

LOG: [MLIR] Canonicalize broadcast operations on single shapes

This covers cases that are not folded away because the extent tensor type
becomes more concrete in the process.

Differential Revision: https://reviews.llvm.org/D98782

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ed8dcfc13549..33719951f3e9 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -414,11 +414,26 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
     return failure();
   }
 };
+
+struct BroadcastForwardSingleOperandPattern
+    : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getNumOperands() == 1) {
+      rewriter.replaceOp(op, op.shapes().front());
+      return success();
+    }
+    return failure();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+  patterns.insert<BroadcastForwardSingleOperandPattern,
+                  RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 53f27e4839cf..3399fe0f4e23 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1119,3 +1119,15 @@ func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
       !shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape
   return %0 : !shape.shape
 }
+
+// -----
+
+// CHECK-LABEL: @broadcast_on_single_operand
+// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
+func @broadcast_on_single_operand(%a : tensor<3xindex>) {
+  // CHECK-NOT: broadcast
+  // CHECK: "use"(%[[A]])
+  %0 = shape.broadcast %a : tensor<3xindex> -> tensor<?xindex>
+  "use"(%0) : (tensor<?xindex>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list