[Mlir-commits] [mlir] 9072f1b - [mlir][linalg] Add isPermutation helper (NFC).
Tobias Gysi
llvmlistbot at llvm.org
Tue Sep 21 08:07:57 PDT 2021
Author: Tobias Gysi
Date: 2021-09-21T15:07:39Z
New Revision: 9072f1b5f81347b36f0668e8cc10802fedbc6cfd
URL: https://github.com/llvm/llvm-project/commit/9072f1b5f81347b36f0668e8cc10802fedbc6cfd
DIFF: https://github.com/llvm/llvm-project/commit/9072f1b5f81347b36f0668e8cc10802fedbc6cfd.diff
LOG: [mlir][linalg] Add isPermutation helper (NFC).
Add a helper method to check if an index vector contains a permutation of its indices. Additionally, refactor applyPermutationToVector to take int64_t.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D110135
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index a8d4edf6c072d..67e2fc35367af 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -29,16 +29,20 @@ class LinalgDependenceGraph;
// General utilities
//===----------------------------------------------------------------------===//
+/// Check if `permutation` is a permutation of the range
+/// `[0, permutation.size())`.
+bool isPermutation(ArrayRef<int64_t> permutation);
+
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
- ArrayRef<unsigned> permutation) {
+ ArrayRef<int64_t> permutation) {
SmallVector<T, N> auxVec(inVec.size());
- for (unsigned i = 0; i < permutation.size(); ++i)
- auxVec[i] = inVec[permutation[i]];
+ for (auto en : enumerate(permutation))
+ auxVec[en.index()] = inVec[en.value()];
inVec = auxVec;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 0e65cb23964ca..13a6fe4fcc903 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -367,6 +367,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
ArrayRef<int64_t> tileInterchange) {
assert(tileSizes.size() == tileInterchange.size() &&
"expect the number of tile sizes and interchange dims to match");
+ assert(isPermutation(tileInterchange) &&
+ "expect tile interchange is a permutation");
// Create an empty tile loop nest.
TileLoopNest tileLoopNest(consumerOp);
@@ -375,9 +377,7 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
// inner reduction dimensions.
SmallVector<StringAttr> iterTypes =
llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
- applyPermutationToVector(
- iterTypes,
- SmallVector<unsigned>(tileInterchange.begin(), tileInterchange.end()));
+ applyPermutationToVector(iterTypes, tileInterchange);
auto *it = find_if(iterTypes, [&](StringAttr iterType) {
return !isParallelIterator(iterType);
});
@@ -459,14 +459,10 @@ struct LinalgTileAndFuseTensorOps
tileInterchange.begin() +
rootOp.getNumLoops());
- // As a tiling can only tile a loop dimension once, `rootInterchange` has to
- // be a permutation of the `rootOp` loop dimensions.
- SmallVector<AffineExpr> rootInterchangeExprs;
- transform(rootInterchange, std::back_inserter(rootInterchangeExprs),
- [&](int64_t dim) { return b.getAffineDimExpr(dim); });
- AffineMap rootInterchangeMap = AffineMap::get(
- rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext());
- if (!rootInterchangeMap.isPermutation())
+ // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
+ // It has to be a permutation since the tiling cannot tile the same loop
+ // dimension multiple times.
+ if (!isPermutation(rootInterchange))
return notifyFailure(
"expect the tile interchange permutes the root loops");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index d94fa30500b75..a42c204a389ee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -69,7 +69,9 @@ void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
llvm::append_range(itTypesVector, itTypes);
- applyPermutationToVector(itTypesVector, interchangeVector);
+ SmallVector<int64_t> permutation(interchangeVector.begin(),
+ interchangeVector.end());
+ applyPermutationToVector(itTypesVector, permutation);
genericOp->setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(context, itTypesVector));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index f3b9516bacf00..e36915d60108f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -206,8 +206,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
invPermutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, b.getContext()));
assert(invPermutationMap);
- applyPermutationToVector(loopRanges, interchangeVector);
- applyPermutationToVector(iteratorTypes, interchangeVector);
+ SmallVector<int64_t> permutation(interchangeVector.begin(),
+ interchangeVector.end());
+ applyPermutationToVector(loopRanges, permutation);
+ applyPermutationToVector(iteratorTypes, permutation);
}
// 2. Create the tiled loops.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e0b1a4840b41f..b7a2becefa77f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -138,6 +138,19 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
namespace mlir {
namespace linalg {
+bool isPermutation(ArrayRef<int64_t> permutation) {
+ // Count the number of appearances for all indices.
+ SmallVector<int64_t> indexCounts(permutation.size(), 0);
+ for (auto index : permutation) {
+ // Exit if the index is out-of-range.
+ if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
+ return false;
+ indexCounts[index]++;
+ }
+ // Return true if all indices appear once.
+ return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
+}
+
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
More information about the Mlir-commits
mailing list