[Mlir-commits] [mlir] 654e8aa - [MLIR] Consider AffineIfOp when getting the index set of an Op wrapped in nested loops
Uday Bondhugula
llvmlistbot at llvm.org
Sat Aug 8 14:46:35 PDT 2020
Author: Vincent Zhao
Date: 2020-08-09T03:16:03+05:30
New Revision: 654e8aadfdda97524c463bcf3552d2ecf2feda93
URL: https://github.com/llvm/llvm-project/commit/654e8aadfdda97524c463bcf3552d2ecf2feda93
DIFF: https://github.com/llvm/llvm-project/commit/654e8aadfdda97524c463bcf3552d2ecf2feda93.diff
LOG: [MLIR] Consider AffineIfOp when getting the index set of an Op wrapped in nested loops
This diff attempts to resolve the TODO in `getOpIndexSet` (formerly
known as `getInstIndexSet`), which states "Add support to handle IfInsts
surronding `op`".
Major changes in this diff:
1. Overload `getIndexSet`. The overloaded version considers both
`AffineForOp` and `AffineIfOp`.
2. The `getInstIndexSet` is updated accordingly: its name is changed to
`getOpIndexSet` and its implementation is based on a new API `getIVs`
instead of `getLoopIVs`.
3. Add `addAffineIfOpDomain` to `FlatAffineConstraints`, which extracts
new constraints from the integer set of `AffineIfOp` and merges it to
the current constraint system.
4. Update how a `Value` is determined as dim or symbol for
`ValuePositionMap` in `buildDimAndSymbolPositionMaps`.
Differential Revision: https://reviews.llvm.org/D84698
Added:
Modified:
mlir/include/mlir/Analysis/AffineAnalysis.h
mlir/include/mlir/Analysis/AffineStructures.h
mlir/include/mlir/Analysis/Utils.h
mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Analysis/AffineStructures.cpp
mlir/lib/Analysis/Utils.cpp
mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/test/Transforms/memref-dependence-check.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h
index 3322f3da6a09..c3aaa40bda9a 100644
--- a/mlir/include/mlir/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -34,12 +34,14 @@ void getReachableAffineApplyOps(ArrayRef<Value> operands,
SmallVectorImpl<Operation *> &affineApplyOps);
/// Builds a system of constraints with dimensional identifiers corresponding to
-/// the loop IVs of the forOps appearing in that order. Bounds of the loop are
-/// used to add appropriate inequalities. Any symbols founds in the bound
-/// operands are added as symbols in the system. Returns failure for the yet
-/// unimplemented cases.
+/// the loop IVs of the forOps and AffineIfOp's operands appearing in
+/// that order. Bounds of the loop are used to add appropriate inequalities.
+/// Constraints from the index sets of AffineIfOp are also added. Any symbols
+/// founds in the bound operands are added as symbols in the system. Returns
+/// failure for the yet unimplemented cases. `ops` accepts both AffineForOp and
+/// AffineIfOp.
// TODO: handle non-unit strides.
-LogicalResult getIndexSet(MutableArrayRef<AffineForOp> forOps,
+LogicalResult getIndexSet(MutableArrayRef<Operation *> ops,
FlatAffineConstraints *domain);
/// Encapsulates a memref load or store access information.
diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 0424e0bb7d33..14a19827902f 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -21,6 +21,7 @@ namespace mlir {
class AffineCondition;
class AffineForOp;
+class AffineIfOp;
class AffineMap;
class AffineValueMap;
class IntegerSet;
@@ -215,6 +216,15 @@ class FlatAffineConstraints {
// TODO: add support for non-unit strides.
LogicalResult addAffineForOpDomain(AffineForOp forOp);
+ /// Adds constraints imposed by the `affine.if` operation. These constraints
+ /// are collected from the IntegerSet attached to the given `affine.if`
+ /// instance argument (`ifOp`). It is asserted that:
+ /// 1) The IntegerSet of the given `affine.if` instance should not contain
+ /// semi-affine expressions,
+ /// 2) The columns of the constraint system created from `ifOp` should match
+ /// the columns in the current one regarding numbers and values.
+ void addAffineIfOpDomain(AffineIfOp ifOp);
+
/// Adds a lower or an upper bound for the identifier at the specified
/// position with constraints being drawn from the specified bound map and
/// operands. If `eq` is true, add a single equality equal to the bound map's
diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index 943a2f125b7d..b502d909d5c0 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -39,6 +39,12 @@ class Value;
// TODO: handle 'affine.if' ops.
void getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
+/// Populates 'ops' with IVs of the loops surrounding `op`, along with
+/// `affine.if` operations interleaved between these loops, ordered from the
+/// outermost `affine.for` or `affine.if` operation to the innermost one.
+void getEnclosingAffineForAndIfOps(Operation &op,
+ SmallVectorImpl<Operation *> *ops);
+
/// Returns the nesting depth of this operation, i.e., the number of loops
/// surrounding this operation.
unsigned getNestingDepth(Operation *op);
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 5bdfcda8533a..aeaa5cb482ee 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -85,33 +85,42 @@ void mlir::getReachableAffineApplyOps(
// FlatAffineConstraints. (For eg., by using iv - lb % step = 0 and/or by
// introducing a method in FlatAffineConstraints setExprStride(ArrayRef<int64_t>
// expr, int64_t stride)
-LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps,
+LogicalResult mlir::getIndexSet(MutableArrayRef<Operation *> ops,
FlatAffineConstraints *domain) {
SmallVector<Value, 4> indices;
+ SmallVector<AffineForOp, 8> forOps;
+
+ for (Operation *op : ops) {
+ assert((isa<AffineForOp, AffineIfOp>(op)) &&
+ "ops should have either AffineForOp or AffineIfOp");
+ if (AffineForOp forOp = dyn_cast<AffineForOp>(op))
+ forOps.push_back(forOp);
+ }
extractForInductionVars(forOps, &indices);
// Reset while associated Values in 'indices' to the domain.
domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
- for (auto forOp : forOps) {
+ for (Operation *op : ops) {
// Add constraints from forOp's bounds.
- if (failed(domain->addAffineForOpDomain(forOp)))
- return failure();
+ if (AffineForOp forOp = dyn_cast<AffineForOp>(op)) {
+ if (failed(domain->addAffineForOpDomain(forOp)))
+ return failure();
+ } else if (AffineIfOp ifOp = dyn_cast<AffineIfOp>(op)) {
+ domain->addAffineIfOpDomain(ifOp);
+ }
}
return success();
}
-// Computes the iteration domain for 'opInst' and populates 'indexSet', which
-// encapsulates the constraints involving loops surrounding 'opInst' and
-// potentially involving any Function symbols. The dimensional identifiers in
-// 'indexSet' correspond to the loops surrounding 'op' from outermost to
-// innermost.
-// TODO: Add support to handle IfInsts surrounding 'op'.
-static LogicalResult getInstIndexSet(Operation *op,
- FlatAffineConstraints *indexSet) {
- // TODO: Extend this to gather enclosing IfInsts and consider
- // factoring it out into a utility function.
- SmallVector<AffineForOp, 4> loops;
- getLoopIVs(*op, &loops);
- return getIndexSet(loops, indexSet);
+/// Computes the iteration domain for 'op' and populates 'indexSet', which
+/// encapsulates the constraints involving loops surrounding 'op' and
+/// potentially involving any Function symbols. The dimensional identifiers in
+/// 'indexSet' correspond to the loops surrounding 'op' from outermost to
+/// innermost.
+static LogicalResult getOpIndexSet(Operation *op,
+ FlatAffineConstraints *indexSet) {
+ SmallVector<Operation *, 4> ops;
+ getEnclosingAffineForAndIfOps(*op, &ops);
+ return getIndexSet(ops, indexSet);
}
namespace {
@@ -209,32 +218,83 @@ static void buildDimAndSymbolPositionMaps(
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
FlatAffineConstraints *dependenceConstraints) {
- auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc) {
+
+ // IsDimState is a tri-state boolean. It is used to distinguish three
+ //
diff erent cases of the values passed to updateValuePosMap.
+ // - When it is TRUE, we are certain that all values are dim values.
+ // - When it is FALSE, we are certain that all values are symbol values.
+ // - When it is UNKNOWN, we need to further check whether the value is from a
+ // loop IV to determine its type (dim or symbol).
+
+ // We need this enumeration because sometimes we cannot determine whether a
+ // Value is a symbol or a dim by the information from the Value itself. If a
+ // Value appears in an affine map of a loop, we can determine whether it is a
+ // dim or not by the function `isForInductionVar`. But when a Value is in the
+ // affine set of an if-statement, there is no way to identify its category
+ // (dim/symbol) by itself. Fortunately, the Values to be inserted into
+ // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such
+ // information of Value category: `srcDomain` and `dstDomain` organize Values
+ // by their category, such that the position of each Value stored in
+ // `srcDomain` and `dstDomain` marks which category that a Value belongs to.
+ // Therefore, we can separate Values into dim and symbol groups before passing
+ // them to the function `updateValuePosMap`. Specifically, when passing the
+ // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE.
+ // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are
+ // not explicitly categorized into dim or symbol, and we have to rely on
+ // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in
+ // this case.
+ enum IsDimState { TRUE, FALSE, UNKNOWN };
+
+ // This function places each given Value (in `values`) under a respective
+ // category in `valuePosMap`. Specifically, the placement rules are:
+ // 1) If `isDim` is FALSE, then every value in `values` are inserted into
+ // `valuePosMap` as symbols.
+ // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an
+ // induction variable of a for-loop, we treat it as symbol as well.
+ // 3) For other cases, we decide whether to add a value to the `src` or the
+ // `dst` section of the dim category simply by the boolean value `isSrc`.
+ auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc,
+ IsDimState isDim) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto value = values[i];
- if (!isForInductionVar(values[i])) {
- assert(isValidSymbol(values[i]) &&
+ if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) {
+ assert(isValidSymbol(value) &&
"access operand has to be either a loop IV or a symbol");
valuePosMap->addSymbolValue(value);
- } else if (isSrc) {
- valuePosMap->addSrcValue(value);
} else {
- valuePosMap->addDstValue(value);
+ if (isSrc)
+ valuePosMap->addSrcValue(value);
+ else
+ valuePosMap->addDstValue(value);
}
}
};
- SmallVector<Value, 4> srcValues, destValues;
- srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues);
- dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues);
- // Update value position map with identifiers from src iteration domain.
- updateValuePosMap(srcValues, /*isSrc=*/true);
- // Update value position map with identifiers from dst iteration domain.
- updateValuePosMap(destValues, /*isSrc=*/false);
+ // Collect values from the src and dst domains. For each domain, we separate
+ // the collected values into dim and symbol parts.
+ SmallVector<Value, 4> srcDimValues, dstDimValues, srcSymbolValues,
+ dstSymbolValues;
+ srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcDimValues);
+ dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstDimValues);
+ srcDomain.getIdValues(srcDomain.getNumDimIds(),
+ srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
+ dstDomain.getIdValues(dstDomain.getNumDimIds(),
+ dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
+
+ // Update value position map with dim values from src iteration domain.
+ updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE);
+ // Update value position map with dim values from dst iteration domain.
+ updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE);
+ // Update value position map with symbols from src iteration domain.
+ updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE);
+ // Update value position map with symbols from dst iteration domain.
+ updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE);
// Update value position map with identifiers from src access function.
- updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true);
+ updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true,
+ /*isDim=*/UNKNOWN);
// Update value position map with identifiers from dst access function.
- updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
+ updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false,
+ /*isDim=*/UNKNOWN);
}
// Sets up dependence constraints columns appropriately, in the format:
@@ -270,24 +330,33 @@ static void initDependenceConstraints(
dependenceConstraints->setIdValues(
srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
- // Set values for the symbolic identifier dimensions.
- auto setSymbolIds = [&](ArrayRef<Value> values) {
+ // Set values for the symbolic identifier dimensions. `isSymbolDetermined`
+ // indicates whether we are certain that the `values` passed in are all
+ // symbols. If `isSymbolDetermined` is true, then we treat every Value in
+ // `values` as a symbol; otherwise, we let the function `isForInductionVar` to
+ // distinguish whether a Value in `values` is a symbol or not.
+ auto setSymbolIds = [&](ArrayRef<Value> values,
+ bool isSymbolDetermined = true) {
for (auto value : values) {
- if (!isForInductionVar(value)) {
+ if (isSymbolDetermined || !isForInductionVar(value)) {
assert(isValidSymbol(value) && "expected symbol");
dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
}
}
};
- setSymbolIds(srcAccessMap.getOperands());
- setSymbolIds(dstAccessMap.getOperands());
+ // We are uncertain about whether all operands in `srcAccessMap` and
+ // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false.
+ setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false);
+ setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false);
SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
srcDomain.getIdValues(srcDomain.getNumDimIds(),
srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
dstDomain.getIdValues(dstDomain.getNumDimIds(),
dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
+ // Since we only take symbol Values out of `srcDomain` and `dstDomain`,
+ // `isSymbolDetermined` is kept to its default value: true.
setSymbolIds(srcSymbolValues);
setSymbolIds(dstSymbolValues);
@@ -530,22 +599,50 @@ getNumCommonLoops(const FlatAffineConstraints &srcDomain,
return numCommonLoops;
}
-// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
+/// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
static Block *getCommonBlock(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
const FlatAffineConstraints &srcDomain,
unsigned numCommonLoops) {
+ // Get the chain of ancestor blocks to the given `MemRefAccess` instance. The
+ // search terminates when either an op with the `AffineScope` trait or
+ // `endBlock` is reached.
+ auto getChainOfAncestorBlocks = [&](const MemRefAccess &access,
+ SmallVector<Block *, 4> &ancestorBlocks,
+ Block *endBlock = nullptr) {
+ Block *currBlock = access.opInst->getBlock();
+ // Loop terminates when the currBlock is nullptr or equals to the endBlock,
+ // or its parent operation holds an affine scope.
+ while (currBlock && currBlock != endBlock &&
+ !currBlock->getParentOp()->hasTrait<OpTrait::AffineScope>()) {
+ ancestorBlocks.push_back(currBlock);
+ currBlock = currBlock->getParentOp()->getBlock();
+ }
+ };
+
if (numCommonLoops == 0) {
- auto *block = srcAccess.opInst->getBlock();
+ Block *block = srcAccess.opInst->getBlock();
while (!llvm::isa<FuncOp>(block->getParentOp())) {
block = block->getParentOp()->getBlock();
}
return block;
}
- auto commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
- auto forOp = getForInductionVarOwner(commonForValue);
+ Value commonForIV = srcDomain.getIdValue(numCommonLoops - 1);
+ AffineForOp forOp = getForInductionVarOwner(commonForIV);
assert(forOp && "commonForValue was not an induction variable");
- return forOp.getBody();
+
+ // Find the closest common block including those in AffineIf.
+ SmallVector<Block *, 4> srcAncestorBlocks, dstAncestorBlocks;
+ getChainOfAncestorBlocks(srcAccess, srcAncestorBlocks, forOp.getBody());
+ getChainOfAncestorBlocks(dstAccess, dstAncestorBlocks, forOp.getBody());
+
+ Block *commonBlock = forOp.getBody();
+ for (int i = srcAncestorBlocks.size() - 1, j = dstAncestorBlocks.size() - 1;
+ i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j];
+ i--, j--)
+ commonBlock = srcAncestorBlocks[i];
+
+ return commonBlock;
}
// Returns true if the ancestor operation of 'srcAccess' appears before the
@@ -788,12 +885,12 @@ DependenceResult mlir::checkMemrefAccessDependence(
// Get iteration domain for the 'srcAccess' operation.
FlatAffineConstraints srcDomain;
- if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain)))
+ if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain)))
return DependenceResult::Failure;
// Get iteration domain for 'dstAccess' operation.
FlatAffineConstraints dstDomain;
- if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain)))
+ if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain)))
return DependenceResult::Failure;
// Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
@@ -814,7 +911,6 @@ DependenceResult mlir::checkMemrefAccessDependence(
buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
dstAccessMap, &valuePosMap,
dependenceConstraints);
-
initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
valuePosMap, dependenceConstraints);
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index c09a9af45712..d3a4c2ed216e 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -723,6 +723,18 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
/*eq=*/false, /*lower=*/false);
}
+void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
+ // Create the base constraints from the integer set attached to ifOp.
+ FlatAffineConstraints cst(ifOp.getIntegerSet());
+
+ // Bind ids in the constraints to ifOp operands.
+ SmallVector<Value, 4> operands = ifOp.getOperands();
+ cst.setIdValues(0, cst.getNumDimAndSymbolIds(), operands);
+
+ // Merge the constraints from ifOp to the current domain.
+ mergeAndAlignIdsWithOther(0, &cst);
+}
+
// Searches for a constraint with a non-zero coefficient at 'colIdx' in
// equality (isEq=true) or inequality (isEq=false) constraints.
// Returns true and sets row found in search in 'rowIdx'.
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index ead45491c159..b02212a09bba 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -44,6 +44,23 @@ void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
std::reverse(loops->begin(), loops->end());
}
+/// Populates 'ops' with IVs of the loops surrounding `op`, along with
+/// `affine.if` operations interleaved between these loops, ordered from the
+/// outermost `affine.for` operation to the innermost one.
+void mlir::getEnclosingAffineForAndIfOps(Operation &op,
+ SmallVectorImpl<Operation *> *ops) {
+ ops->clear();
+ Operation *currOp = op.getParentOp();
+
+ // Traverse up the hierarchy collecting all `affine.for` and `affine.if`
+ // operations.
+ while (currOp && (isa<AffineIfOp, AffineForOp>(currOp))) {
+ ops->push_back(currOp);
+ currOp = currOp->getParentOp();
+ }
+ std::reverse(ops->begin(), ops->end());
+}
+
// Populates 'cst' with FlatAffineConstraints which represent slice bounds.
LogicalResult
ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
index 5223cb83b10a..1889711cbf7a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
@@ -289,7 +289,11 @@ mlir::tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
extractForInductionVars(input, &origLoopIVs);
FlatAffineConstraints cst;
- getIndexSet(input, &cst);
+ SmallVector<Operation *, 8> ops;
+ ops.reserve(input.size());
+ for (AffineForOp forOp : input)
+ ops.push_back(forOp);
+ getIndexSet(ops, &cst);
if (!cst.isHyperRectangular(0, width)) {
rootAffineForOp.emitError("tiled code generation unimplemented for the "
"non-hyperrectangular case");
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index ecb478adbbdf..db6a071367d6 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -2335,7 +2335,11 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
auto *context = loops[0].getContext();
FlatAffineConstraints cst;
- getIndexSet(loops, &cst);
+ SmallVector<Operation *, 8> ops;
+ ops.reserve(loops.size());
+ for (AffineForOp forOp : loops)
+ ops.push_back(forOp);
+ getIndexSet(ops, &cst);
// Remove constraints that are independent of these loop IVs.
cst.removeIndependentConstraints(/*pos=*/0, /*num=*/loops.size());
@@ -2419,7 +2423,8 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest,
<< "[tile separation] non-unit stride not implemented\n");
return failure();
}
- getIndexSet({loop}, &cst);
+ SmallVector<Operation *, 1> loopOp{loop.getOperation()};
+ getIndexSet(loopOp, &cst);
// We will mark everything other than this loop IV as symbol for getting a
// pair of <lb, ub> with a constant
diff erence.
cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - 1);
diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir
index 154dcf79c114..fc1c72cf1664 100644
--- a/mlir/test/Transforms/memref-dependence-check.mlir
+++ b/mlir/test/Transforms/memref-dependence-check.mlir
@@ -904,3 +904,163 @@ func @test_dep_store_depth2_load_depth1() {
}
return
}
+
+// -----
+
+// Test the case that `affine.if` changes the domain for both load/store simultaneously.
+#set = affine_set<(d0): (d0 - 50 >= 0)>
+
+// CHECK-LABEL: func @test_affine_for_if_same_block() {
+func @test_affine_for_if_same_block() {
+ %0 = alloc() : memref<100xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 100 {
+ affine.if #set(%i0) {
+ %1 = affine.load %0[%i0] : memref<100xf32>
+ // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 2 = true}}
+ affine.store %cf7, %0[%i0] : memref<100xf32>
+ // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+ }
+ }
+
+ return
+}
+
+// -----
+
+// Test the case that the domain that load/store access is completedly separated by `affine.if`.
+#set = affine_set<(d0): (d0 - 50 >= 0)>
+
+// CHECK-LABEL: func @test_affine_for_if_separated() {
+func @test_affine_for_if_separated() {
+ %0 = alloc() : memref<100xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 10 {
+ affine.if #set(%i0) {
+ %1 = affine.load %0[%i0] : memref<100xf32>
+ // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 2 = false}}
+ } else {
+ affine.store %cf7, %0[%i0] : memref<100xf32>
+ // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+ }
+ }
+
+ return
+}
+
+// -----
+
+// Test the case that the domain that load/store access has non-empty union set.
+#set1 = affine_set<(d0): ( d0 - 25 >= 0)>
+#set2 = affine_set<(d0): (- d0 + 75 >= 0)>
+
+// CHECK-LABEL: func @test_affine_for_if_partially_joined() {
+func @test_affine_for_if_partially_joined() {
+ %0 = alloc() : memref<100xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 100 {
+ affine.if #set1(%i0) {
+ %1 = affine.load %0[%i0] : memref<100xf32>
+ // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 2 = true}}
+ }
+ affine.if #set2(%i0) {
+ affine.store %cf7, %0[%i0] : memref<100xf32>
+ // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+ }
+ }
+
+ return
+}
+
+// -----
+
+// Test whether interleaved affine.for/affine.if can be properly handled.
+#set1 = affine_set<(d0): (d0 - 50 >= 0)>
+#set2 = affine_set<(d0, d1): (d0 - 75 >= 0, d1 - 50 >= 0)>
+
+// CHECK-LABEL: func @test_interleaved_affine_for_if() {
+func @test_interleaved_affine_for_if() {
+ %0 = alloc() : memref<100x100xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 100 {
+ affine.if #set1(%i0) {
+ affine.for %i1 = 0 to 100 {
+ %1 = affine.load %0[%i0, %i1] : memref<100x100xf32>
+ // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 0 to 0 at depth 3 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 2 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 3 = true}}
+
+ affine.if #set2(%i0, %i1) {
+ affine.store %cf7, %0[%i0, %i1] : memref<100x100xf32>
+ // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 1 to 0 at depth 3 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 3 = false}}
+ }
+ }
+ }
+ }
+
+ return
+}
+
+// -----
+
+// Test whether symbols can be handled .
+#set1 = affine_set<(d0)[s0]: ( d0 - s0 floordiv 2 >= 0)>
+#set2 = affine_set<(d0): (- d0 + 51 >= 0)>
+
+// CHECK-LABEL: func @test_interleaved_affine_for_if() {
+func @test_interleaved_affine_for_if() {
+ %0 = alloc() : memref<101xf32>
+ %c0 = constant 0 : index
+ %N = dim %0, %c0 : memref<101xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 100 {
+ affine.if #set1(%i0)[%N] {
+ %1 = affine.load %0[%i0] : memref<101xf32>
+ // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 0 to 1 at depth 2 = true}}
+ }
+
+ affine.if #set2(%i0) {
+ affine.store %cf7, %0[%i0] : memref<101xf32>
+ // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+ // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+ }
+ }
+
+ return
+}
More information about the Mlir-commits
mailing list