[Mlir-commits] [mlir] ccb40c8 - [MLIR][Shape] Allow `cstr_broadcastable` to accept extent tensors
Frederik Gossen
llvmlistbot at llvm.org
Mon Jul 20 07:40:08 PDT 2020
Author: Frederik Gossen
Date: 2020-07-20T14:39:44Z
New Revision: ccb40c84c57137fade6b8815cf3590a537a4b702
URL: https://github.com/llvm/llvm-project/commit/ccb40c84c57137fade6b8815cf3590a537a4b702
DIFF: https://github.com/llvm/llvm-project/commit/ccb40c84c57137fade6b8815cf3590a537a4b702.diff
LOG: [MLIR][Shape] Allow `cstr_broadcastable` to accept extent tensors
Differential Revision: https://reviews.llvm.org/D84155
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir
mlir/test/Dialect/Shape/remove-shape-constraints.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 85082419b30e..46400c8e9eff 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -610,8 +610,9 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
let summary = "Determines if 2 shapes can be successfully broadcasted";
let description = [{
- Given 2 input shapes, return a witness specifying if they are broadcastable.
- This broadcastable follows the same logic as what shape.broadcast documents.
+ Given two input shapes or extent tensors, return a witness specifying if
+ they are broadcastable. This broadcastable follows the same logic as what
+ shape.broadcast documents.
"cstr" operations represent runtime assertions.
@@ -622,10 +623,11 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
```
}];
- let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
+ Shape_ShapeOrExtentTensorType:$rhs);
let results = (outs Shape_WitnessType:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict";
+ let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) `,` type($rhs) attr-dict";
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index a58b230f7981..156063ea002c 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -431,7 +431,7 @@ func @f() {
// CHECK-NEXT: return
%cs0 = shape.const_shape [3, 1]
%cs1 = shape.const_shape [1, 5]
- %0 = shape.cstr_broadcastable %cs0, %cs1
+ %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
@@ -447,7 +447,7 @@ func @static_non_broadcastable() {
// CHECK-NEXT: return
%cs0 = shape.const_shape [1, 3]
%cs1 = shape.const_shape [1, 5]
- %0 = shape.cstr_broadcastable %cs0, %cs1
+ %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
@@ -461,7 +461,7 @@ func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [1,3]
- %0 = shape.cstr_broadcastable %arg0, %cs0
+ %0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
@@ -473,7 +473,20 @@ func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %0 = shape.cstr_broadcastable %arg0, %arg0
+ %0 = shape.cstr_broadcastable %arg0, %arg0 : !shape.shape, !shape.shape
+ "consume.witness"(%0) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+
+// Broadcastable canonicalization also works on extent tensors.
+// CHECK-LABEL: func @broadcastable_on_extent_tensors
+func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = shape.cstr_broadcastable %arg, %arg : tensor<?xindex>, tensor<?xindex>
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
@@ -560,7 +573,7 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
// CHECK-NEXT: return
%0 = shape.const_shape []
%1 = shape.shape_of %arg0 : tensor<?xf32>
- %2 = shape.cstr_broadcastable %0, %1
+ %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
@@ -577,7 +590,7 @@ func @cstr_broadcastable_unknown(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) {
// CHECK-NEXT: return
%0 = shape.shape_of %arg0 : tensor<?xf32>
%1 = shape.shape_of %arg1 : tensor<?xf32>
- %2 = shape.cstr_broadcastable %0, %1
+ %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
@@ -592,7 +605,7 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
// CHECK-NEXT: return
%0 = shape.shape_of %arg1 : tensor<index>
%1 = shape.shape_of %arg0 : tensor<*xf32>
- %2 = shape.cstr_broadcastable %0, %1
+ %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index c6f52519ad2e..30cf29a083ec 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -86,7 +86,7 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
func @test_constraints() {
%0 = shape.const_shape []
%1 = shape.const_shape [1, 2, 3]
- %w0 = shape.cstr_broadcastable %0, %1
+ %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
%w1 = shape.cstr_eq %0, %1
%w2 = shape.const_witness true
%w3 = shape.const_witness false
@@ -98,6 +98,12 @@ func @test_constraints() {
return
}
+func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
+ %rhs : tensor<?xindex>) {
+ %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
+ return
+}
+
func @test_mul(%lhs: !shape.size, %rhs: !shape.size) -> !shape.size {
%product = shape.mul %lhs, %rhs
return %product: !shape.size
diff --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
index 69887c6994f4..31bb7bd56255 100644
--- a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
+++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
@@ -11,7 +11,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// REPLACE: shape.assuming %[[WITNESS]]
// CANON-NEXT: test.source
// CANON-NEXT: return
- %0 = shape.cstr_broadcastable %arg0, %arg1
+ %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
%1 = shape.assuming %0 -> index {
%2 = "test.source"() : () -> (index)
shape.assuming_yield %2 : index
@@ -45,7 +45,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// CANON-NEXT: test.source
// CANON-NEXT: return
- %0 = shape.cstr_broadcastable %arg0, %arg1
+ %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
%1 = shape.cstr_eq %arg0, %arg1
%2 = shape.assuming_all %0, %1
%3 = shape.assuming %0 -> index {
More information about the Mlir-commits
mailing list