[Mlir-commits] [mlir] e74e6af - [shape] Add min and max ops
Jacques Pienaar
llvmlistbot at llvm.org
Tue Apr 6 17:58:54 PDT 2021
Author: Jacques Pienaar
Date: 2021-04-06T17:58:12-07:00
New Revision: e74e6afcf13aeb7d0a30e55b2eda89f5910d6e68
URL: https://github.com/llvm/llvm-project/commit/e74e6afcf13aeb7d0a30e55b2eda89f5910d6e68
DIFF: https://github.com/llvm/llvm-project/commit/e74e6afcf13aeb7d0a30e55b2eda89f5910d6e68.diff
LOG: [shape] Add min and max ops
These are element-wise operations that operates on shapes with equal ranks.
Also add missing printer/parser for join operator.
Differential Revision: https://reviews.llvm.org/D99986
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 34a12275eabeb..0b8c26dc91565 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -387,13 +387,52 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
used to return an error to the user upon mismatch of dimensions.
```mlir
- %c = shape.join %a, %b, error="<reason>" : !shape.shape
+ %c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
```
}];
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrSizeType:$result);
+
+ let assemblyFormat = [{
+ $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
+ type($arg0) `,` type($arg1) `->` type($result)
+ }];
+}
+
+def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
+ let summary = "Elementwise maximum";
+ let description = [{
+ Computes the elementwise maximum of two shapes with equal ranks. If either
+ operand is an error, then an error will be propagated to the result. If the
+ input types mismatch or the ranks do not match, then the result is an
+ error.
+ }];
+
+ let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
+ let results = (outs Shape_ShapeOrSizeType:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+ }];
+}
+
+def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
+ let summary = "Elementwise minimum";
+ let description = [{
+ Computes the elementwise maximum of two shapes with equal ranks. If either
+ operand is an error, then an error will be propagated to the result. If the
+ input types mismatch or the ranks do not match, then the result is an
+ error.
+ }];
+
+ let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
+ let results = (outs Shape_ShapeOrSizeType:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+ }];
}
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index ca838e7f8dc7b..b9ae301d55799 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -115,7 +115,7 @@ func @test_constraints() {
}
func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
- %rhs : tensor<?xindex>) {
+ %rhs : tensor<?xindex>) {
%w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
return
}
@@ -183,7 +183,6 @@ func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
return %rank : index
}
-
func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
@@ -289,3 +288,35 @@ func @is_broadcastable_on_shapes(%a : !shape.shape,
: !shape.shape, !shape.shape
return %result : i1
}
+
+func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
+ %0 = shape.const_shape [4, 57, 92] : !shape.shape
+ %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
+ %2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
+ !shape.shape, !shape.shape -> !shape.shape
+ return %2 : !shape.shape
+}
+
+func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
+ %0 = shape.const_shape [4, 57, 92] : !shape.shape
+ %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
+ %2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
+ !shape.shape, !shape.shape -> !shape.shape
+ return %2 : !shape.shape
+}
+
+func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
+ %0 = shape.const_size 5
+ %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
+ %2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
+ !shape.size, !shape.size -> !shape.size
+ return %2 : !shape.size
+}
+
+func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
+ %0 = shape.const_size 9
+ %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
+ %2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
+ !shape.size, !shape.size -> !shape.size
+ return %2 : !shape.size
+}
More information about the Mlir-commits
mailing list