[Mlir-commits] [mlir] a26883e - [MLIR] Add shape.witness type and ops

Tres Popp llvmlistbot at llvm.org
Fri May 15 05:34:08 PDT 2020


Author: Tres Popp
Date: 2020-05-15T14:33:54+02:00
New Revision: a26883e5aa14f9f0c6de312fb55ec1a13fdc762a

URL: https://github.com/llvm/llvm-project/commit/a26883e5aa14f9f0c6de312fb55ec1a13fdc762a
DIFF: https://github.com/llvm/llvm-project/commit/a26883e5aa14f9f0c6de312fb55ec1a13fdc762a.diff

LOG: [MLIR] Add shape.witness type and ops

Summary: These represent shape based preconditions on execution of code.

Differential Revision: https://reviews.llvm.org/D79717

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/Shape.h
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 0134ba9381ac..70d0a0da230d 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -30,7 +30,8 @@ enum Kind {
   Shape,
   Size,
   ValueShape,
-  LAST_SHAPE_TYPE = ValueShape
+  Witness,
+  LAST_SHAPE_TYPE = Witness
 };
 } // namespace ShapeTypes
 
@@ -105,6 +106,22 @@ class ValueShapeType : public Type::TypeBase<ValueShapeType, Type> {
   }
 };
 
