[Mlir-commits] [mlir] 51cbe4e - [mlir] Fix broadcasting check with 1 values

Jacques Pienaar llvmlistbot at llvm.org
Sun Jul 11 20:41:53 PDT 2021


Author: Jacques Pienaar
Date: 2021-07-11T20:41:33-07:00
New Revision: 51cbe4e58797d85c0f17d4a9cad1bcb11743afae

URL: https://github.com/llvm/llvm-project/commit/51cbe4e58797d85c0f17d4a9cad1bcb11743afae
DIFF: https://github.com/llvm/llvm-project/commit/51cbe4e58797d85c0f17d4a9cad1bcb11743afae.diff

LOG: [mlir] Fix broadcasting check with 1 values

The trait was inconsistent with the other broadcasting logic here. And
also fix printing here to use ? rather than -1 in the error.

Differential Revision: https://reviews.llvm.org/D105748

Added: 
    

Modified: 
    mlir/lib/Dialect/Traits.cpp
    mlir/test/Dialect/traits.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index 50f203644741..ce2feff441bc 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -192,14 +192,18 @@ static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
       llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
 }
 
-static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
-                                ArrayRef<int64_t> shape2) {
+static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
+                                            ArrayRef<int64_t> existing) {
   auto isCompatible = [](int64_t dim1, int64_t dim2) {
-    return dim1 == dim2 || dim1 == -1 || dim2 == -1;
+    // If the inferred and existing dim is the same, or one of them is unknown
+    // then it is compatible, else if the inferred dim is 1 then it is also
+    // compatible. But if the existing dim is 1 and the inferred is greater than
+    // 1 then flag.
+    return dim1 == dim2 || dim1 == -1 || dim2 == -1 || dim1 == 1;
   };
-  if (shape1.size() != shape2.size())
+  if (inferred.size() != existing.size())
     return false;
-  for (auto p : llvm::zip(shape1, shape2))
+  for (auto p : llvm::zip(inferred, existing))
     if (!isCompatible(std::get<0>(p), std::get<1>(p)))
       return false;
   return true;
@@ -208,8 +212,20 @@ static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
 static std::string getShapeString(ArrayRef<int64_t> shape) {
   // TODO: should replace with printing shape more uniformly across here and
   // when in type.
-  return std::string(
-      formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end())));
+  std::string ret;
+  llvm::raw_string_ostream ss(ret);
+  ss << '\'';
+  llvm::interleave(
+      shape, ss,
+      [&](int64_t dim) {
+        if (ShapedType::isDynamic(dim))
+          ss << '?';
+        else
+          ss << dim;
+      },
+      "x");
+  ss << '\'';
+  return ss.str();
 }
 
 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
@@ -252,7 +268,7 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
   for (auto type : rankedResults) {
     ArrayRef<int64_t> actualSuffix =
         getShape(type).take_back(resultShape.size());
-    if (!areCompatibleShapes(actualSuffix, resultShape))
+    if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
       return op->emitOpError()
              << "result type " << getShapeString(getShape(type))
              << " not broadcast compatible with broadcasted operands's shapes "

diff  --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
index aaea63d14361..daf09ebd79a2 100644
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -111,6 +111,13 @@ func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor
 
 // -----
 
+func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x6x1xi32>, %arg1: tensor<*xi32>) -> tensor<?x6x6xi32> {
+  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x6x1xi32>, tensor<*xi32>) -> tensor<?x6x6xi32>
+  return %0 : tensor<?x6x6xi32>
+}
+
+// -----
+
 // Unranked operands but ranked result
 func @broadcast_tensor_tensor_tensor(tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32> {
 ^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>):


        


More information about the Mlir-commits mailing list