[Mlir-commits] [mlir] 86a7854 - [mlir] Add shape.with_shape op
Jacques Pienaar
llvmlistbot at llvm.org
Fri Jul 31 14:47:02 PDT 2020
Author: Jacques Pienaar
Date: 2020-07-31T14:46:48-07:00
New Revision: 86a78546b97950dfacd44ab77f17f4ce055d16e5
URL: https://github.com/llvm/llvm-project/commit/86a78546b97950dfacd44ab77f17f4ce055d16e5
DIFF: https://github.com/llvm/llvm-project/commit/86a78546b97950dfacd44ab77f17f4ce055d16e5.diff
LOG: [mlir] Add shape.with_shape op
This is an operation that can returns a new ValueShape with a different shape. Useful for composing shape function calls and reusing existing shape transfer functions.
Just adding the op in this change.
Differential Revision: https://reviews.llvm.org/D84217
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index 8f64e3c081e6..3e0177bca50e 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -100,7 +100,11 @@ def Shape_ValueShapeType : DialectType<ShapeDialect,
corresponds to `Value` in the compiler) and a shape. Conceptually this is a
tuple of a value (potentially unknown) and `shape.type`. The value and shape
can either or both be unknown. If both the `value` and `shape` are known,
- then the shape of `value` is conformant with `shape`.
+ then the shape of `value` is conformant with `shape`. That is, the shape of
+ the value conforms to the shape of the ValueShape, so that if we have
+ ValueShape `(value, shape)` then `join(shape_of(value), shape)` would be
+ error free and in particular it means that if both are statically known,
+ then they are equal.
}];
}
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index bc7b6048e28f..ac077439be3c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -432,6 +432,49 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
let hasCanonicalizer = 1;
}
+def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> {
+ let summary = "Returns ValueShape with given shape";
+ let description = [{
+ Returns ValueShape with the shape updated to match the shape operand. That
+ is a new ValueShape tuple is created with value equal to `operand`'s
+ value and shape equal to `shape`. If the ValueShape and given `shape` are
+ non-conformant, then the returned ValueShape will represent an error of
+ this mismatch. Similarly if either inputs are in an error state, then an
+ error is popagated.
+
+ Usage:
+ %0 = shape.with_shape %1, %2 : tensor<...>, !shape.shape
+
+ This is used, for example, where one combines shape function calculations
+ and/or call one shape function from another. E.g.,
+
+ ```mlir
+ func @shape_foobah(%a: !shape.value_shape,
+ %b: !shape.value_shape,
+ %c: !shape.value_shape) -> !shape.shape {
+ %0 = call @shape_foo(%a, %b) :
+ (!shape.value_shape, !shape.value_shape) -> !shape.shape
+ %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
+ %2 = call @shape_bah(%c, %1) :
+ (!shape.value_shape, !shape.value_shape) -> !shape.shape
+ return %2 : !shape.shape
+ }
+ ```
+
+ This op need not be a refinement of the shape. In non-error cases the input
+ ValueShape's value and shape are conformant and so too for the output, but
+ the result may be less specified than `operand`'s shape as `shape` is
+ merely used to construct the new ValueShape. If join behavior is desired
+ then a join op should be used.
+ }];
+
+ let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand,
+ Shape_ShapeType:$shape);
+ let results = (outs Shape_ValueShapeType:$result);
+
+ let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)";
+}
+
def Shape_YieldOp : Shape_Op<"yield",
[HasParent<"ReduceOp">,
NoSideEffect,
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 48b3805d0a3b..172835a2c6d5 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -221,4 +221,17 @@ func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
return %result : !shape.size
}
-
+// Testing nvoking shape function from another. shape_equal_shapes is merely
+// a trivial helper function to invoke elsewhere.
+func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
+ %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
+ %1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
+ %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ return %2 : !shape.shape
+}
+func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
+ %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
+ %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
+ %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape
+ return %2 : !shape.shape
+}
More information about the Mlir-commits
mailing list