[Mlir-commits] [mlir] c714b44 - [mlir][Shape] Make cstr_eq more like cstr_broadcastable

Benjamin Kramer llvmlistbot at llvm.org
Wed Mar 3 07:20:54 PST 2021


Author: Benjamin Kramer
Date: 2021-03-03T16:20:05+01:00
New Revision: c714b441ef01523a1ccd86d92376c1b1103e6a90

URL: https://github.com/llvm/llvm-project/commit/c714b441ef01523a1ccd86d92376c1b1103e6a90
DIFF: https://github.com/llvm/llvm-project/commit/c714b441ef01523a1ccd86d92376c1b1103e6a90.diff

LOG: [mlir][Shape] Make cstr_eq more like cstr_broadcastable

This includes allowing extents and not just shapes.

Differential Revision: https://reviews.llvm.org/D97716

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir
    mlir/test/Dialect/Shape/remove-shape-constraints.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 756d0ec706b8..0a6122801835 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -783,7 +783,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable",
   let verifier = [{ return ::verify(*this); }];
 }
 
-def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
+def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> {
   let summary = "Determines if all input shapes are equal";
   let description = [{
     Given 1 or more input shapes, determine if all shapes are the exact same.
@@ -796,10 +796,21 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
     %w1 = shape.cstr_eq [2,2], [1,2] // Failure
     ```
   }];
-  let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
+  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
   let results = (outs Shape_WitnessType:$result);
 
-  let assemblyFormat = "$inputs attr-dict";
+  let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
+
+  let extraClassDeclaration = [{
+    // TODO: This should really be automatic. Figure out how to not need this defined.
+    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+    ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+    ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
+      inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context));
+      return success();
+    };
+  }];
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index ea895015b1e5..b5828fe53bd8 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -360,7 +360,7 @@ func @f(%arg0 : !shape.shape) {
   // CHECK-NEXT: shape.const_witness true
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %0 = shape.cstr_eq %arg0, %arg0, %arg0
+  %0 = shape.cstr_eq %arg0, %arg0, %arg0 : !shape.shape, !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -375,7 +375,7 @@ func @f() {
   %cs0 = shape.const_shape [0, 1] : !shape.shape
   %cs1 = shape.const_shape [0, 1] : !shape.shape
   %cs2 = shape.const_shape [0, 1] : !shape.shape
-  %0 = shape.cstr_eq %cs0, %cs1, %cs2
+  %0 = shape.cstr_eq %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -391,7 +391,7 @@ func @f() {
   // CHECK-NEXT: return
   %cs0 = shape.const_shape [0, 1] : !shape.shape
   %cs1 = shape.const_shape [3, 1] : !shape.shape
-  %0 = shape.cstr_eq %cs0, %cs1
+  %0 = shape.cstr_eq %cs0, %cs1 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -403,7 +403,7 @@ func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
   // CHECK-NEXT: shape.cstr_eq
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %0 = shape.cstr_eq %arg0, %arg1
+  %0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 1a5735ccbb5e..ca838e7f8dc7 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -102,7 +102,7 @@ func @test_constraints() {
   %1 = shape.const_shape [1, 2, 3] : !shape.shape
   %true = constant true
   %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
-  %w1 = shape.cstr_eq %0, %1
+  %w1 = shape.cstr_eq %0, %1 : !shape.shape, !shape.shape
   %w2 = shape.const_witness true
   %w3 = shape.const_witness false
   %w4 = shape.cstr_require %true, "msg"
@@ -114,6 +114,12 @@ func @test_constraints() {
   return
 }
 
+func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
+                                      %rhs : tensor<?xindex>) {
+  %w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
+  return
+}
+
 func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
                                       %rhs : tensor<?xindex>) {
   %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>

diff  --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
index 31bb7bd56255..e52cd36107ef 100644
--- a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
+++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
@@ -29,7 +29,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
   // REPLACE: shape.assuming %[[WITNESS]]
   // CANON-NEXT: test.source
   // CANON-NEXT: return
-  %0 = shape.cstr_eq %arg0, %arg1
+  %0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
   %1 = shape.assuming %0 -> index {
     %2 = "test.source"() : () -> (index)
     shape.assuming_yield %2 : index
@@ -46,7 +46,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
   // CANON-NEXT: test.source
   // CANON-NEXT: return
   %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
-  %1 = shape.cstr_eq %arg0, %arg1
+  %1 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
   %2 = shape.assuming_all %0, %1
   %3 = shape.assuming %0 -> index {
     %4 = "test.source"() : () -> (index)


        


More information about the Mlir-commits mailing list