[Mlir-commits] [mlir] 1f02ad7 - [mlir][shape] Update meet to handle all size & shape types
Jacques Pienaar
llvmlistbot at llvm.org
Wed Aug 10 05:08:32 PDT 2022
Author: Jacques Pienaar
Date: 2022-08-10T05:08:24-07:00
New Revision: 1f02ad71310b9e86d183abdedc75ca99ff1106f5
URL: https://github.com/llvm/llvm-project/commit/1f02ad71310b9e86d183abdedc75ca99ff1106f5
DIFF: https://github.com/llvm/llvm-project/commit/1f02ad71310b9e86d183abdedc75ca99ff1106f5.diff
LOG: [mlir][shape] Update meet to handle all size & shape types
Also tighten up return type inference & compatibility functions.
Differential Revision: https://reviews.llvm.org/D130866
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index 57792edaf17b8..d5a8a65b419d7 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -110,6 +110,11 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
+// Any type representing a shape or size/dim.
+def Shape_AnyShapeOrSizeType : AnyTypeOf<
+ [Shape_SizeOrIndexType, Shape_ShapeOrExtentTensorType],
+ "any shape or size">;
+
def Shape_WitnessType : Shape_Type<"Witness", "witness"> {
let description = [{
A witness is a structural device in the compiler to maintain ordering of
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2773883f57a2a..6d0d84dbbd0fd 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -406,11 +406,11 @@ def Shape_MaxOp : Shape_Op<"max",
def Shape_MeetOp : Shape_Op<"meet",
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
- let summary = "Returns the least general shape.shape of its operands";
+ let summary = "Returns the least general shape or size of its operands";
let description = [{
- An operation that computes the least general shape of input operands.
+ An operation that computes the least general shape or dim of input operands.
This effectively asserts that corresponding static dimensions are equal.
- The behavior is to match each element of the `shape.shape` and propagate the
+ The behavior is to match each element of the shape/size and propagate the
most restrictive information, returning an invalid shape if there are
contradictory requirements. E.g., using pseudo code
@@ -433,9 +433,11 @@ def Shape_MeetOp : Shape_Op<"meet",
```
}];
- let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
- OptionalAttr<StrAttr>:$error);
- let results = (outs Shape_ShapeOrSizeType:$result);
+ let arguments = (ins
+ Shape_AnyShapeOrSizeType:$arg0,
+ Shape_AnyShapeOrSizeType:$arg1,
+ OptionalAttr<StrAttr>:$error);
+ let results = (outs Shape_AnyShapeOrSizeType:$result);
let assemblyFormat = [{
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ce1767ba8413e..ff7065474bdb4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1309,7 +1309,53 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- inferredReturnTypes.assign({operands[0].getType()});
+ if (operands.empty())
+ return failure();
+
+ auto isShapeType = [](Type arg) {
+ if (arg.isa<ShapeType>())
+ return true;
+ return isExtentTensorType(arg);
+ };
+
+ ValueRange::type_range types = operands.getTypes();
+ Type acc = types.front();
+ for (auto t : drop_begin(types)) {
+ Type l = acc, r = t;
+ if (!l.isa<ShapeType, SizeType>())
+ std::swap(l, r);
+
+ // Handle sizes, propagate error type if present.
+ if (l.isa<SizeType>()) {
+ if (r.isa<SizeType, IndexType>())
+ acc = l;
+ else
+ return emitOptionalError(location, "requires all sizes or shapes");
+ } else if (l.isa<IndexType>()) {
+ if (r.isa<IndexType>())
+ acc = r;
+ else
+ return emitOptionalError(location, "requires all sizes or shapes");
+ } else if (l.isa<ShapeType>()) {
+ // Handle shapes, propagate error type if present.
+ if (isShapeType(r))
+ acc = l;
+ else
+ return emitOptionalError(location, "requires all sizes or shapes");
+ } else if (isExtentTensorType(l)) {
+ auto rank1 = l.cast<RankedTensorType>().getShape()[0];
+ auto rank2 = r.cast<RankedTensorType>().getShape()[0];
+ if (ShapedType::isDynamic(rank1))
+ acc = l;
+ else if (ShapedType::isDynamic(rank2))
+ acc = r;
+ else if (rank1 != rank2)
+ return emitOptionalError(location, "unequal shape cardinality");
+ else
+ acc = l;
+ }
+ }
+ inferredReturnTypes.assign({acc});
return success();
}
@@ -1322,11 +1368,13 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
Type lhs = l.front();
Type rhs = r.front();
- if (lhs != rhs)
- return false;
+ if (!lhs.isa<ShapeType, SizeType>())
+ std::swap(lhs, rhs);
- if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
- return true;
+ if (lhs.isa<SizeType>())
+ return rhs.isa<SizeType, IndexType>();
+ if (lhs.isa<ShapeType>())
+ return rhs.isa<ShapeType, TensorType>();
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
return true;
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 3b4059b1d6026..daed6a49a0e82 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -272,3 +272,20 @@ func.func @const_shape() {
%0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
return
}
+
+// -----
+
+func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
+ // expected-error at +1 {{requires all sizes or shapes}}
+ %result = shape.meet %arg0, %arg1 : !shape.shape, index -> index
+ return %result : index
+}
+
+// -----
+
+func.func @invalid_meet(%arg0 : tensor<2xindex>, %arg1 : tensor<3xindex>) -> tensor<?xindex> {
+ // expected-error at +1 {{unequal shape cardinality}}
+ %result = shape.meet %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
+ return %result : tensor<?xindex>
+}
+
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 60ee2f4541a7c..0f442308b3f6a 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -325,3 +325,9 @@ func.func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}
+
+func.func @meet_index(%arg0 : index, %arg1 : index) -> index {
+ %result = shape.meet %arg0, %arg1 : index, index -> index
+ return %result : index
+}
+
More information about the Mlir-commits
mailing list