[Mlir-commits] [mlir] e74e6af - [shape] Add min and max ops

Jacques Pienaar llvmlistbot at llvm.org
Tue Apr 6 17:58:54 PDT 2021


Author: Jacques Pienaar
Date: 2021-04-06T17:58:12-07:00
New Revision: e74e6afcf13aeb7d0a30e55b2eda89f5910d6e68

URL: https://github.com/llvm/llvm-project/commit/e74e6afcf13aeb7d0a30e55b2eda89f5910d6e68
DIFF: https://github.com/llvm/llvm-project/commit/e74e6afcf13aeb7d0a30e55b2eda89f5910d6e68.diff

LOG: [shape] Add min and max ops

These are element-wise operations that operates on shapes with equal ranks.
Also add missing printer/parser for join operator.

Differential Revision: https://reviews.llvm.org/D99986

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 34a12275eabeb..0b8c26dc91565 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -387,13 +387,52 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
     used to return an error to the user upon mismatch of dimensions.
 
     ```mlir
-    %c = shape.join %a, %b, error="<reason>" : !shape.shape
+    %c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
     ```
   }];
 
   let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
                    OptionalAttr<StrAttr>:$error);
   let results = (outs Shape_ShapeOrSizeType:$result);
+
+  let assemblyFormat = [{
+    $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
+      type($arg0) `,` type($arg1) `->` type($result)
+  }];
+}
+
+def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
+  let summary = "Elementwise maximum";
+  let description = [{
+    Computes the elementwise maximum of two shapes with equal ranks. If either
+    operand is an error, then an error will be propagated to the result. If the
+    input types mismatch or the ranks do not match, then the result is an
+    error.
+  }];
+
+  let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
+  let results = (outs Shape_ShapeOrSizeType:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+  }];
+}
+
+def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
+  let summary = "Elementwise minimum";
+  let description = [{
+    Computes the elementwise maximum of two shapes with equal ranks. If either
+    operand is an error, then an error will be propagated to the result. If the
+    input types mismatch or the ranks do not match, then the result is an
+    error.
+  }];
+
+  let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
+  let results = (outs Shape_ShapeOrSizeType:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+  }];
 }
 
 def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index ca838e7f8dc7b..b9ae301d55799 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -115,7 +115,7 @@ func @test_constraints() {
 }
 
 func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
-                                      %rhs : tensor<?xindex>) {
+                           %rhs : tensor<?xindex>) {
   %w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
   return
 }
@@ -183,7 +183,6 @@ func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
   return %rank : index
 }
 
-
 func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
   %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
   return %result : i1
@@ -289,3 +288,35 @@ func @is_broadcastable_on_shapes(%a : !shape.shape,
       : !shape.shape, !shape.shape
   return %result : i1
 }
+
+func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
+  %0 = shape.const_shape [4, 57, 92] : !shape.shape
+  %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
+  %2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
+    !shape.shape, !shape.shape -> !shape.shape
+  return %2 : !shape.shape
+}
+
+func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
+  %0 = shape.const_shape [4, 57, 92] : !shape.shape
+  %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
+  %2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
+    !shape.shape, !shape.shape -> !shape.shape
+  return %2 : !shape.shape
+}
+
+func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
+  %0 = shape.const_size 5
+  %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
+  %2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
+    !shape.size, !shape.size -> !shape.size
+  return %2 : !shape.size
+}
+
+func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
+  %0 = shape.const_size 9
+  %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
+  %2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
+    !shape.size, !shape.size -> !shape.size
+  return %2 : !shape.size
+}


        


More information about the Mlir-commits mailing list