[Mlir-commits] [mlir] f0c3fd1 - Don't combine if there would remain no true reduction dim.
Benoit Jacob
llvmlistbot at llvm.org
Tue Jul 19 12:59:07 PDT 2022
Author: Benoit Jacob
Date: 2022-07-19T19:58:53Z
New Revision: f0c3fd185e059a855c1bd4a4779975d3a5c5681f
URL: https://github.com/llvm/llvm-project/commit/f0c3fd185e059a855c1bd4a4779975d3a5c5681f
DIFF: https://github.com/llvm/llvm-project/commit/f0c3fd185e059a855c1bd4a4779975d3a5c5681f.diff
LOG: Don't combine if there would remain no true reduction dim.
Differential Revision: https://reviews.llvm.org/D130109
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a423cd2eca8f2..baffee208ddb4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1161,13 +1161,23 @@ struct CombineContractBroadcast
if (!unusedDimsBitVector.test(i))
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
}
- // Check that compressing unused dims isn't removing all reduction
- // iterators. For example, if the vector.contract had only one reduction
+ // Check that compressing unused dims isn't removing all reduction dimension
+ // pairs. For example, if the vector.contract had only one reduction
// iterator and that was a unit-dimension created by a broadcast,
// then we should bail here, otherwise we would create a contract without
- // a reduction iterator.
- if (!llvm::any_of(iterators, isReductionIterator))
+ // a reduction dimension pair.
+ bool hasReductionIteratorApplyingOnBothSides = false;
+ for (unsigned i = 0; i < iterators.size(); ++i) {
+ if (!isReductionIterator(iterators[i]))
+ continue;
+ if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
+ hasReductionIteratorApplyingOnBothSides = true;
+ break;
+ }
+ }
+ if (!hasReductionIteratorApplyingOnBothSides)
return failure();
+
// If the compressed maps have a dimension that is not used by either LHS or
// RHS then the ContractionOp verifier would fail.
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index f1587c2e2f3d6..87b30e6116b03 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -213,6 +213,38 @@ func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vecto
return %result : vector<1xi32>
}
+// -----
+
+// Test that CombineContractBroadcast is not combining this case, as that would
+// result in a vector.contract without a reduction dimention pair, as the only
+// reduction dimension would be used by only one side among LHS, RHS.
+// This is arguably a convoluted edge case (the affine_maps here look weird!)
+// but it is something that we actually ran into from linalg.matmul tests that
+// were exercising 1x1 shapes, and using various drop-unit-dims patterns.
+
+#map0 = affine_map<(d0, d1) -> (d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+#map2 = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL: contract_broadcast_would_have_no_reduction_dim_pair
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>, %[[ARG2:.+]]: vector<1xf32>)
+// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: vector.contract
+// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction"]
+// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32>
+
+func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>, %arg2 : vector<1xf32>) -> vector<1xf32> {
+ %1 = vector.broadcast %arg1 : vector<1xf32> to vector<1x1xf32>
+ %result = vector.contract {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg0, %1, %arg2 : vector<1xf32>, vector<1x1xf32> into vector<1xf32>
+ return %result : vector<1xf32>
+}
+
+
//===----------------------------------------------------------------------===//
// Reorder casting ops and vector ops. The casting ops have almost identical
// pattern, so only arith.extsi op is tested.
More information about the Mlir-commits
mailing list