[Mlir-commits] [mlir] 0a554e6 - [mlir] Folding and canonicalization of shape.cstr_eq
Tres Popp
llvmlistbot at llvm.org
Fri Jun 5 02:01:00 PDT 2020
Author: Tres Popp
Date: 2020-06-05T11:00:20+02:00
New Revision: 0a554e607ff6247b79d1c4f184999750e5ad53b9
URL: https://github.com/llvm/llvm-project/commit/0a554e607ff6247b79d1c4f184999750e5ad53b9
DIFF: https://github.com/llvm/llvm-project/commit/0a554e607ff6247b79d1c4f184999750e5ad53b9.diff
LOG: [mlir] Folding and canonicalization of shape.cstr_eq
In the case of all inputs being constant and equal, cstr_eq will be
replaced with a true_witness.
Differential Revision: https://reviews.llvm.org/D80303
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index a05273f2b3f8..6fb7cbf1f3b7 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -555,7 +555,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
let hasFolder = 1;
}
-def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
+def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
let summary = "Determines if all input shapes are equal";
let description = [{
Given 1 or more input shapes, determine if all shapes are the exact same.
@@ -572,6 +572,9 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
let results = (outs Shape_WitnessType:$result);
let assemblyFormat = "$inputs attr-dict";
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3a8831c3d14f..e12e23ba128c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -290,6 +290,27 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// CstrEqOp
+//===----------------------------------------------------------------------===//
+
+void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context) {
+ // If inputs are equal, return passing witness
+ patterns.insert<CstrEqEqOps>(context);
+}
+
+OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
+ if (llvm::all_of(operands,
+ [&](Attribute a) { return a && a == operands[0]; }))
+ return BoolAttr::get(true, getContext());
+
+ // Because a failing witness result here represents an eventual assertion
+ // failure, we do not try to replace it with a constant witness. Similarly, we
+ // cannot if there are any non-const inputs.
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// ConstSizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 9a73a8847779..78c9119f1292 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -2,7 +2,17 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td"
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
+def AllInputShapesEq : Constraint<CPred< [{
+ llvm::all_of($0, [&](mlir::Value val) {
+ return $0[0] == val;
+ })
+}]>>;
+
// Canonicalization patterns.
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
(Shape_ConstWitnessOp ConstBoolAttrTrue),
[(EqualBinaryOperands $lhs, $rhs)]>;
+
+def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
+ (Shape_ConstWitnessOp ConstBoolAttrTrue),
+ [(AllInputShapesEq $shapes)]>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 93ce36ae8198..32fa496e7347 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -213,6 +213,62 @@ func @not_const(%arg0: !shape.shape) -> !shape.size {
return %0 : !shape.size
}
+
+// -----
+// cstr_eq with non-constant but known equal shapes can be removed.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = shape.cstr_eq %arg0, %arg0, %arg0
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// cstr_eq with equal const_shapes can be folded
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [0, 1]
+ %cs1 = shape.const_shape [0, 1]
+ %cs2 = shape.const_shape [0, 1]
+ %0 = shape.cstr_eq %cs0, %cs1, %cs2
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// cstr_eq with unequal const_shapes cannot be folded
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: shape.const_shape
+ // CHECK-NEXT: shape.const_shape
+ // CHECK-NEXT: shape.cstr_eq
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %cs0 = shape.const_shape [0, 1]
+ %cs1 = shape.const_shape [3, 1]
+ %0 = shape.cstr_eq %cs0, %cs1
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+// cstr_eq without const_shapes cannot be folded
+// CHECK-LABEL: func @f
+func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
+ // CHECK-NEXT: shape.cstr_eq
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = shape.cstr_eq %arg0, %arg1
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
// -----
// assuming_all with known passing witnesses can be folded
// CHECK-LABEL: func @f
More information about the Mlir-commits
mailing list