[Mlir-commits] [mlir] 1f02ad7 - [mlir][shape] Update meet to handle all size & shape types

Jacques Pienaar llvmlistbot at llvm.org
Wed Aug 10 05:08:32 PDT 2022


Author: Jacques Pienaar
Date: 2022-08-10T05:08:24-07:00
New Revision: 1f02ad71310b9e86d183abdedc75ca99ff1106f5

URL: https://github.com/llvm/llvm-project/commit/1f02ad71310b9e86d183abdedc75ca99ff1106f5
DIFF: https://github.com/llvm/llvm-project/commit/1f02ad71310b9e86d183abdedc75ca99ff1106f5.diff

LOG: [mlir][shape] Update meet to handle all size & shape types

Also tighten up return type inference & compatibility functions.

Differential Revision: https://reviews.llvm.org/D130866

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index 57792edaf17b8..d5a8a65b419d7 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -110,6 +110,11 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
 
 def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
 
+// Any type representing a shape or size/dim.
+def Shape_AnyShapeOrSizeType : AnyTypeOf<
+    [Shape_SizeOrIndexType, Shape_ShapeOrExtentTensorType],
+  "any shape or size">;
+
 def Shape_WitnessType : Shape_Type<"Witness", "witness"> {
   let description = [{
     A witness is a structural device in the compiler to maintain ordering of

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2773883f57a2a..6d0d84dbbd0fd 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -406,11 +406,11 @@ def Shape_MaxOp : Shape_Op<"max",
 
 def Shape_MeetOp : Shape_Op<"meet",
     [Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
-  let summary = "Returns the least general shape.shape of its operands";
+  let summary = "Returns the least general shape or size of its operands";
   let description = [{
-    An operation that computes the least general shape of input operands.
+    An operation that computes the least general shape or dim of input operands.
     This effectively asserts that corresponding static dimensions are equal.
-    The behavior is to match each element of the `shape.shape` and propagate the
+    The behavior is to match each element of the shape/size and propagate the
     most restrictive information, returning an invalid shape if there are
     contradictory requirements. E.g., using pseudo code
 
@@ -433,9 +433,11 @@ def Shape_MeetOp : Shape_Op<"meet",
     ```
   }];
 
-  let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
-                   OptionalAttr<StrAttr>:$error);
-  let results = (outs Shape_ShapeOrSizeType:$result);
+  let arguments = (ins
+    Shape_AnyShapeOrSizeType:$arg0,
+    Shape_AnyShapeOrSizeType:$arg1,
+    OptionalAttr<StrAttr>:$error);
+  let results = (outs Shape_AnyShapeOrSizeType:$result);
 
   let assemblyFormat = [{
     $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ce1767ba8413e..ff7065474bdb4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1309,7 +1309,53 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
     MLIRContext *context, Optional<Location> location, ValueRange operands,
     DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<Type> &inferredReturnTypes) {
-  inferredReturnTypes.assign({operands[0].getType()});
+  if (operands.empty())
+    return failure();
+
+  auto isShapeType = [](Type arg) {
+    if (arg.isa<ShapeType>())
+      return true;
+    return isExtentTensorType(arg);
+  };
+
+  ValueRange::type_range types = operands.getTypes();
+  Type acc = types.front();
+  for (auto t : drop_begin(types)) {
+    Type l = acc, r = t;
+    if (!l.isa<ShapeType, SizeType>())
+      std::swap(l, r);
+
+    // Handle sizes, propagate error type if present.
+    if (l.isa<SizeType>()) {
+      if (r.isa<SizeType, IndexType>())
+        acc = l;
+      else
+        return emitOptionalError(location, "requires all sizes or shapes");
+    } else if (l.isa<IndexType>()) {
+      if (r.isa<IndexType>())
+        acc = r;
+      else
+        return emitOptionalError(location, "requires all sizes or shapes");
+    } else if (l.isa<ShapeType>()) {
+      // Handle shapes, propagate error type if present.
+      if (isShapeType(r))
+        acc = l;
+      else
+        return emitOptionalError(location, "requires all sizes or shapes");
+    } else if (isExtentTensorType(l)) {
+      auto rank1 = l.cast<RankedTensorType>().getShape()[0];
+      auto rank2 = r.cast<RankedTensorType>().getShape()[0];
+      if (ShapedType::isDynamic(rank1))
+        acc = l;
+      else if (ShapedType::isDynamic(rank2))
+        acc = r;
+      else if (rank1 != rank2)
+        return emitOptionalError(location, "unequal shape cardinality");
+      else
+        acc = l;
+    }
+  }
+  inferredReturnTypes.assign({acc});
   return success();
 }
 
@@ -1322,11 +1368,13 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   Type lhs = l.front();
   Type rhs = r.front();
 
-  if (lhs != rhs)
-    return false;
+  if (!lhs.isa<ShapeType, SizeType>())
+    std::swap(lhs, rhs);
 
-  if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
-    return true;
+  if (lhs.isa<SizeType>())
+    return rhs.isa<SizeType, IndexType>();
+  if (lhs.isa<ShapeType>())
+    return rhs.isa<ShapeType, TensorType>();
 
   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
     return true;

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 3b4059b1d6026..daed6a49a0e82 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -272,3 +272,20 @@ func.func @const_shape() {
   %0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
   return
 }
+
+// -----
+
+func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
+  // expected-error at +1 {{requires all sizes or shapes}}
+  %result = shape.meet %arg0, %arg1 : !shape.shape, index -> index
+  return %result : index
+}
+
+// -----
+
+func.func @invalid_meet(%arg0 : tensor<2xindex>, %arg1 : tensor<3xindex>) -> tensor<?xindex> {
+  // expected-error at +1 {{unequal shape cardinality}}
+  %result = shape.meet %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
+  return %result : tensor<?xindex>
+}
+

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 60ee2f4541a7c..0f442308b3f6a 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -325,3 +325,9 @@ func.func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
     !shape.size, !shape.size -> !shape.size
   return %2 : !shape.size
 }
+
+func.func @meet_index(%arg0 : index, %arg1 : index) -> index {
+  %result = shape.meet %arg0, %arg1 : index, index -> index
+  return %result : index
+}
+


        


More information about the Mlir-commits mailing list