[Mlir-commits] [mlir] b1d3afc - [mlir] Factor more common utils to IndexingUtils
Hanhan Wang
llvmlistbot at llvm.org
Fri Dec 2 13:27:13 PST 2022
Author: Hanhan Wang
Date: 2022-12-02T13:27:01-08:00
New Revision: b1d3afc93e0e6bdfbc1105b48bc4caed0c880be2
URL: https://github.com/llvm/llvm-project/commit/b1d3afc93e0e6bdfbc1105b48bc4caed0c880be2
DIFF: https://github.com/llvm/llvm-project/commit/b1d3afc93e0e6bdfbc1105b48bc4caed0c880be2.diff
LOG: [mlir] Factor more common utils to IndexingUtils
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D139159
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
mlir/include/mlir/Dialect/Utils/IndexingUtils.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/Utils/IndexingUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index e231bddfcc414..28c75fcfa6530 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -70,10 +70,6 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b);
-/// Check if `permutation` is a permutation of the range
-/// `[0, permutation.size())`.
-bool isPermutation(ArrayRef<int64_t> permutation);
-
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index ee1d4550e953c..bc58c127e1624 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -75,6 +75,12 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
inVec = auxVec;
}
+/// Helper method to apply to inverse a permutation.
+SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
+
+/// Method to check if an interchange vector is a permutation.
+bool isPermutationVector(ArrayRef<int64_t> interchange);
+
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
unsigned dropBack = 0);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1f6a8fd91efb5..a6f42c9577ef7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
@@ -1392,7 +1393,7 @@ void TransposeOp::print(OpAsmPrinter &p) {
LogicalResult TransposeOp::verify() {
ArrayRef<int64_t> permutationRef = getPermutation();
- if (!isPermutation(permutationRef))
+ if (!isPermutationVector(permutationRef))
return emitOpError("permutation is not valid");
auto inputType = getInput().getType();
@@ -1683,19 +1684,6 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
return llvm::to_vector<4>(concatRanges);
}
-bool mlir::linalg::isPermutation(ArrayRef<int64_t> permutation) {
- // Count the number of appearances for all indices.
- SmallVector<int64_t> indexCounts(permutation.size(), 0);
- for (auto index : permutation) {
- // Exit if the index is out-of-range.
- if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
- return false;
- ++indexCounts[index];
- }
- // Return true if all indices appear once.
- return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
-}
-
static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
if (auto memref = t.dyn_cast<MemRefType>()) {
ss << "view";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index bff2d5406f5c2..a83305d013bcb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -151,7 +151,7 @@ computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector) {
if (transposeVector.empty())
return rankedTensorType;
- if (!isPermutation(transposeVector) ||
+ if (!isPermutationVector(transposeVector) ||
transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 712d4c2de32f1..3267b2aff075f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
@@ -409,7 +410,7 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
Type elementType = resultTensorType.getElementType();
- assert(isPermutation(transposeVector) &&
+ assert(isPermutationVector(transposeVector) &&
"expect transpose vector to be a permutation");
assert(transposeVector.size() ==
static_cast<size_t>(resultTensorType.getRank()) &&
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index aea0f07a2272b..092f853b11d10 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -59,34 +60,6 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
return filledVector;
}
-/// Helper method to apply permutation to a vector
-template <typename T>
-static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
- ArrayRef<int64_t> interchange) {
- assert(interchange.size() == vector.size());
- return llvm::to_vector(
- llvm::map_range(interchange, [&](int64_t val) { return vector[val]; }));
-}
-/// Helper method to apply to invert a permutation.
-static SmallVector<int64_t>
-invertPermutationVector(ArrayRef<int64_t> interchange) {
- SmallVector<int64_t> inversion(interchange.size());
- for (const auto &pos : llvm::enumerate(interchange)) {
- inversion[pos.value()] = pos.index();
- }
- return inversion;
-}
-/// Method to check if an interchange vector is a permutation.
-static bool isPermutation(ArrayRef<int64_t> interchange) {
- llvm::SmallDenseSet<int64_t, 4> seenVals;
- for (auto val : interchange) {
- if (seenVals.count(val))
- return false;
- seenVals.insert(val);
- }
- return seenVals.size() == interchange.size();
-}
-
//===----------------------------------------------------------------------===//
// tileUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
@@ -321,16 +294,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
iterationDomain.size());
}
if (!interchangeVector.empty()) {
- if (!isPermutation(interchangeVector)) {
+ if (!isPermutationVector(interchangeVector)) {
return rewriter.notifyMatchFailure(
op, "invalid intechange vector, not a permutation of the entire "
"iteration space");
}
- iterationDomain =
- applyPermutationToVector(iterationDomain, interchangeVector);
- tileSizeVector =
- applyPermutationToVector(tileSizeVector, interchangeVector);
+ applyPermutationToVector(iterationDomain, interchangeVector);
+ applyPermutationToVector(tileSizeVector, interchangeVector);
}
// 3. Materialize an empty loop nest that iterates over the tiles. These
@@ -341,8 +312,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
- offsets = applyPermutationToVector(offsets, inversePermutation);
- sizes = applyPermutationToVector(sizes, inversePermutation);
+ applyPermutationToVector(offsets, inversePermutation);
+ applyPermutationToVector(sizes, inversePermutation);
}
}
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 1c7a89cd755f3..ed92ade8ab092 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -86,6 +86,25 @@ int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
std::multiplies<int64_t>());
}
+llvm::SmallVector<int64_t>
+mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
+ SmallVector<int64_t> inversion(permutation.size());
+ for (const auto &pos : llvm::enumerate(permutation)) {
+ inversion[pos.value()] = pos.index();
+ }
+ return inversion;
+}
+
+bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
+ llvm::SmallDenseSet<int64_t, 4> seenVals;
+ for (auto val : interchange) {
+ if (seenVals.count(val))
+ return false;
+ seenVals.insert(val);
+ }
+ return seenVals.size() == interchange.size();
+}
+
llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront,
unsigned dropBack) {
More information about the Mlir-commits
mailing list