[Mlir-commits] [mlir] b975e3b - [MLIR] Add canoncalization for `shape.is_broadcastable`
Frederik Gossen
llvmlistbot at llvm.org
Thu Mar 11 01:10:52 PST 2021
Author: Frederik Gossen
Date: 2021-03-11T10:10:34+01:00
New Revision: b975e3b5aa8c6c8b608302997a3bf0fda06bf8d8
URL: https://github.com/llvm/llvm-project/commit/b975e3b5aa8c6c8b608302997a3bf0fda06bf8d8
DIFF: https://github.com/llvm/llvm-project/commit/b975e3b5aa8c6c8b608302997a3bf0fda06bf8d8.diff
LOG: [MLIR] Add canoncalization for `shape.is_broadcastable`
Canonicalize `is_broadcastable` to constant true if fewer than 2 unique shape
operands. Eliminate redundant operands, otherwise.
Differential Revision: https://reviews.llvm.org/D98361
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index a176e6d87673..ae14d81b9e45 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -277,9 +277,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
};
}];
+ let hasCanonicalizer = 1;
+
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
let verifier = [{ return ::verify(*this); }];
-
}
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 719f4bddb58d..741d065822db 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -779,6 +779,44 @@ static LogicalResult verify(IsBroadcastableOp op) {
return success();
}
+namespace {
+struct IsBroadcastableCanonicalizationPattern
+ : public OpRewritePattern<IsBroadcastableOp> {
+ using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IsBroadcastableOp op,
+ PatternRewriter &rewriter) const override {
+ // Find unique operands.
+ SmallVector<Value, 2> unique;
+ for (Value v : op.getOperands()) {
+ if (!llvm::is_contained(unique, v))
+ unique.push_back(v);
+ }
+
+ // Can always broadcast fewer than two shapes.
+ if (unique.size() < 2) {
+ rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
+ rewriter.getBoolAttr(true));
+ return success();
+ }
+
+ // Reduce op to equivalent with unique operands.
+ if (unique.size() < op.getNumOperands()) {
+ rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
+ unique);
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void IsBroadcastableOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
+}
+
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5ee495d66f18..5589221eba82 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1069,3 +1069,28 @@ func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xind
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
return %1 : tensor<?xindex>
}
+
+// -----
+
+// CHECK-LABEL: @is_broadcastable_on_same_shape
+func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 {
+ // CHECK-NOT: is_broadcastable
+ // CHECK: %[[RES:.*]] = constant true
+ // CHECK: return %[[RES]]
+ %0 = shape.is_broadcastable %shape, %shape, %shape
+ : !shape.shape, !shape.shape, !shape.shape
+ return %0 : i1
+}
+
+// -----
+
+// CHECK-LABEL: @is_broadcastable_on_duplicate_shapes
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
+func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
+ -> i1 {
+ // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
+ // CHECK: return %[[RES]]
+ %0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
+ !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
+ return %0 : i1
+}
More information about the Mlir-commits
mailing list