[Mlir-commits] [mlir] d2abbc1 - [mlir] Add shape.is_broadcastable.

Tres Popp llvmlistbot at llvm.org
Fri Oct 30 01:46:45 PDT 2020


Author: Tres Popp
Date: 2020-10-30T09:46:35+01:00
New Revision: d2abbc17b2c03fda16feb1b52aa11440680e3887

URL: https://github.com/llvm/llvm-project/commit/d2abbc17b2c03fda16feb1b52aa11440680e3887
DIFF: https://github.com/llvm/llvm-project/commit/d2abbc17b2c03fda16feb1b52aa11440680e3887.diff

LOG: [mlir] Add shape.is_broadcastable.

This op returns a boolean value indicating whether 2 ops are
broadcastable or not. This follows the same logic as the other ops with
broadcast in their names in the shape dialect.

Concretely, shape.is_broadcastable returning true implies that
shape.broadcast will not give an error, and shape.cstr_broadcastable
will not result in an assertion failure. Similarly, false implies an
error or assertion failure.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2bc4329db192..9998788568fc 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -181,6 +181,27 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
   let assemblyFormat = "attr-dict $input `:` type($input)";
 }
 
+def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
+  let summary = "Determines if 2 shapes can be successfully broadcasted";
+  let description = [{
+    Given two input shapes or extent tensors, return a predicate specifying if
+    they are broadcastable. This broadcastable follows the same logic as what
+    shape.broadcast documents.
+
+    Example:
+    ```mlir
+    %true = shape.is_broadcastable [2,2], [3,1,2]
+    %false = shape.is_broadcastable [2,2], [3,2]
+    ```
+  }];
+
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
+                       Shape_ShapeOrExtentTensorType:$rhs);
+  let results = (outs I1:$result);
+
+  let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) `,` type($rhs) attr-dict";
+}
+
 def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
   let summary = "Gets the rank of a shape";
   let description = [{

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index ad7d42efaa52..57195baad7c0 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -260,3 +260,17 @@ func @any_on_extent_tensors(%a : tensor<?xindex>,
       : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
   return %result : tensor<?xindex>
 }
+
+func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>,
+                                         %b : tensor<?xindex>) -> i1 {
+  %result = shape.is_broadcastable %a, %b
+      : tensor<?xindex>, tensor<?xindex>
+  return %result : i1
+}
+
+func @is_broadcastable_on_shapes(%a : !shape.shape,
+                                 %b : !shape.shape) -> i1 {
+  %result = shape.is_broadcastable %a, %b
+      : !shape.shape, !shape.shape
+  return %result : i1
+}


        


More information about the Mlir-commits mailing list