[Mlir-commits] [mlir] bae6374 - [mlir][shape] Add `shape.cstr_require %bool`

Sean Silva llvmlistbot at llvm.org
Thu Sep 17 16:57:04 PDT 2020


Author: Sean Silva
Date: 2020-09-17T16:56:43-07:00
New Revision: bae63742057785e03732f58d6ed1ec7bda090cc1

URL: https://github.com/llvm/llvm-project/commit/bae63742057785e03732f58d6ed1ec7bda090cc1
DIFF: https://github.com/llvm/llvm-project/commit/bae63742057785e03732f58d6ed1ec7bda090cc1.diff

LOG: [mlir][shape] Add `shape.cstr_require %bool`

This op is a catch-all for creating witnesses from various random kinds
of constraints. In particular, I when dealing with extents directly,
which are of `index` type, one can directly use std ops for calculating
the predicates, and then use cstr_require for the final conversion to a
witness.

Differential Revision: https://reviews.llvm.org/D87871

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2e8f03237039..ed89ce36fb8a 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -738,5 +738,27 @@ def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect
   let hasFolder = 1;
 }
 
+def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
+  let summary = "Represents a runtime assertion that an i1 is `true`";
+  let description = [{
+    Represents a runtime assretion that an i1 is true. It returns a
+    !shape.witness to order this assertion.
+
+    For simplicity, prefer using other cstr_* ops if they are available for a
+    given constraint.
+
+    Example:
+    ```mlir
+    %bool = ...
+    %w0 = shape.cstr_require %bool // Passing if `%bool` is true.
+    ```
+  }];
+  let arguments = (ins I1:$pred);
+  let results = (outs Shape_WitnessType:$result);
+
+  let assemblyFormat = "$pred attr-dict";
+
+  let hasFolder = 1;
+}
 
 #endif // SHAPE_OPS

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3be53ee2a833..70621295e39c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -490,6 +490,14 @@ void ConstSizeOp::getAsmResultNames(
 
 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
 
+//===----------------------------------------------------------------------===//
+// CstrRequireOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
+  return operands[0];
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeEqOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 670d207a5f47..9c45f254ba6d 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -386,7 +386,31 @@ func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
 }
 
 // -----
+// cstr_require with constant can be folded
+// CHECK-LABEL: func @cstr_require_fold
+func @cstr_require_fold() {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %true = constant true
+  %0 = shape.cstr_require %true
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// cstr_require without constant cannot be folded
+// CHECK-LABEL: func @cstr_require_no_fold
+func @cstr_require_no_fold(%arg0: i1) {
+  // CHECK-NEXT: shape.cstr_require
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.cstr_require %arg0
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
 
+// -----
 // assuming_all with known passing witnesses can be folded
 // CHECK-LABEL: func @f
 func @f() {

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 58f2a61841e2..1a431d2dbd2f 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -100,12 +100,14 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
 func @test_constraints() {
   %0 = shape.const_shape [] : !shape.shape
   %1 = shape.const_shape [1, 2, 3] : !shape.shape
+  %true = constant true
   %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   %w1 = shape.cstr_eq %0, %1
   %w2 = shape.const_witness true
   %w3 = shape.const_witness false
-  %w4 = shape.assuming_all %w0, %w1, %w2, %w3
-  shape.assuming %w4 -> !shape.shape {
+  %w4 = shape.cstr_require %true
+  %w_all = shape.assuming_all %w0, %w1, %w2, %w3, %w4
+  shape.assuming %w_all -> !shape.shape {
     %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
     shape.assuming_yield %2 : !shape.shape
   }


        


More information about the Mlir-commits mailing list