[Mlir-commits] [mlir] 1c3e38d - [mlir] Add a shape op that returns a constant witness

Tres Popp llvmlistbot at llvm.org
Fri Jun 5 02:00:51 PDT 2020


Author: Tres Popp
Date: 2020-06-05T11:00:19+02:00
New Revision: 1c3e38d98c916104c675afa30ad2dd4343e9e923

URL: https://github.com/llvm/llvm-project/commit/1c3e38d98c916104c675afa30ad2dd4343e9e923
DIFF: https://github.com/llvm/llvm-project/commit/1c3e38d98c916104c675afa30ad2dd4343e9e923.diff

LOG: [mlir] Add a shape op that returns a constant witness

This will later be used during canonicalization and folding steps to replace
statically known passing constraints.

Differential Revision: https://reviews.llvm.org/D80307

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index bfc4ecd66b76..6e00e5852a52 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -473,11 +473,11 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> {
 
     Example:
     ```mlir
-      %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Success
+      %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
       %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
-      %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Success
+      %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
       %wf = shape.assuming_all %w0, %w1 // Failure
-      %wt = shape.assuming_all %w0, %w2 // Success
+      %wt = shape.assuming_all %w0, %w2 // Passing
     ```
   }];
 
@@ -537,7 +537,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
 
     Example:
     ```mlir
-      %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Success
+      %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
       %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
     ```
   }];
@@ -557,7 +557,7 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
 
     Example:
     ```mlir
-      %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Success
+      %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
       %w1 = shape.cstr_eq [2,2], [1,2] // Failure
     ```
   }];
@@ -567,6 +567,28 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
   let assemblyFormat = "$inputs attr-dict";
 }
 
-// Canonicalization patterns.
+def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> {
+  let summary = "An operation that returns a statically known witness value";
+  let description = [{
+  This operation represents a statically known witness result. This can be
+  often used to canonicalize/fold constraint and assuming code that will always
+  pass.
+
+  ```mlir
+  %0 = shape.const_shape [1,2,3]
+  %1 = shape.const_shape [1, 2, 3]
+  %w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true"
+  %w1 = shape.const_witness true
+  %w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true"
+  ```
+  }];
+  let arguments = (ins BoolAttr:$passing);
+  let results = (outs Shape_WitnessType:$result);
+
+  let assemblyFormat = "$passing attr-dict";
+
+  let hasFolder = 1;
+}
+
 
 #endif // SHAPE_OPS

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 5f7301f29dab..26928f272f2a 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -42,6 +42,9 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
   if (auto sizeType = type.dyn_cast<SizeType>()) {
     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
   }
+  if (auto witnessType = type.dyn_cast<WitnessType>()) {
+    return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
+  }
   return nullptr;
 }
 
@@ -229,6 +232,12 @@ OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
 
 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
 
+//===----------------------------------------------------------------------===//
+// ConstWitnessOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
+
 //===----------------------------------------------------------------------===//
 // IndexToSizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 5ccd3ffba3ff..5f316d9988b8 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -77,8 +77,10 @@ func @test_constraints() {
   %1 = shape.const_shape [1, 2, 3]
   %w0 = shape.cstr_broadcastable %0, %1
   %w1 = shape.cstr_eq %0, %1
-  %w3 = shape.assuming_all %w0, %w1
-  shape.assuming %w3 -> !shape.shape {
+  %w2 = shape.const_witness true
+  %w3 = shape.const_witness false
+  %w4 = shape.assuming_all %w0, %w1, %w2, %w3
+  shape.assuming %w4 -> !shape.shape {
     %2 = shape.any %0, %1
     shape.assuming_yield %2 : !shape.shape
   }


        


More information about the Mlir-commits mailing list