[Mlir-commits] [mlir] 25a20b8 - [mlir] Correct verifyCompatibleShapes
Tres Popp
llvmlistbot at llvm.org
Thu Mar 11 04:04:23 PST 2021
Author: Tres Popp
Date: 2021-03-11T13:04:10+01:00
New Revision: 25a20b8aa68e6fa6129f1ce4d0125365f399b59d
URL: https://github.com/llvm/llvm-project/commit/25a20b8aa68e6fa6129f1ce4d0125365f399b59d
DIFF: https://github.com/llvm/llvm-project/commit/25a20b8aa68e6fa6129f1ce4d0125365f399b59d.diff
LOG: [mlir] Correct verifyCompatibleShapes
verifyCompatibleShapes is not transitive. Create an n-ary version and
update SameOperandShapes and SameOperandAndResultShapes traits to use
it.
Differential Revision: https://reviews.llvm.org/D98331
Added:
Modified:
mlir/include/mlir/IR/TypeUtilities.h
mlir/lib/IR/Operation.cpp
mlir/lib/IR/TypeUtilities.cpp
mlir/test/IR/traits.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h
index 52a4e497e621..b790f99a5ccc 100644
--- a/mlir/include/mlir/IR/TypeUtilities.h
+++ b/mlir/include/mlir/IR/TypeUtilities.h
@@ -59,6 +59,13 @@ LogicalResult verifyCompatibleShape(Type type1, Type type2);
/// each pair wise entries have compatible shape.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2);
+/// Returns success if all given types have compatible shapes. That is, they are
+/// all scalars (not shaped), or they are all shaped types and any ranked shapes
+/// have compatible dimensions. The element type does not matter.
+LogicalResult verifyCompatibleShapes(TypeRange types);
+
+/// Dimensions are compatible if all non-dynamic dims are equal.
+LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims);
//===----------------------------------------------------------------------===//
// Utility Iterators
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index c214e83b4f2d..0614a7bd63d3 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -834,11 +834,9 @@ LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)))
return failure();
- auto type = op->getOperand(0).getType();
- for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
- if (failed(verifyCompatibleShape(opType, type)))
- return op->emitOpError() << "requires the same shape for all operands";
- }
+ if (failed(verifyCompatibleShapes(op->getOperandTypes())))
+ return op->emitOpError() << "requires the same shape for all operands";
+
return success();
}
@@ -847,17 +845,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
failed(verifyAtLeastNResults(op, 1)))
return failure();
- auto type = op->getOperand(0).getType();
- for (auto resultType : op->getResultTypes()) {
- if (failed(verifyCompatibleShape(resultType, type)))
- return op->emitOpError()
- << "requires the same shape for all operands and results";
- }
- for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
- if (failed(verifyCompatibleShape(opType, type)))
- return op->emitOpError()
- << "requires the same shape for all operands and results";
- }
+ SmallVector<Type, 8> types(op->getOperandTypes());
+ types.append(llvm::to_vector<4>(op->getResultTypes()));
+
+ if (failed(verifyCompatibleShapes(types)))
+ return op->emitOpError()
+ << "requires the same shape for all operands and results";
+
return success();
}
diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index f5391e5ff521..bc6a0b9d9af3 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -11,6 +11,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/TypeUtilities.h"
+
+#include <numeric>
+
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
@@ -97,6 +100,57 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
return success();
}
+LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
+ if (dims.empty())
+ return success();
+ auto staticDim = std::accumulate(
+ dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
+ return ShapedType::isDynamic(dim) ? fold : dim;
+ });
+ return success(llvm::all_of(dims, [&](auto dim) {
+ return ShapedType::isDynamic(dim) || dim == staticDim;
+ }));
+}
+
+/// Returns success if all given types have compatible shapes. That is, they are
+/// all scalars (not shaped), or they are all shaped types and any ranked shapes
+/// have compatible dimensions. Dimensions are compatible if all non-dynamic
+/// dims are equal. The element type does not matter.
+LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
+ auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
+ types, [](auto type) { return type.template dyn_cast<ShapedType>(); }));
+ // Return failure if some, but not all are not shaped. Return early if none
+ // are shaped also.
+ if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
+ return success();
+ if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
+ return failure();
+
+ // Remove all unranked shapes
+ auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
+ shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
+ if (shapes.empty())
+ return success();
+
+ // All ranks should be equal
+ auto firstRank = shapes.front().getRank();
+ if (llvm::any_of(shapes,
+ [&](auto shape) { return firstRank != shape.getRank(); }))
+ return failure();
+
+ for (unsigned i = 0; i < firstRank; ++i) {
+ // Retrieve all ranked dimensions
+ auto dims = llvm::to_vector<8>(llvm::map_range(
+ llvm::make_filter_range(
+ shapes, [&](auto shape) { return shape.getRank() >= i; }),
+ [&](auto shape) { return shape.getDimSize(i); }));
+ if (verifyCompatibleDims(dims).failed())
+ return failure();
+ }
+
+ return success();
+}
+
OperandElementTypeIterator::OperandElementTypeIterator(
Operation::operand_iterator it)
: llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>(
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index e3604b7b5387..858f601c0211 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -133,6 +133,13 @@ func @failedSameOperandAndResultShape_operand_result_mismatch(%t10x10 : tensor<1
// -----
+func @failedSameOperandAndResultShape_operand_result_mismatch(%t10 : tensor<10xf32>, %t1: tensor<?xf32>) {
+ // expected-error at +1 {{requires the same shape for all operands and results}}
+ "test.same_operand_and_result_shape"(%t1, %t10) : (tensor<?xf32>, tensor<10xf32>) -> tensor<3xf32>
+}
+
+// -----
+
func @failedSameOperandAndResultShape_no_operands() {
// expected-error at +1 {{expected 1 or more operands}}
"test.same_operand_and_result_shape"() : () -> (tensor<1xf32>)
@@ -347,7 +354,7 @@ func @failedSingleBlockImplicitTerminator_missing_terminator() {
func private @foo()
"test.finish" () : () -> ()
}) : () -> ()
-func private @foo()
+func private @foo()
// -----
More information about the Mlir-commits
mailing list