[Mlir-commits] [mlir] 1c3e38d - [mlir] Add a shape op that returns a constant witness
Tres Popp
llvmlistbot at llvm.org
Fri Jun 5 02:00:51 PDT 2020
Author: Tres Popp
Date: 2020-06-05T11:00:19+02:00
New Revision: 1c3e38d98c916104c675afa30ad2dd4343e9e923
URL: https://github.com/llvm/llvm-project/commit/1c3e38d98c916104c675afa30ad2dd4343e9e923
DIFF: https://github.com/llvm/llvm-project/commit/1c3e38d98c916104c675afa30ad2dd4343e9e923.diff
LOG: [mlir] Add a shape op that returns a constant witness
This will later be used during canonicalization and folding steps to replace
statically known passing constraints.
Differential Revision: https://reviews.llvm.org/D80307
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index bfc4ecd66b76..6e00e5852a52 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -473,11 +473,11 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> {
Example:
```mlir
- %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Success
+ %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
%w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
- %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Success
+ %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
%wf = shape.assuming_all %w0, %w1 // Failure
- %wt = shape.assuming_all %w0, %w2 // Success
+ %wt = shape.assuming_all %w0, %w2 // Passing
```
}];
@@ -537,7 +537,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
Example:
```mlir
- %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Success
+ %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
%w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
```
}];
@@ -557,7 +557,7 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
Example:
```mlir
- %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Success
+ %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
%w1 = shape.cstr_eq [2,2], [1,2] // Failure
```
}];
@@ -567,6 +567,28 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
let assemblyFormat = "$inputs attr-dict";
}
-// Canonicalization patterns.
+def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> {
+ let summary = "An operation that returns a statically known witness value";
+ let description = [{
+ This operation represents a statically known witness result. This can be
+ often used to canonicalize/fold constraint and assuming code that will always
+ pass.
+
+ ```mlir
+ %0 = shape.const_shape [1,2,3]
+ %1 = shape.const_shape [1, 2, 3]
+ %w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true"
+ %w1 = shape.const_witness true
+ %w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true"
+ ```
+ }];
+ let arguments = (ins BoolAttr:$passing);
+ let results = (outs Shape_WitnessType:$result);
+
+ let assemblyFormat = "$passing attr-dict";
+
+ let hasFolder = 1;
+}
+
#endif // SHAPE_OPS
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 5f7301f29dab..26928f272f2a 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -42,6 +42,9 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
if (auto sizeType = type.dyn_cast<SizeType>()) {
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
}
+ if (auto witnessType = type.dyn_cast<WitnessType>()) {
+ return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
+ }
return nullptr;
}
@@ -229,6 +232,12 @@ OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
+//===----------------------------------------------------------------------===//
+// ConstWitnessOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
+
//===----------------------------------------------------------------------===//
// IndexToSizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 5ccd3ffba3ff..5f316d9988b8 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -77,8 +77,10 @@ func @test_constraints() {
%1 = shape.const_shape [1, 2, 3]
%w0 = shape.cstr_broadcastable %0, %1
%w1 = shape.cstr_eq %0, %1
- %w3 = shape.assuming_all %w0, %w1
- shape.assuming %w3 -> !shape.shape {
+ %w2 = shape.const_witness true
+ %w3 = shape.const_witness false
+ %w4 = shape.assuming_all %w0, %w1, %w2, %w3
+ shape.assuming %w4 -> !shape.shape {
%2 = shape.any %0, %1
shape.assuming_yield %2 : !shape.shape
}
More information about the Mlir-commits
mailing list