[Mlir-commits] [mlir] a26883e - [MLIR] Add shape.witness type and ops
Tres Popp
llvmlistbot at llvm.org
Fri May 15 05:34:08 PDT 2020
Author: Tres Popp
Date: 2020-05-15T14:33:54+02:00
New Revision: a26883e5aa14f9f0c6de312fb55ec1a13fdc762a
URL: https://github.com/llvm/llvm-project/commit/a26883e5aa14f9f0c6de312fb55ec1a13fdc762a
DIFF: https://github.com/llvm/llvm-project/commit/a26883e5aa14f9f0c6de312fb55ec1a13fdc762a.diff
LOG: [MLIR] Add shape.witness type and ops
Summary: These represent shape based preconditions on execution of code.
Differential Revision: https://reviews.llvm.org/D79717
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 0134ba9381ac..70d0a0da230d 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -30,7 +30,8 @@ enum Kind {
Shape,
Size,
ValueShape,
- LAST_SHAPE_TYPE = ValueShape
+ Witness,
+ LAST_SHAPE_TYPE = Witness
};
} // namespace ShapeTypes
@@ -105,6 +106,22 @@ class ValueShapeType : public Type::TypeBase<ValueShapeType, Type> {
}
};
+/// The Witness represents a runtime constraint, to be used as shape related
+/// preconditions on code execution.
+class WitnessType : public Type::TypeBase<WitnessType, Type> {
+public:
+ using Base::Base;
+
+ static WitnessType get(MLIRContext *context) {
+ return Base::get(context, ShapeTypes::Kind::Witness);
+ }
+
+ /// Support method to enable LLVM-style type casting.
+ static bool kindof(unsigned kind) {
+ return kind == ShapeTypes::Kind::Witness;
+ }
+};
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 4f588f0583a8..47825577921e 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -17,6 +17,32 @@ include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+def Shape_WitnessType : DialectType<ShapeDialect,
+ CPred<"$_self.isa<::mlir::shape::WitnessType>()">, "witness">,
+ BuildableType<"$_builder.getType<::mlir::shape::WitnessType>()"> {
+ let typeDescription = [{
+ A witness is a structural device in the compiler to maintain ordering of
+ code relying on information obtained from passing assertions. Witnesses do
+ not represent any physical data.
+
+ "cstr_" operations will return witnesses and be lowered into assertion logic
+ when not resolvable at compile time.
+
+ "assuming_" operations will take witnesses as input and represent only
+ information to the compiler, so they do not exist in executing code. Code
+ that is dependent on "assuming_" operations can assume all cstr operations
+ transitively before are honored as true.
+
+ These abstractions are intended to allow the compiler more freedom with
+ assertions by merely showing the assertion through dataflow at this time
+ rather than a side effecting operation that acts as a barrier. This can be
+ viewed similarly to a compiler representation of promises from asynchronous,
+ possibly crashing assertions. Reliant code will not be reordered to before
+ the code and non-reliant code can be reordered freely, and there are no
+ guarantees on the final ordering of the assertions or their related code.
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Shape op definitions
//===----------------------------------------------------------------------===//
@@ -313,4 +339,123 @@ def Shape_ConcatOp : Shape_Op<"concat",
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// Shape constraint related ops.
+//===----------------------------------------------------------------------===//
+
+//TODO(tpopp): Move the code below and witnesses to a
diff erent file.
+def Shape_AnyOp : Shape_Op<"any",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Return any combination of the input shapes.";
+ let description = [{
+ This operation takes multiple input shapes and returns some combination of
+ their dimensions. This can be best seen with examples below.
+
+ The result is undefined, but still side-effect free, in cases where the
+ inputs have
diff ering ranks or
diff er in extents of shared dimensions.
+
+ Example:
+ ```mlir
+ %s0 = shape.any([2,?], [?,3]) // [2,3]
+ %s1 = shape.any([?,?], [1,2]) // [1,2]
+ ```
+ }];
+
+ let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
+ let results = (outs Shape_ShapeType:$result);
+}
+
+def Shape_AssumingAllOp : Shape_Op<"assuming_all", []> {
+ let summary = "Return a logical AND of all witnesses.";
+ let description = [{
+ Used to simplify constraints as any single failing precondition is enough
+ to prevent execution.
+
+ "assuming" operations represent an execution order restriction to the
+ compiler, information for dependent code to rely on (by assuming), and
+ nothing else. They should not exist after a program is fully lowered and
+ ready to execute.
+
+ Example:
+ ```mlir
+ %w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success
+ %w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure
+ %w2 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success
+ %wf = shape.assume_all(%w0, %w1) // Failure
+ %wt = shape.assume_all(%w0, %w2) // Success
+ ```
+ }];
+
+ let arguments = (ins Variadic<Shape_WitnessType>:$inputs);
+ let results = (outs Shape_WitnessType:$result);
+}
+
+def Shape_AssumingOp : Shape_Op<"assuming",
+ [SingleBlockImplicitTerminator<"AssumingYieldOp">,
+ RecursiveSideEffects]> {
+ let summary = "Execute the region.";
+ let description = [{
+ Executes the region assuming all witnesses are true.
+
+ "assuming" operations represent an execution order restriction to the
+ compiler, information for dependent code to rely on (by assuming), and
+ nothing else. They should not exist after a program is fully lowered and
+ ready to execute.
+ }];
+ let arguments = (ins Shape_WitnessType);
+ let regions = (region SizedRegion<1>:$thenRegion);
+ let results = (outs Variadic<AnyType>:$results);
+}
+
+def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> {
+ let summary = "Yield operation";
+ let description = [{
+ This yield operation represents a return operation within the assert_and_exec
+ region. The operation takes variable number of operands and produces no
+ results. The operand number and types must match the return signature of
+ the region that contains the operation.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+}
+
+def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
+ let summary = "Determines if 2 shapes can be successfully broadcasted.";
+ let description = [{
+ Given 2 input shapes, return a witness specifying if they are broadcastable.
+ This broadcastable follows the same logic as what shape.broadcast documents.
+
+ "cstr" operations represent runtime assertions.
+
+ Example:
+ ```mlir
+ %w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success
+ %w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure
+ ```
+ }];
+
+ let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
+ let results = (outs Shape_WitnessType:$result);
+}
+
+def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
+ let summary = "Determines if all input shapes are equal.";
+ let description = [{
+ Given 1 or more input shapes, determine if all shapes are the exact same.
+
+ "cstr" operations represent runtime assertions.
+
+ Example:
+ ```mlir
+ %w0 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success
+ %w1 = shape.cstr_eq([2,2], [1,2]) // Failure
+ ```
+ }];
+ let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
+ let results = (outs Shape_WitnessType:$result);
+}
+
+
+// Canonicalization patterns.
+
#endif // SHAPE_OPS
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a420fa73cfad..a66fa8a8128a 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -24,7 +24,8 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
#define GET_OP_LIST
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
>();
- addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
+ addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
+ WitnessType>();
// Allow unknown operations during prototyping and testing. As the dialect is
// still evolving it makes it simple to start with an unregistered ops and
// try
diff erent variants before actually defining the op.
@@ -60,6 +61,8 @@ Type ShapeDialect::parseType(DialectAsmParser &parser) const {
return SizeType::get(getContext());
if (keyword == "value_shape")
return ValueShapeType::get(getContext());
+ if (keyword == "witness")
+ return WitnessType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
return Type();
@@ -83,11 +86,27 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
case ShapeTypes::ValueShape:
os << "value_shape";
return;
+ case ShapeTypes::Witness:
+ os << "witness";
+ return;
default:
llvm_unreachable("unexpected 'shape' type kind");
}
}
+//===----------------------------------------------------------------------===//
+// AnyOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(ShapeType::get(context));
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 81c9afceef5f..dc07e66f84e4 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -67,3 +67,16 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
%0 = shape.shape_of %arg0 : tensor<?xf32>
return %0 : !shape.shape
}
+
+func @test_constraints() {
+ %0 = shape.const_shape []
+ %1 = shape.const_shape [1, 2, 3]
+ %w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
+ %w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
+ %w3 = "shape.assuming_all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness
+ "shape.assuming"(%w3) ( {
+ %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ "shape.assuming_yield"(%2) : (!shape.shape) -> ()
+ }) : (!shape.witness) -> !shape.shape
+ return
+}
More information about the Mlir-commits
mailing list