[Mlir-commits] [mlir] c3839c0 - CombineContractBroadcast should not create dims unused in LHS+RHS
Benoit Jacob
llvmlistbot at llvm.org
Mon Jul 4 09:52:52 PDT 2022
Author: Benoit Jacob
Date: 2022-07-04T16:52:35Z
New Revision: c3839c0b46a902f96f9395ec15d89d0cb21c73b3
URL: https://github.com/llvm/llvm-project/commit/c3839c0b46a902f96f9395ec15d89d0cb21c73b3
DIFF: https://github.com/llvm/llvm-project/commit/c3839c0b46a902f96f9395ec15d89d0cb21c73b3.diff
LOG: CombineContractBroadcast should not create dims unused in LHS+RHS
Differential Revision: https://reviews.llvm.org/D129087
Added:
Modified:
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index de94f43708fad..952fbde9b487f 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -18,6 +18,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/SmallBitVector.h"
namespace llvm {
class SmallBitVector;
@@ -584,6 +585,11 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
map.print(os);
return os;
}
+
+// Return a bitvector where each bit set indicates a dimension that is not used
+// by any of the maps in the input array `maps`.
+llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
+
} // namespace mlir
namespace llvm {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6c1ba2161b83c..a3874795571d5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -687,6 +687,9 @@ static LogicalResult verifyOutputShape(
MLIRContext *ctx = op.getContext();
AffineMap lhsMap = op.getIndexingMaps()[0];
AffineMap rhsMap = op.getIndexingMaps()[1];
+ if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+ return op.emitOpError(
+ "expected all dimensions to be either a LHS or a RHS dimension");
SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
for (auto pair :
{std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
@@ -699,8 +702,8 @@ static LogicalResult verifyOutputShape(
}
}
if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
- return op.emitOpError("expected all input dimensions to be used by "
- "either the LHS or the RHS");
+ return op.emitOpError("expected all dimensions to get an extent as "
+ "either a LHS or a RHS dimension");
AffineMap resMap = op.getIndexingMaps()[2];
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 80e2a0a3f7ea3..bb8cc2bfae396 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -32,7 +32,6 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@@ -1155,21 +1154,14 @@ struct CombineContractBroadcast
// Determine which dims are usused, now that the maps have been composed
// with the broadcast maps.
- unsigned numDims = maps[0].getNumDims();
- llvm::SmallBitVector unusedDims(numDims, true);
- for (const auto &m : maps) {
- for (unsigned i = 0; i < numDims; ++i) {
- if (m.isFunctionOfDim(i))
- unusedDims.reset(i);
- }
- }
+ llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
// Compress unused dims.
for (auto &m : maps)
- m = compressDims(m, unusedDims);
+ m = compressDims(m, unusedDimsBitVector);
// Compute the combined iterators.
SmallVector<Attribute, 4> iterators;
- for (unsigned i = 0; i < numDims; ++i) {
- if (!unusedDims.test(i))
+ for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
+ if (!unusedDimsBitVector.test(i))
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
}
// Check that compressing unused dims isn't removing all reduction
@@ -1179,7 +1171,10 @@ struct CombineContractBroadcast
// a reduction iterator.
if (!llvm::any_of(iterators, isReductionIterator))
return failure();
-
+ // If the compressed maps have a dimension that is not used by either LHS or
+ // RHS then the ContractionOp verifier would fail.
+ if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
+ return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index c93b4c28d769a..a8de9579edae6 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -560,12 +560,7 @@ AffineMap mlir::compressDims(AffineMap map,
}
AffineMap mlir::compressUnusedDims(AffineMap map) {
- llvm::SmallBitVector unusedDims(map.getNumDims(), true);
- map.walkExprs([&](AffineExpr expr) {
- if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
- unusedDims.reset(dimExpr.getPosition());
- });
- return compressDims(map, unusedDims);
+ return compressDims(map, getUnusedDimsBitVector({map}));
}
static SmallVector<AffineMap>
@@ -722,6 +717,18 @@ AffineMap mlir::getProjectedMap(AffineMap map,
return compressUnusedSymbols(compressDims(map, unusedDims));
}
+llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
+ unsigned numDims = maps[0].getNumDims();
+ llvm::SmallBitVector numDimsBitVector(numDims, true);
+ for (const auto &m : maps) {
+ for (unsigned i = 0; i < numDims; ++i) {
+ if (m.isFunctionOfDim(i))
+ numDimsBitVector.reset(i);
+ }
+ }
+ return numDimsBitVector;
+}
+
//===----------------------------------------------------------------------===//
// MutableAffineMap.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 9a2afb9cc2774..87e5f9443807c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -875,7 +875,7 @@ func.func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: ve
// -----
func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
-// expected-error at +1 {{'vector.contract' op expected all input dimensions to be used by either the LHS or the RHS}}
+// expected-error at +1 {{'vector.contract' op expected all dimensions to be either a LHS or a RHS dimension}}
%result = vector.contract {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 8549327326b80..ade539e278226 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -159,6 +159,10 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
@@ -178,6 +182,37 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
return %result : vector<8x8xi32>
}
+// -----
+
+// Test that CombineContractBroadcast is not combining this case, as that would
+// result in a dimension being unused in the LHS and RHS maps, which is illegal.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
+// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
+// CHECK: vector.contract
+// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
+// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
+
+func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
+ %1 = vector.broadcast %arg1 : vector<2xi32> to vector<1x1x2xi32>
+ %result = vector.contract {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg0, %1, %arg2 : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
+ return %result : vector<1xi32>
+}
+
//===----------------------------------------------------------------------===//
// Reorder casting ops and vector ops. The casting ops have almost identical
// pattern, so only arith.extsi op is tested.
More information about the Mlir-commits
mailing list