[Mlir-commits] [mlir] f0c3fd1 - Don't combine if there would remain no true reduction dim.

Benoit Jacob llvmlistbot at llvm.org
Tue Jul 19 12:59:07 PDT 2022


Author: Benoit Jacob
Date: 2022-07-19T19:58:53Z
New Revision: f0c3fd185e059a855c1bd4a4779975d3a5c5681f

URL: https://github.com/llvm/llvm-project/commit/f0c3fd185e059a855c1bd4a4779975d3a5c5681f
DIFF: https://github.com/llvm/llvm-project/commit/f0c3fd185e059a855c1bd4a4779975d3a5c5681f.diff

LOG: Don't combine if there would remain no true reduction dim.

Differential Revision: https://reviews.llvm.org/D130109

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a423cd2eca8f2..baffee208ddb4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1161,13 +1161,23 @@ struct CombineContractBroadcast
       if (!unusedDimsBitVector.test(i))
         iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
     }
-    // Check that compressing unused dims isn't removing all reduction
-    // iterators. For example, if the vector.contract had only one reduction
+    // Check that compressing unused dims isn't removing all reduction dimension
+    // pairs. For example, if the vector.contract had only one reduction
     // iterator and that was a unit-dimension created by a broadcast,
     // then we should bail here, otherwise we would create a contract without
-    // a reduction iterator.
-    if (!llvm::any_of(iterators, isReductionIterator))
+    // a reduction dimension pair.
+    bool hasReductionIteratorApplyingOnBothSides = false;
+    for (unsigned i = 0; i < iterators.size(); ++i) {
+      if (!isReductionIterator(iterators[i]))
+        continue;
+      if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
+        hasReductionIteratorApplyingOnBothSides = true;
+        break;
+      }
+    }
+    if (!hasReductionIteratorApplyingOnBothSides)
       return failure();
+
     // If the compressed maps have a dimension that is not used by either LHS or
     // RHS then the ContractionOp verifier would fail.
     if (getUnusedDimsBitVector({maps[0], maps[1]}).any())

diff  --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index f1587c2e2f3d6..87b30e6116b03 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -213,6 +213,38 @@ func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vecto
   return  %result : vector<1xi32>
 }
 
+// -----
+
+// Test that CombineContractBroadcast is not combining this case, as that would
+// result in a vector.contract without a reduction dimention pair, as the only
+// reduction dimension would be used by only one side among LHS, RHS.
+// This is arguably a convoluted edge case (the affine_maps here look weird!)
+// but it is something that we actually ran into from linalg.matmul tests that
+// were exercising 1x1 shapes, and using various drop-unit-dims patterns.
+
+#map0 = affine_map<(d0, d1) -> (d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+#map2 = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL: contract_broadcast_would_have_no_reduction_dim_pair
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>, %[[ARG2:.+]]: vector<1xf32>)
+//  CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32>
+//  CHECK: vector.contract
+//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: iterator_types = ["parallel", "reduction"]
+//  CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32>
+
+func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>, %arg2 : vector<1xf32>) -> vector<1xf32> {
+  %1 = vector.broadcast %arg1 : vector<1xf32> to vector<1x1xf32>
+  %result = vector.contract {
+    indexing_maps = [#map0, #map1, #map2],
+    iterator_types = ["parallel", "reduction"],
+    kind = #vector.kind<add>
+  } %arg0, %1, %arg2 : vector<1xf32>, vector<1x1xf32> into vector<1xf32>
+  return %result : vector<1xf32>
+}
+
+
 //===----------------------------------------------------------------------===//
 // Reorder casting ops and vector ops. The casting ops have almost identical
 // pattern, so only arith.extsi op is tested.


        


More information about the Mlir-commits mailing list