[Mlir-commits] [mlir] ad793ed - Forward extent tensors through shape.broadcast.
Tres Popp
llvmlistbot at llvm.org
Wed Jul 29 06:49:37 PDT 2020
Author: Tres Popp
Date: 2020-07-29T15:49:10+02:00
New Revision: ad793ed90370f0e99fa7ae0cc4d4e97081b5561a
URL: https://github.com/llvm/llvm-project/commit/ad793ed90370f0e99fa7ae0cc4d4e97081b5561a
DIFF: https://github.com/llvm/llvm-project/commit/ad793ed90370f0e99fa7ae0cc4d4e97081b5561a.diff
LOG: Forward extent tensors through shape.broadcast.
Differential Revision: https://reviews.llvm.org/D84832
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 6ea61376c34d..7b7da042834a 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -89,16 +89,23 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
Op has an optional string attribute for the error case where there is no
broadcastable output shape possible for the given inputs.
+
+ Op may also return an ExtentTensor, but this should only be done when this
+ is statically guaranteed to never fail, either because of a dependency on a
+ cstr_broadcastable operation or other details of the construction of the
+ program.
}];
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
Shape_ShapeOrExtentTensorType:$rhs,
OptionalAttr<StrAttr>:$error);
- let results = (outs Shape_ShapeType:$result);
+ let results = (outs Shape_ShapeOrExtentTensorType:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)";
let hasFolder = 1;
+
+ let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
}
def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index ed5bf6999cd9..e18ff14df304 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -54,7 +54,7 @@ func @f() -> !shape.shape {
// CHECK: shape.const_shape [7, 2] : !shape.shape
%0 = shape.const_shape [1, 2] : !shape.shape
%1 = shape.const_shape [7, 1] : !shape.shape
- %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
+ %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
@@ -65,7 +65,7 @@ func @f() -> !shape.shape {
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
%0 = shape.const_shape [] : !shape.shape
- %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape
+ %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape -> !shape.shape
return %1 : !shape.shape
}
@@ -76,7 +76,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
%0 = shape.const_shape [] : !shape.shape
- %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape
+ %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape -> !shape.shape
return %1 : !shape.shape
}
@@ -89,7 +89,7 @@ func @f() -> !shape.shape {
// CHECK: return %[[CST]]
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
- %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
+ %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
@@ -101,7 +101,7 @@ func @f() -> !shape.shape {
// CHECK: shape.broadcast
%0 = shape.const_shape [2] : !shape.shape
%1 = shape.const_shape [7] : !shape.shape
- %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
+ %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d804efbd9e4e..448bd84e754e 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -136,3 +136,19 @@ func @add(%lhs : !shape.size, %rhs : index) -> index {
return %result : index
}
+// -----
+
+func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
+ // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
+ %result = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> tensor<?xindex>
+ return %result : tensor<?xindex>
+}
+
+
+// -----
+
+func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
+ // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
+ %result = shape.broadcast %arg0, %arg1 : !shape.shape, tensor<?xindex> -> tensor<?xindex>
+ return %result : tensor<?xindex>
+}
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 229f3948d31d..48b3805d0a3b 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -49,11 +49,18 @@ func @test_shape_num_elements_fixed() {
func @test_broadcast_fixed() {
%0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
- %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
+ %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
+func @test_broadcast_extents() -> tensor<?xindex> {
+ %0 = shape.const_shape [10, 1, 57, 92] : tensor<?xindex>
+ %1 = shape.const_shape [4, 57, 92] : tensor<?xindex>
+ %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+ return %2 : tensor<?xindex>
+}
+
func @test_shape_any_fixed() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
More information about the Mlir-commits
mailing list