[Mlir-commits] [mlir] d216f98 - Revert "Revert "[mlir] Canonicalization and folding of shape.cstr_broadcastable""

Tres Popp llvmlistbot at llvm.org
Mon Jun 8 01:07:54 PDT 2020


Author: Tres Popp
Date: 2020-06-08T10:06:55+02:00
New Revision: d216f983e61980c26a6945280befb588ca5e0755

URL: https://github.com/llvm/llvm-project/commit/d216f983e61980c26a6945280befb588ca5e0755
DIFF: https://github.com/llvm/llvm-project/commit/d216f983e61980c26a6945280befb588ca5e0755.diff

LOG: Revert "Revert "[mlir] Canonicalization and folding of shape.cstr_broadcastable""

This reverts commit 4261b026ad5b97231be25f28fe2b0f8a84d82d13.

Added: 
    mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/CMakeLists.txt
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 40fde4837407..b9c63e3e48e8 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -546,7 +546,7 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
   let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
 }
 
-def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
+def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
   let summary = "Determines if 2 shapes can be successfully broadcasted";
   let description = [{
     Given 2 input shapes, return a witness specifying if they are broadcastable.
@@ -565,6 +565,9 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
   let results = (outs Shape_WitnessType:$result);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict";
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {

diff  --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index da6c04524317..84e085fbafdf 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS ShapeCanonicalization.td)
+mlir_tablegen(ShapeCanonicalization.inc -gen-rewriters)
+add_public_tablegen_target(MLIRShapeCanonicalizationIncGen)
+
 add_mlir_dialect_library(MLIRShape
   Shape.cpp
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index d29f48e7c51e..0bdbf17c8af5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -18,6 +18,10 @@
 using namespace mlir;
 using namespace mlir::shape;
 
+namespace {
+#include "ShapeCanonicalization.inc"
+}
+
 ShapeDialect::ShapeDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
   addOperations<
@@ -295,6 +299,32 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
 
 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
 
+//===----------------------------------------------------------------------===//
+// CstrBroadcastableOp
+//===----------------------------------------------------------------------===//
+
+void CstrBroadcastableOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  // If inputs are equal, return passing witness
+  patterns.insert<CstrBroadcastableEqOps>(context);
+}
+
+OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[0] || !operands[1])
+    return nullptr;
+  auto lhsShape = llvm::to_vector<6>(
+      operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  auto rhsShape = llvm::to_vector<6>(
+      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+  SmallVector<int64_t, 6> resultShape;
+  if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
+    return BoolAttr::get(true, getContext());
+
+  // Because a failing witness result here represents an eventual assertion
+  // failure, we do not replace it with a constant witness.
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstSizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
new file mode 100644
index 000000000000..9a73a8847779
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -0,0 +1,8 @@
+include "mlir/Dialect/Shape/IR/ShapeOps.td"
+
+def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
+
+// Canonicalization patterns.
+def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
+  (Shape_ConstWitnessOp ConstBoolAttrTrue),
+  [(EqualBinaryOperands $lhs, $rhs)]>;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b7e0d8672678..5ebca6784c0e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -303,3 +303,59 @@ func @f() {
   "test.sink"(%1) : (index) -> ()
   return
 }
+
+// -----
+// Broadcastable with broadcastable constant shapes can be removed.
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [3, 1]
+  %cs1 = shape.const_shape [1, 5]
+  %0 = shape.cstr_broadcastable %cs0, %cs1
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// Broadcastable with non-broadcastable constant shapes is always false
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_shape
+  // CHECK-NEXT: shape.const_shape
+  // CHECK-NEXT: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [1, 3]
+  %cs1 = shape.const_shape [1, 5]
+  %0 = shape.cstr_broadcastable %cs0, %cs1
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// Broadcastable without guaranteed broadcastable shapes cannot be removed.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+  // CHECK-NEXT: shape.const_shape
+  // CHECK-NEXT: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [1,3]
+  %0 = shape.cstr_broadcastable %arg0, %cs0
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// Broadcastable with non-constant but known equal shapes can be removed.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.cstr_broadcastable %arg0, %arg0
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list