[Mlir-commits] [mlir] 51cbe4e - [mlir] Fix broadcasting check with 1 values
Jacques Pienaar
llvmlistbot at llvm.org
Sun Jul 11 20:41:53 PDT 2021
Author: Jacques Pienaar
Date: 2021-07-11T20:41:33-07:00
New Revision: 51cbe4e58797d85c0f17d4a9cad1bcb11743afae
URL: https://github.com/llvm/llvm-project/commit/51cbe4e58797d85c0f17d4a9cad1bcb11743afae
DIFF: https://github.com/llvm/llvm-project/commit/51cbe4e58797d85c0f17d4a9cad1bcb11743afae.diff
LOG: [mlir] Fix broadcasting check with 1 values
The trait was inconsistent with the other broadcasting logic here. And
also fix printing here to use ? rather than -1 in the error.
Differential Revision: https://reviews.llvm.org/D105748
Added:
Modified:
mlir/lib/Dialect/Traits.cpp
mlir/test/Dialect/traits.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index 50f203644741..ce2feff441bc 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -192,14 +192,18 @@ static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
}
-static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
- ArrayRef<int64_t> shape2) {
+static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
+ ArrayRef<int64_t> existing) {
auto isCompatible = [](int64_t dim1, int64_t dim2) {
- return dim1 == dim2 || dim1 == -1 || dim2 == -1;
+ // If the inferred and existing dim is the same, or one of them is unknown
+ // then it is compatible, else if the inferred dim is 1 then it is also
+ // compatible. But if the existing dim is 1 and the inferred is greater than
+ // 1 then flag.
+ return dim1 == dim2 || dim1 == -1 || dim2 == -1 || dim1 == 1;
};
- if (shape1.size() != shape2.size())
+ if (inferred.size() != existing.size())
return false;
- for (auto p : llvm::zip(shape1, shape2))
+ for (auto p : llvm::zip(inferred, existing))
if (!isCompatible(std::get<0>(p), std::get<1>(p)))
return false;
return true;
@@ -208,8 +212,20 @@ static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
static std::string getShapeString(ArrayRef<int64_t> shape) {
// TODO: should replace with printing shape more uniformly across here and
// when in type.
- return std::string(
- formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end())));
+ std::string ret;
+ llvm::raw_string_ostream ss(ret);
+ ss << '\'';
+ llvm::interleave(
+ shape, ss,
+ [&](int64_t dim) {
+ if (ShapedType::isDynamic(dim))
+ ss << '?';
+ else
+ ss << dim;
+ },
+ "x");
+ ss << '\'';
+ return ss.str();
}
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
@@ -252,7 +268,7 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
for (auto type : rankedResults) {
ArrayRef<int64_t> actualSuffix =
getShape(type).take_back(resultShape.size());
- if (!areCompatibleShapes(actualSuffix, resultShape))
+ if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
return op->emitOpError()
<< "result type " << getShapeString(getShape(type))
<< " not broadcast compatible with broadcasted operands's shapes "
diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
index aaea63d14361..daf09ebd79a2 100644
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -111,6 +111,13 @@ func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor
// -----
+func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x6x1xi32>, %arg1: tensor<*xi32>) -> tensor<?x6x6xi32> {
+ %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x6x1xi32>, tensor<*xi32>) -> tensor<?x6x6xi32>
+ return %0 : tensor<?x6x6xi32>
+}
+
+// -----
+
// Unranked operands but ranked result
func @broadcast_tensor_tensor_tensor(tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>):
More information about the Mlir-commits
mailing list