[Mlir-commits] [mlir] 5b2d850 - [mlir][Linalg] NFC - Expose helper function `substituteMin`.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Mar 19 09:31:41 PDT 2021
Author: Nicolas Vasilache
Date: 2021-03-19T16:26:52Z
New Revision: 5b2d8503d1d4b925e30fd2b91f97bfd625f03157
URL: https://github.com/llvm/llvm-project/commit/5b2d8503d1d4b925e30fd2b91f97bfd625f03157
DIFF: https://github.com/llvm/llvm-project/commit/5b2d8503d1d4b925e30fd2b91f97bfd625f03157.diff
LOG: [mlir][Linalg] NFC - Expose helper function `substituteMin`.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 54a4aec9f867..6d428384080b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -893,6 +893,30 @@ struct AffineMinSCFCanonicalizationPattern
PatternRewriter &rewriter) const override;
};
+ /// Helper struct to return the results of `substituteMin`.
+struct AffineMapAndOperands {
+ AffineMap map;
+ SmallVector<Value> dims;
+ SmallVector<Value> symbols;
+};
+/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
+/// induction variables by new expressions involving the lower or upper bound:
+/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
+/// replaced by the loop upper bound.
+/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
+/// replaced by the loop lower bound.
+/// All loop induction variables are iteratively replaced, unless a
+/// `substituteOperation` hook is passed to more finely determine which
+/// operations are substituted.
+/// This is used as an intermediate step in computing bounding boxes and
+/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
+/// positive values (positive orthant assumptions).
+/// Return a new AffineMap, dims and symbols that have been canonicalized and
+/// simplified.
+AffineMapAndOperands substituteMin(
+ AffineMinOp affineMinOp,
+ llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
+
/// Converts Convolution op into vector contraction.
///
/// Conversion expects ConvOp to have dimensions marked in the *mask* as
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c2e52c63eabd..fef6dd8f996f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -536,8 +536,10 @@ static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr,
/// Traverse the `dims` and substitute known min or max expressions in place of
/// induction variables in `exprs`.
-static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
- SmallVectorImpl<Value> &symbols) {
+static AffineMap substitute(
+ AffineMap map, SmallVectorImpl<Value> &dims,
+ SmallVectorImpl<Value> &symbols,
+ llvm::function_ref<bool(Operation *)> substituteOperation = nullptr) {
auto exprs = llvm::to_vector<4>(map.getResults());
for (AffineExpr &expr : exprs) {
bool substituted = true;
@@ -549,17 +551,19 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
AffineExpr substitutedExpr;
if (auto forOp = scf::getForInductionVarOwner(dim))
- substitutedExpr = substituteLoopInExpr(
- expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
- forOp.step(), dims, symbols);
+ if (!substituteOperation || substituteOperation(forOp))
+ substitutedExpr = substituteLoopInExpr(
+ expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
+ forOp.step(), dims, symbols);
if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
- for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
- ++idx)
- substitutedExpr = substituteLoopInExpr(
- expr, dimExpr, parallelForOp.lowerBound()[idx],
- parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
- dims, symbols);
+ if (!substituteOperation || substituteOperation(parallelForOp))
+ for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
+ ++idx)
+ substitutedExpr = substituteLoopInExpr(
+ expr, dimExpr, parallelForOp.lowerBound()[idx],
+ parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
+ dims, symbols);
if (!substitutedExpr)
continue;
@@ -578,6 +582,9 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
exprs.front().getContext());
LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n");
+ LLVM_DEBUG(DBGS() << "Operands:\n");
+ for (Value v : operands)
+ LLVM_DEBUG(DBGS() << v << "\n");
// Pull in affine.apply operations and compose them fully into the
// result.
@@ -596,14 +603,38 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
}
+/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
+/// induction variables by new expressions involving the lower or upper bound:
+/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
+/// replaced by the loop upper bound.
+/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
+/// replaced by the loop lower bound.
+/// All loop induction variables are iteratively replaced, unless a
+/// `substituteOperation` hook is passed to more finely determine which
+/// operations are substituted.
+/// This is used as an intermediate step in computing bounding boxes and
+/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
+/// positive values (positive orthant assumptions).
+/// Return a new AffineMap, dims and symbols that have been canonicalized and
+/// simplified.
+AffineMapAndOperands mlir::linalg::substituteMin(
+ AffineMinOp affineMinOp,
+ llvm::function_ref<bool(Operation *)> substituteOperation) {
+ AffineMapAndOperands res{affineMinOp.getAffineMap(),
+ SmallVector<Value>(affineMinOp.getDimOperands()),
+ SmallVector<Value>(affineMinOp.getSymbolOperands())};
+ res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
+ substituteOperation);
+ return res;
+}
+
LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
AffineMinOp minOp, PatternRewriter &rewriter) const {
LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
<< "\n");
- SmallVector<Value, 4> dims(minOp.getDimOperands()),
- symbols(minOp.getSymbolOperands());
- AffineMap map = substitute(minOp.getAffineMap(), dims, symbols);
+ auto affineMapAndOperands = substituteMin(minOp);
+ AffineMap map = affineMapAndOperands.map;
LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
@@ -638,8 +669,8 @@ LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
} else {
auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
- SmallVector<Value, 4> resultOperands = dims;
- resultOperands.append(symbols.begin(), symbols.end());
+ SmallVector<Value> resultOperands = affineMapAndOperands.dims;
+ llvm::append_range(resultOperands, affineMapAndOperands.symbols);
canonicalizeMapAndOperands(&resultMap, &resultOperands);
resultMap = simplifyAffineMap(resultMap);
rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
More information about the Mlir-commits
mailing list