[Mlir-commits] [mlir] 6b3a5bf - [mlir] Folding of shape.assuming_all

Tres Popp llvmlistbot at llvm.org
Fri Jun 5 02:00:53 PDT 2020


Author: Tres Popp
Date: 2020-06-05T11:00:19+02:00
New Revision: 6b3a5bff93cd9779f1e82a2d6896f35cbd1a44bc

URL: https://github.com/llvm/llvm-project/commit/6b3a5bff93cd9779f1e82a2d6896f35cbd1a44bc
DIFF: https://github.com/llvm/llvm-project/commit/6b3a5bff93cd9779f1e82a2d6896f35cbd1a44bc.diff

LOG: [mlir] Folding of shape.assuming_all

This allows assuming_all to be replaced when all inputs are known to be
statically passing witnesses.

Differential Revision: https://reviews.llvm.org/D80306

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 6e00e5852a52..b66ea38d03da 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -460,7 +460,7 @@ def Shape_AnyOp : Shape_Op<"any", [NoSideEffect]> {
   let assemblyFormat = "$inputs attr-dict";
 }
 
-def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> {
+def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> {
   let summary = "Return a logical AND of all witnesses";
   let description = [{
     Used to simplify constraints as any single failing precondition is enough
@@ -485,6 +485,8 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> {
   let results = (outs Shape_WitnessType:$result);
 
   let assemblyFormat = "$inputs attr-dict";
+
+  let hasFolder = 1;
 }
 
 def Shape_AssumingOp : Shape_Op<"assuming",

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 26928f272f2a..a4a8b2de59fd 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -145,6 +145,30 @@ static void print(OpAsmPrinter &p, AssumingOp op) {
   p.printOptionalAttrDict(op.getAttrs());
 }
 
+//===----------------------------------------------------------------------===//
+// AssumingAllOp
+//===----------------------------------------------------------------------===//
+OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
+  // Iterate in reverse to first handle all constant operands. They are
+  // guaranteed to be the tail of the inputs because this is commutative.
+  for (int idx = operands.size() - 1; idx >= 0; idx--) {
+    Attribute a = operands[idx];
+    // Cannot fold if any inputs are not constant;
+    if (!a)
+      return nullptr;
+
+    // We do not need to keep statically known values after handling them in
+    // this method.
+    getOperation()->eraseOperand(idx);
+
+    // Always false if any input is statically known false
+    if (!a.cast<BoolAttr>().getValue())
+      return a;
+  }
+  // If this is reached, all inputs were statically known passing.
+  return BoolAttr::get(true, getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 69c312e6dad7..646700f8d6bf 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -canonicalize <%s | FileCheck %s --dump-input=fail
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s --dump-input=fail
 
 // -----
 // CHECK-LABEL: func @f
@@ -212,3 +212,36 @@ func @not_const(%arg0: !shape.shape) -> !shape.size {
   %0 = shape.get_extent %arg0, 3
   return %0 : !shape.size
 }
+
+// -----
+// assuming_all with known passing witnesses can be folded
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.const_witness true
+  %1 = shape.const_witness true
+  %2 = shape.const_witness true
+  %3 = shape.assuming_all %0, %1, %2
+  "consume.witness"(%3) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// assuming_all should not be removed if not all witnesses are statically passing.
+//
+// Additionally check that the attribute is moved to the end as this op is
+// commutative.
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source"
+  // CHECK-NEXT: shape.assuming_all %[[UNKNOWN]]
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.const_witness true
+  %1 = "test.source"() : () -> !shape.witness
+  %2 = shape.assuming_all %0, %1
+  "consume.witness"(%2) : (!shape.witness) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list