[Mlir-commits] [mlir] ad793ed - Forward extent tensors through shape.broadcast.

Tres Popp llvmlistbot at llvm.org
Wed Jul 29 06:49:37 PDT 2020


Author: Tres Popp
Date: 2020-07-29T15:49:10+02:00
New Revision: ad793ed90370f0e99fa7ae0cc4d4e97081b5561a

URL: https://github.com/llvm/llvm-project/commit/ad793ed90370f0e99fa7ae0cc4d4e97081b5561a
DIFF: https://github.com/llvm/llvm-project/commit/ad793ed90370f0e99fa7ae0cc4d4e97081b5561a.diff

LOG: Forward extent tensors through shape.broadcast.

Differential Revision: https://reviews.llvm.org/D84832

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 6ea61376c34d..7b7da042834a 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -89,16 +89,23 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
 
     Op has an optional string attribute for the error case where there is no
     broadcastable output shape possible for the given inputs.
+
+    Op may also return an ExtentTensor, but this should only be done when this
+    is statically guaranteed to never fail, either because of a dependency on a
+    cstr_broadcastable operation or other details of the construction of the
+    program.
   }];
 
   let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
                        Shape_ShapeOrExtentTensorType:$rhs,
                        OptionalAttr<StrAttr>:$error);
-  let results = (outs Shape_ShapeType:$result);
+  let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)";
 
   let hasFolder = 1;
+
+  let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
 }
 
 def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index ed5bf6999cd9..e18ff14df304 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -54,7 +54,7 @@ func @f() -> !shape.shape {
   // CHECK: shape.const_shape [7, 2] : !shape.shape
   %0 = shape.const_shape [1, 2] : !shape.shape
   %1 = shape.const_shape [7, 1] : !shape.shape
-  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
+  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
   return %2 : !shape.shape
 }
 
@@ -65,7 +65,7 @@ func @f() -> !shape.shape {
 func @f(%arg0 : !shape.shape) -> !shape.shape {
   // CHECK: return %arg0
   %0 = shape.const_shape [] : !shape.shape
-  %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape
+  %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape -> !shape.shape
   return %1 : !shape.shape
 }
 
@@ -76,7 +76,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
 func @f(%arg0 : !shape.shape) -> !shape.shape {
   // CHECK: return %arg0
   %0 = shape.const_shape [] : !shape.shape
-  %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape
+  %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape -> !shape.shape
   return %1 : !shape.shape
 }
 
@@ -89,7 +89,7 @@ func @f() -> !shape.shape {
   // CHECK: return %[[CST]]
   %0 = shape.const_shape [] : !shape.shape
   %1 = shape.const_shape [1, 2, 3] : !shape.shape
-  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
+  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
   return %2 : !shape.shape
 }
 
@@ -101,7 +101,7 @@ func @f() -> !shape.shape {
   // CHECK: shape.broadcast
   %0 = shape.const_shape [2] : !shape.shape
   %1 = shape.const_shape [7] : !shape.shape
-  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
+  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
   return %2 : !shape.shape
 }
 

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d804efbd9e4e..448bd84e754e 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -136,3 +136,19 @@ func @add(%lhs : !shape.size, %rhs : index) -> index {
   return %result : index
 }
 
+// -----
+
+func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
+  // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
+  %result = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> tensor<?xindex>
+  return %result : tensor<?xindex>
+}
+
+
+// -----
+
+func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
+  // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
+  %result = shape.broadcast %arg0, %arg1 : !shape.shape, tensor<?xindex> -> tensor<?xindex>
+  return %result : tensor<?xindex>
+}

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 229f3948d31d..48b3805d0a3b 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -49,11 +49,18 @@ func @test_shape_num_elements_fixed() {
 func @test_broadcast_fixed() {
   %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
   %1 = shape.const_shape [4, 57, 92] : !shape.shape
-  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
+  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
 
+func @test_broadcast_extents() -> tensor<?xindex> {
+  %0 = shape.const_shape [10, 1, 57, 92] : tensor<?xindex>
+  %1 = shape.const_shape [4, 57, 92] : tensor<?xindex>
+  %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+  return %2 : tensor<?xindex>
+}
+
 func @test_shape_any_fixed() {
   %0 = shape.const_shape [4, 57, 92] : !shape.shape
   %1 = shape.const_shape [4, 57, 92] : !shape.shape


        


More information about the Mlir-commits mailing list