[Mlir-commits] [mlir] 6983cf3 - [MLIR][Shape] Allow unsafe `shape.broadcast`

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 31 07:18:31 PDT 2020


Author: Frederik Gossen
Date: 2020-07-31T14:18:06Z
New Revision: 6983cf3a57aa6d8619eb39e1625eed5340ba05c7

URL: https://github.com/llvm/llvm-project/commit/6983cf3a57aa6d8619eb39e1625eed5340ba05c7
DIFF: https://github.com/llvm/llvm-project/commit/6983cf3a57aa6d8619eb39e1625eed5340ba05c7.diff

LOG: [MLIR][Shape] Allow unsafe `shape.broadcast`

In a context in which `shape.broadcast` is known not to produce an error value,
we want it to operate solely on extent tensors. The operation's behavior is
then undefined in the error case as the result type cannot hold this value.

Differential Revision: https://reviews.llvm.org/D84933

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 72e392b256db..bc7b6048e28f 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -49,25 +49,24 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
 def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
   let summary = "Returns the broadcasted output shape of two inputs";
   let description = [{
-    Computes the broadcasted output shape following:
-    1. If any inputs are unranked, output is unranked;
-    2. Else the input array with number of dimensions smaller than the max
-       input dimension, has 1’s prepended to its shapes and the output shape is
-       calculated as follows:
-
-           output[i] = lhs[i] if lhs[i] == rhs[i] or rhs[i] is unknown/undefined
-                     = rhs[i] if lhs[i] is unknown/undefined
-                     = lhs[i] if rhs[i] == 1
-                     = rhs[i] if lhs[i] == 1
-                     = error  if lhs[i] != rhs[i]
-
-    Op has an optional string attribute for the error case where there is no
-    broadcastable output shape possible for the given inputs.
-
-    Op may also return an ExtentTensor, but this should only be done when this
-    is statically guaranteed to never fail, either because of a dependency on a
-    cstr_broadcastable operation or other details of the construction of the
-    program.
+    Returns the broadcasted shape for two input shapes or extent tensors. Both
+    operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of
+    type `shape.shape` and, if both operands are tensors, may be of type
+    `tensor<?xindex>`.
+
+    If the two operand shapes are of 
diff erent rank the smaller one is padded
+    with 1's from the left. The resulting broadcasted shape is then defined as
+
+        result[i] = lhs[i] if lhs[i] == rhs[i]
+                  = lhs[i] if rhs[i] == 1
+                  = rhs[i] if lhs[i] == 1.
+
+    In case the resulting shape is undefined, i.e. if corresponding extents are
+    
diff erent from each other but none is 1, the result is an error shape.
+    Likewise error values are propagated if any of the operands holds an error
+    value. If the result type is an extent tensor (and can therefore not hold
+    the error value) the behavior may be undefined. The optional string
+    attribute can be used to describe the error case.
   }];
 
   let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
@@ -75,8 +74,11 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
                        OptionalAttr<StrAttr>:$error);
   let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)";
+  let assemblyFormat = [{
+    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+  }];
 
+  let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
   let hasFolder = 1;
 
   let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index e18ff14df304..21c5a68c3adc 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -60,6 +60,31 @@ func @f() -> !shape.shape {
 
 // -----
 
+// Basic case including extent tensors.
+// CHECK-LABEL: @broadcast
+func @broadcast() -> tensor<?xindex> {
+  // CHECK: shape.const_shape [7, 2] : tensor<?xindex>
+  %0 = shape.const_shape [1, 2] : tensor<?xindex>
+  %1 = shape.const_shape [7, 1] : tensor<?xindex>
+  %2 = shape.broadcast %0, %1
+      : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+  return %2 : tensor<?xindex>
+}
+
+// -----
+
+// Basic case including extent tensors.
+// CHECK-LABEL: @broadcast
+func @broadcast() -> !shape.shape {
+  // CHECK: shape.const_shape [7, 2] : !shape.shape
+  %0 = shape.const_shape [1, 2] : tensor<?xindex>
+  %1 = shape.const_shape [7, 1] : tensor<?xindex>
+  %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> !shape.shape
+  return %2 : !shape.shape
+}
+
+// -----
+
 // Rhs is a scalar.
 // CHECK-LABEL: func @f
 func @f(%arg0 : !shape.shape) -> !shape.shape {

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 448bd84e754e..eb0ae5ae05a9 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -138,17 +138,19 @@ func @add(%lhs : !shape.size, %rhs : index) -> index {
 
 // -----
 
-func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
+func @broadcast(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
   // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
-  %result = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> tensor<?xindex>
+  %result = shape.broadcast %arg0, %arg1
+      : !shape.shape, !shape.shape -> tensor<?xindex>
   return %result : tensor<?xindex>
 }
 
 
 // -----
 
-func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
+func @broadcast(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
   // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
-  %result = shape.broadcast %arg0, %arg1 : !shape.shape, tensor<?xindex> -> tensor<?xindex>
+  %result = shape.broadcast %arg0, %arg1
+      : !shape.shape, tensor<?xindex> -> tensor<?xindex>
   return %result : tensor<?xindex>
 }


        


More information about the Mlir-commits mailing list