[Mlir-commits] [mlir] ccb40c8 - [MLIR][Shape] Allow `cstr_broadcastable` to accept extent tensors

Frederik Gossen llvmlistbot at llvm.org
Mon Jul 20 07:40:08 PDT 2020


Author: Frederik Gossen
Date: 2020-07-20T14:39:44Z
New Revision: ccb40c84c57137fade6b8815cf3590a537a4b702

URL: https://github.com/llvm/llvm-project/commit/ccb40c84c57137fade6b8815cf3590a537a4b702
DIFF: https://github.com/llvm/llvm-project/commit/ccb40c84c57137fade6b8815cf3590a537a4b702.diff

LOG: [MLIR][Shape] Allow `cstr_broadcastable` to accept extent tensors

Differential Revision: https://reviews.llvm.org/D84155

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir
    mlir/test/Dialect/Shape/remove-shape-constraints.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 85082419b30e..46400c8e9eff 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -610,8 +610,9 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
 def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
   let summary = "Determines if 2 shapes can be successfully broadcasted";
   let description = [{
-    Given 2 input shapes, return a witness specifying if they are broadcastable.
-    This broadcastable follows the same logic as what shape.broadcast documents.
+    Given two input shapes or extent tensors, return a witness specifying if
+    they are broadcastable. This broadcastable follows the same logic as what
+    shape.broadcast documents.
 
     "cstr" operations represent runtime assertions.
 
@@ -622,10 +623,11 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
     ```
   }];
 
-  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
+                       Shape_ShapeOrExtentTensorType:$rhs);
   let results = (outs Shape_WitnessType:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict";
+  let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) `,` type($rhs) attr-dict";
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index a58b230f7981..156063ea002c 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -431,7 +431,7 @@ func @f() {
   // CHECK-NEXT: return
   %cs0 = shape.const_shape [3, 1]
   %cs1 = shape.const_shape [1, 5]
-  %0 = shape.cstr_broadcastable %cs0, %cs1
+  %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -447,7 +447,7 @@ func @static_non_broadcastable() {
   // CHECK-NEXT: return
   %cs0 = shape.const_shape [1, 3]
   %cs1 = shape.const_shape [1, 5]
-  %0 = shape.cstr_broadcastable %cs0, %cs1
+  %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -461,7 +461,7 @@ func @f(%arg0 : !shape.shape) {
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
   %cs0 = shape.const_shape [1,3]
-  %0 = shape.cstr_broadcastable %arg0, %cs0
+  %0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -473,7 +473,20 @@ func @f(%arg0 : !shape.shape) {
   // CHECK-NEXT: shape.const_witness true
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %0 = shape.cstr_broadcastable %arg0, %arg0
+  %0 = shape.cstr_broadcastable %arg0, %arg0 : !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+
+// Broadcastable canonicalization also works on extent tensors.
+// CHECK-LABEL: func @broadcastable_on_extent_tensors
+func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.cstr_broadcastable %arg, %arg : tensor<?xindex>, tensor<?xindex>
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -560,7 +573,7 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
   // CHECK-NEXT: return
   %0 = shape.const_shape []
   %1 = shape.shape_of %arg0 : tensor<?xf32>
-  %2 = shape.cstr_broadcastable %0, %1
+  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
@@ -577,7 +590,7 @@ func @cstr_broadcastable_unknown(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) {
   // CHECK-NEXT: return
   %0 = shape.shape_of %arg0 : tensor<?xf32>
   %1 = shape.shape_of %arg1 : tensor<?xf32>
-  %2 = shape.cstr_broadcastable %0, %1
+  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
@@ -592,7 +605,7 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
   // CHECK-NEXT: return
   %0 = shape.shape_of %arg1 : tensor<index>
   %1 = shape.shape_of %arg0 : tensor<*xf32>
-  %2 = shape.cstr_broadcastable %0, %1
+  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index c6f52519ad2e..30cf29a083ec 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -86,7 +86,7 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
 func @test_constraints() {
   %0 = shape.const_shape []
   %1 = shape.const_shape [1, 2, 3]
-  %w0 = shape.cstr_broadcastable %0, %1
+  %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   %w1 = shape.cstr_eq %0, %1
   %w2 = shape.const_witness true
   %w3 = shape.const_witness false
@@ -98,6 +98,12 @@ func @test_constraints() {
   return
 }
 
+func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
+                                      %rhs : tensor<?xindex>) {
+  %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
+  return
+}
+
 func @test_mul(%lhs: !shape.size, %rhs: !shape.size) -> !shape.size {
   %product = shape.mul %lhs, %rhs
   return %product: !shape.size

diff  --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
index 69887c6994f4..31bb7bd56255 100644
--- a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
+++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
@@ -11,7 +11,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
   // REPLACE: shape.assuming %[[WITNESS]]
   // CANON-NEXT: test.source
   // CANON-NEXT: return
-  %0 = shape.cstr_broadcastable %arg0, %arg1
+  %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
   %1 = shape.assuming %0 -> index {
     %2 = "test.source"() : () -> (index)
     shape.assuming_yield %2 : index
@@ -45,7 +45,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
 func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
   // CANON-NEXT: test.source
   // CANON-NEXT: return
-  %0 = shape.cstr_broadcastable %arg0, %arg1
+  %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
   %1 = shape.cstr_eq %arg0, %arg1
   %2 = shape.assuming_all %0, %1
   %3 = shape.assuming %0 -> index {


        


More information about the Mlir-commits mailing list