[Mlir-commits] [mlir] c714b44 - [mlir][Shape] Make cstr_eq more like cstr_broadcastable
Benjamin Kramer
llvmlistbot at llvm.org
Wed Mar 3 07:20:54 PST 2021
Author: Benjamin Kramer
Date: 2021-03-03T16:20:05+01:00
New Revision: c714b441ef01523a1ccd86d92376c1b1103e6a90
URL: https://github.com/llvm/llvm-project/commit/c714b441ef01523a1ccd86d92376c1b1103e6a90
DIFF: https://github.com/llvm/llvm-project/commit/c714b441ef01523a1ccd86d92376c1b1103e6a90.diff
LOG: [mlir][Shape] Make cstr_eq more like cstr_broadcastable
This includes allowing extents and not just shapes.
Differential Revision: https://reviews.llvm.org/D97716
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir
mlir/test/Dialect/Shape/remove-shape-constraints.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 756d0ec706b8..0a6122801835 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -783,7 +783,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable",
let verifier = [{ return ::verify(*this); }];
}
-def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
+def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> {
let summary = "Determines if all input shapes are equal";
let description = [{
Given 1 or more input shapes, determine if all shapes are the exact same.
@@ -796,10 +796,21 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
%w1 = shape.cstr_eq [2,2], [1,2] // Failure
```
}];
- let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
+ let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
let results = (outs Shape_WitnessType:$result);
- let assemblyFormat = "$inputs attr-dict";
+ let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
+
+ let extraClassDeclaration = [{
+ // TODO: This should really be automatic. Figure out how to not need this defined.
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
+ inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context));
+ return success();
+ };
+ }];
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index ea895015b1e5..b5828fe53bd8 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -360,7 +360,7 @@ func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %0 = shape.cstr_eq %arg0, %arg0, %arg0
+ %0 = shape.cstr_eq %arg0, %arg0, %arg0 : !shape.shape, !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
@@ -375,7 +375,7 @@ func @f() {
%cs0 = shape.const_shape [0, 1] : !shape.shape
%cs1 = shape.const_shape [0, 1] : !shape.shape
%cs2 = shape.const_shape [0, 1] : !shape.shape
- %0 = shape.cstr_eq %cs0, %cs1, %cs2
+ %0 = shape.cstr_eq %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
@@ -391,7 +391,7 @@ func @f() {
// CHECK-NEXT: return
%cs0 = shape.const_shape [0, 1] : !shape.shape
%cs1 = shape.const_shape [3, 1] : !shape.shape
- %0 = shape.cstr_eq %cs0, %cs1
+ %0 = shape.cstr_eq %cs0, %cs1 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
@@ -403,7 +403,7 @@ func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
// CHECK-NEXT: shape.cstr_eq
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %0 = shape.cstr_eq %arg0, %arg1
+ %0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 1a5735ccbb5e..ca838e7f8dc7 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -102,7 +102,7 @@ func @test_constraints() {
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%true = constant true
%w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
- %w1 = shape.cstr_eq %0, %1
+ %w1 = shape.cstr_eq %0, %1 : !shape.shape, !shape.shape
%w2 = shape.const_witness true
%w3 = shape.const_witness false
%w4 = shape.cstr_require %true, "msg"
@@ -114,6 +114,12 @@ func @test_constraints() {
return
}
+func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
+ %rhs : tensor<?xindex>) {
+ %w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
+ return
+}
+
func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
%rhs : tensor<?xindex>) {
%w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
diff --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
index 31bb7bd56255..e52cd36107ef 100644
--- a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
+++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
@@ -29,7 +29,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// REPLACE: shape.assuming %[[WITNESS]]
// CANON-NEXT: test.source
// CANON-NEXT: return
- %0 = shape.cstr_eq %arg0, %arg1
+ %0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
%1 = shape.assuming %0 -> index {
%2 = "test.source"() : () -> (index)
shape.assuming_yield %2 : index
@@ -46,7 +46,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
// CANON-NEXT: test.source
// CANON-NEXT: return
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
- %1 = shape.cstr_eq %arg0, %arg1
+ %1 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
%2 = shape.assuming_all %0, %1
%3 = shape.assuming %0 -> index {
%4 = "test.source"() : () -> (index)
More information about the Mlir-commits
mailing list