[Mlir-commits] [mlir] 6aab709 - [mlir] Canonicalization and folding of shape.cstr_broadcastable
Tres Popp
llvmlistbot at llvm.org
Fri Jun 5 02:00:58 PDT 2020
Author: Tres Popp
Date: 2020-06-05T11:00:19+02:00
New Revision: 6aab70945915ef1d565f1146734416029549a5a9
URL: https://github.com/llvm/llvm-project/commit/6aab70945915ef1d565f1146734416029549a5a9
DIFF: https://github.com/llvm/llvm-project/commit/6aab70945915ef1d565f1146734416029549a5a9.diff
LOG: [mlir] Canonicalization and folding of shape.cstr_broadcastable
This allows replacing of this op with a true witness in the case of both
inputs being const_shapes and being found to be broadcastable.
Differential Revision: https://reviews.llvm.org/D80304
Added:
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/CMakeLists.txt
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 075050a8c2b1..a05273f2b3f8 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -531,7 +531,7 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
-def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
+def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
let summary = "Determines if 2 shapes can be successfully broadcasted";
let description = [{
Given 2 input shapes, return a witness specifying if they are broadcastable.
@@ -550,6 +550,9 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
let results = (outs Shape_WitnessType:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict";
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt
index 2af3de896568..0a03849722cb 100644
--- a/mlir/lib/Dialect/Shape/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS IR/ShapeCanonicalization.td)
+mlir_tablegen(IR/ShapeCanonicalization.inc -gen-rewriters)
+add_public_tablegen_target(MLIRShapeCanonicalizationIncGen)
+
add_mlir_dialect_library(MLIRShape
IR/Shape.cpp
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2b05c4c65bab..3a8831c3d14f 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -18,6 +18,10 @@
using namespace mlir;
using namespace mlir::shape;
+namespace {
+#include "IR/ShapeCanonicalization.inc"
+}
+
ShapeDialect::ShapeDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
@@ -260,6 +264,32 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
+//===----------------------------------------------------------------------===//
+// CstrBroadcastableOp
+//===----------------------------------------------------------------------===//
+
+void CstrBroadcastableOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ // If inputs are equal, return passing witness
+ patterns.insert<CstrBroadcastableEqOps>(context);
+}
+
+OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0] || !operands[1])
+ return nullptr;
+ auto lhsShape = llvm::to_vector<6>(
+ operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ auto rhsShape = llvm::to_vector<6>(
+ operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ SmallVector<int64_t, 6> resultShape;
+ if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
+ return BoolAttr::get(true, getContext());
+
+ // Because a failing witness result here represents an eventual assertion
+ // failure, we do not replace it with a constant witness.
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// ConstSizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
new file mode 100644
index 000000000000..9a73a8847779
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -0,0 +1,8 @@
+include "mlir/Dialect/Shape/IR/ShapeOps.td"
+
+def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
+
+// Canonicalization patterns.
+def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
+ (Shape_ConstWitnessOp ConstBoolAttrTrue),
+ [(EqualBinaryOperands $lhs, $rhs)]>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 0f92c18c25a8..93ce36ae8198 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -267,3 +267,59 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
%1 = shape.any %arg0, %arg1
return %1 : !shape.shape
}
+
+// -----
+// Broadcastable with broadcastable constant shapes can be removed.
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [3, 1]
+ %cs1 = shape.const_shape [1, 5]
+ %0 = shape.cstr_broadcastable %cs0, %cs1
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// Broadcastable with non-broadcastable constant shapes is always false
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: shape.const_shape
+ // CHECK-NEXT: shape.const_shape
+ // CHECK-NEXT: shape.cstr_broadcastable
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [1, 3]
+ %cs1 = shape.const_shape [1, 5]
+ %0 = shape.cstr_broadcastable %cs0, %cs1
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// Broadcastable without guaranteed broadcastable shapes cannot be removed.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+ // CHECK-NEXT: shape.const_shape
+ // CHECK-NEXT: shape.cstr_broadcastable
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [1,3]
+ %0 = shape.cstr_broadcastable %arg0, %cs0
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// Broadcastable with non-constant but known equal shapes can be removed.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = shape.cstr_broadcastable %arg0, %arg0
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
More information about the Mlir-commits
mailing list