[Mlir-commits] [mlir] 9ddb464 - [mlir] refactor common idiom into AffineMap method
Aart Bik
llvmlistbot at llvm.org
Fri Nov 13 19:18:29 PST 2020
Author: Aart Bik
Date: 2020-11-13T19:18:13-08:00
New Revision: 9ddb464d37b05b993734c62511576a947b4542df
URL: https://github.com/llvm/llvm-project/commit/9ddb464d37b05b993734c62511576a947b4542df
DIFF: https://github.com/llvm/llvm-project/commit/9ddb464d37b05b993734c62511576a947b4542df.diff
LOG: [mlir] refactor common idiom into AffineMap method
motivated by a refactoring in the new sparse code (yet to be merged), this avoids some lengthy code dup
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D91465
Added:
Modified:
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/IR/AffineMap.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index c450024dcb57..f1f267ff0fc2 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -125,6 +125,10 @@ class AffineMap {
ArrayRef<AffineExpr> getResults() const;
AffineExpr getResult(unsigned idx) const;
+ /// Extracts the position of the dimensional expression at the given result,
+ /// when the caller knows it is safe to do so.
+ unsigned getDimPosition(unsigned idx) const;
+
/// Walk all of the AffineExpr's in this mapping. Each node in an expression
/// tree is visited in postorder.
void walkExprs(std::function<void(AffineExpr)> callback) const;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index abc10e8f486a..8e1dbf17d3f1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -466,9 +466,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
numFoldedDims[pos] = foldedDims.getNumResults();
- ArrayRef<int64_t> shape = expandedShape.slice(
- foldedDims.getResult(0).cast<AffineDimExpr>().getPosition(),
- numFoldedDims[pos]);
+ ArrayRef<int64_t> shape =
+ expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
expandedDimsShape[pos].assign(shape.begin(), shape.end());
}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 39aed718de0a..0cc1e7c07aba 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -336,7 +336,7 @@ static LogicalResult verifyOutputShape(
VectorType v = pair.first;
auto map = pair.second;
for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
- unsigned pos = map.getResult(idx).cast<AffineDimExpr>().getPosition();
+ unsigned pos = map.getDimPosition(idx);
if (!extents[pos])
extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
}
@@ -785,8 +785,7 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
if (insertedPos.size() == extractedPos.size()) {
bool fold = true;
for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
- auto pos =
- permutationMap.getResult(idx).cast<AffineDimExpr>().getPosition();
+ auto pos = permutationMap.getDimPosition(idx);
if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
fold = false;
break;
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 49865fddba4c..e488db677fe5 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -50,7 +50,7 @@ using llvm::dbgs;
// Helper to find an index in an affine map.
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
- int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+ int64_t idx = map.getDimPosition(i);
if (idx == index)
return i;
}
@@ -76,7 +76,7 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
auto *ctx = rewriter.getContext();
SmallVector<AffineExpr, 4> results;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
- int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+ int64_t idx = map.getDimPosition(i);
if (idx == index)
continue;
// Re-insert remaining indices, but renamed when occurring
@@ -2016,16 +2016,13 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
int64_t iterIndex = -1;
int64_t dimSize = -1;
if (lhsIndex >= 0) {
- iterIndex = iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
- assert(
- (rhsIndex < 0 ||
- iterIndex ==
- iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition()) &&
- "parallel index should be free in LHS or batch in LHS/RHS");
+ iterIndex = iMap[0].getDimPosition(lhsIndex);
+ assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
+ "parallel index should be free in LHS or batch in LHS/RHS");
dimSize = lhsType.getDimSize(lhsIndex);
} else {
assert(rhsIndex >= 0 && "missing parallel index");
- iterIndex = iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
+ iterIndex = iMap[1].getDimPosition(rhsIndex);
dimSize = rhsType.getDimSize(rhsIndex);
}
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 1f73d07cc8ff..cc2cb8be4f3c 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -227,6 +227,10 @@ AffineExpr AffineMap::getResult(unsigned idx) const {
return map->results[idx];
}
+unsigned AffineMap::getDimPosition(unsigned idx) const {
+ return getResult(idx).cast<AffineDimExpr>().getPosition();
+}
+
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.
More information about the Mlir-commits
mailing list