[Mlir-commits] [mlir] d2abbc1 - [mlir] Add shape.is_broadcastable.
Tres Popp
llvmlistbot at llvm.org
Fri Oct 30 01:46:45 PDT 2020
Author: Tres Popp
Date: 2020-10-30T09:46:35+01:00
New Revision: d2abbc17b2c03fda16feb1b52aa11440680e3887
URL: https://github.com/llvm/llvm-project/commit/d2abbc17b2c03fda16feb1b52aa11440680e3887
DIFF: https://github.com/llvm/llvm-project/commit/d2abbc17b2c03fda16feb1b52aa11440680e3887.diff
LOG: [mlir] Add shape.is_broadcastable.
This op returns a boolean value indicating whether 2 ops are
broadcastable or not. This follows the same logic as the other ops with
broadcast in their names in the shape dialect.
Concretely, shape.is_broadcastable returning true implies that
shape.broadcast will not give an error, and shape.cstr_broadcastable
will not result in an assertion failure. Similarly, false implies an
error or assertion failure.
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2bc4329db192..9998788568fc 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -181,6 +181,27 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
let assemblyFormat = "attr-dict $input `:` type($input)";
}
+def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
+ let summary = "Determines if 2 shapes can be successfully broadcasted";
+ let description = [{
+ Given two input shapes or extent tensors, return a predicate specifying if
+ they are broadcastable. This broadcastable follows the same logic as what
+ shape.broadcast documents.
+
+ Example:
+ ```mlir
+ %true = shape.is_broadcastable [2,2], [3,1,2]
+ %false = shape.is_broadcastable [2,2], [3,2]
+ ```
+ }];
+
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
+ Shape_ShapeOrExtentTensorType:$rhs);
+ let results = (outs I1:$result);
+
+ let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) `,` type($rhs) attr-dict";
+}
+
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
let summary = "Gets the rank of a shape";
let description = [{
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index ad7d42efaa52..57195baad7c0 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -260,3 +260,17 @@ func @any_on_extent_tensors(%a : tensor<?xindex>,
: tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return %result : tensor<?xindex>
}
+
+func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>,
+ %b : tensor<?xindex>) -> i1 {
+ %result = shape.is_broadcastable %a, %b
+ : tensor<?xindex>, tensor<?xindex>
+ return %result : i1
+}
+
+func @is_broadcastable_on_shapes(%a : !shape.shape,
+ %b : !shape.shape) -> i1 {
+ %result = shape.is_broadcastable %a, %b
+ : !shape.shape, !shape.shape
+ return %result : i1
+}
More information about the Mlir-commits
mailing list