+/// The Witness represents a runtime constraint, to be used as shape related
+/// preconditions on code execution.
+class WitnessType : public Type::TypeBase<WitnessType, Type> {
+public:
+  using Base::Base;
+
+  static WitnessType get(MLIRContext *context) {
+    return Base::get(context, ShapeTypes::Kind::Witness);
+  }
+
+  /// Support method to enable LLVM-style type casting.
+  static bool kindof(unsigned kind) {
+    return kind == ShapeTypes::Kind::Witness;
+  }
+};
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Shape/IR/ShapeOps.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 4f588f0583a8..47825577921e 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -17,6 +17,32 @@ include "mlir/Dialect/Shape/IR/ShapeBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+def Shape_WitnessType : DialectType<ShapeDialect,
+    CPred<"$_self.isa<::mlir::shape::WitnessType>()">, "witness">,
+    BuildableType<"$_builder.getType<::mlir::shape::WitnessType>()"> {
+  let typeDescription = [{
+    A witness is a structural device in the compiler to maintain ordering of
+    code relying on information obtained from passing assertions. Witnesses do
+    not represent any physical data.
+
+    "cstr_" operations will return witnesses and be lowered into assertion logic
+    when not resolvable at compile time.
+
+    "assuming_" operations will take witnesses as input and represent only
+    information to the compiler, so they do not exist in executing code. Code
+    that is dependent on "assuming_" operations can assume all cstr operations
+    transitively before are honored as true.
+
+    These abstractions are intended to allow the compiler more freedom with
+    assertions by merely showing the assertion through dataflow at this time
+    rather than a side effecting operation that acts as a barrier. This can be
+    viewed similarly to a compiler representation of promises from asynchronous,
+    possibly crashing assertions. Reliant code will not be reordered to before
+    the code and non-reliant code can be reordered freely, and there are no
+    guarantees on the final ordering of the assertions or their related code.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Shape op definitions
 //===----------------------------------------------------------------------===//
@@ -313,4 +339,123 @@ def Shape_ConcatOp : Shape_Op<"concat",
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Shape constraint related ops.
+//===----------------------------------------------------------------------===//
+
+//TODO(tpopp): Move the code below and witnesses to a 
diff erent file.
+def Shape_AnyOp : Shape_Op<"any",
+    [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Return any combination of the input shapes.";
+  let description = [{
+    This operation takes multiple input shapes and returns some combination of
+    their dimensions. This can be best seen with examples below.
+
+    The result is undefined, but still side-effect free, in cases where the
+    inputs have 
diff ering ranks or 
diff er in extents of shared dimensions.
+
+    Example:
+    ```mlir
+      %s0 = shape.any([2,?], [?,3]) // [2,3]
+      %s1 = shape.any([?,?], [1,2]) // [1,2]
+    ```
+  }];
+
+  let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
+  let results = (outs Shape_ShapeType:$result);
+}
+
+def Shape_AssumingAllOp : Shape_Op<"assuming_all", []> {
+  let summary = "Return a logical AND of all witnesses.";
+  let description = [{
+    Used to simplify constraints as any single failing precondition is enough
+    to prevent execution.
+
+    "assuming" operations represent an execution order restriction to the
+    compiler, information for dependent code to rely on (by assuming), and
+    nothing else. They should not exist after a program is fully lowered and
+    ready to execute.
+
+    Example:
+    ```mlir
+      %w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success
+      %w1 = shape.cstr_broadcastable([2,2], [3,2])    // Failure
+      %w2 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success
+      %wf = shape.assume_all(%w0, %w1) // Failure
+      %wt = shape.assume_all(%w0, %w2) // Success
+    ```
+  }];
+
+  let arguments = (ins Variadic<Shape_WitnessType>:$inputs);
+  let results = (outs Shape_WitnessType:$result);
+}
+
+def Shape_AssumingOp : Shape_Op<"assuming",
+                           [SingleBlockImplicitTerminator<"AssumingYieldOp">,
+                            RecursiveSideEffects]> {
+  let summary = "Execute the region.";
+  let description = [{
+    Executes the region assuming all witnesses are true.
+
+    "assuming" operations represent an execution order restriction to the
+    compiler, information for dependent code to rely on (by assuming), and
+    nothing else. They should not exist after a program is fully lowered and
+    ready to execute.
+  }];
+  let arguments = (ins Shape_WitnessType);
+  let regions = (region SizedRegion<1>:$thenRegion);
+  let results = (outs Variadic<AnyType>:$results);
+}
+
+def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> {
+  let summary = "Yield operation";
+  let description = [{
+    This yield operation represents a return operation within the assert_and_exec
+    region. The operation takes variable number of operands and produces no
+    results. The operand number and types must match the return signature of
+    the region that contains the operation.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+}
+
+def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
+  let summary = "Determines if 2 shapes can be successfully broadcasted.";
+  let description = [{
+    Given 2 input shapes, return a witness specifying if they are broadcastable.
+    This broadcastable follows the same logic as what shape.broadcast documents.
+
+    "cstr" operations represent runtime assertions.
+
+    Example:
+    ```mlir
+      %w0 = shape.cstr_broadcastable([2,2], [3,1,2])  // Success
+      %w1 = shape.cstr_broadcastable([2,2], [3,2])    // Failure
+    ```
+  }];
+
+  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
+  let results = (outs Shape_WitnessType:$result);
+}
+
+def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
+  let summary = "Determines if all input shapes are equal.";
+  let description = [{
+    Given 1 or more input shapes, determine if all shapes are the exact same.
+
+    "cstr" operations represent runtime assertions.
+
+    Example:
+    ```mlir
+      %w0 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success
+      %w1 = shape.cstr_eq([2,2], [1,2])        // Failure
+    ```
+  }];
+  let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
+  let results = (outs Shape_WitnessType:$result);
+}
+
+
+// Canonicalization patterns.
+
 #endif // SHAPE_OPS

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a420fa73cfad..a66fa8a8128a 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -24,7 +24,8 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
 #define GET_OP_LIST
 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
       >();
-  addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
+  addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
+           WitnessType>();
   // Allow unknown operations during prototyping and testing. As the dialect is
   // still evolving it makes it simple to start with an unregistered ops and
   // try 
diff erent variants before actually defining the op.
@@ -60,6 +61,8 @@ Type ShapeDialect::parseType(DialectAsmParser &parser) const {
     return SizeType::get(getContext());
   if (keyword == "value_shape")
     return ValueShapeType::get(getContext());
+  if (keyword == "witness")
+    return WitnessType::get(getContext());
 
   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
   return Type();
@@ -83,11 +86,27 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
   case ShapeTypes::ValueShape:
     os << "value_shape";
     return;
+  case ShapeTypes::Witness:
+    os << "witness";
+    return;
   default:
     llvm_unreachable("unexpected 'shape' type kind");
   }
 }
 
+//===----------------------------------------------------------------------===//
+// AnyOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+                        ValueRange operands, DictionaryAttr attributes,
+                        RegionRange regions,
+                        SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(ShapeType::get(context));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 81c9afceef5f..dc07e66f84e4 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -67,3 +67,16 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
   %0 = shape.shape_of %arg0 : tensor<?xf32>
   return %0 : !shape.shape
 }
+
+func @test_constraints() {
+  %0 = shape.const_shape []
+  %1 = shape.const_shape [1, 2, 3]
+  %w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
+  %w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
+  %w3 = "shape.assuming_all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness
+  "shape.assuming"(%w3) ( {
+    %2 = "shape.any"(%0, %1)  : (!shape.shape, !shape.shape) -> !shape.shape
+    "shape.assuming_yield"(%2) : (!shape.shape) -> ()
+  }) : (!shape.witness) -> !shape.shape
+  return
+}


        


More information about the Mlir-commits mailing list