[Mlir-commits] [mlir] 654e8aa - [MLIR] Consider AffineIfOp when getting the index set of an Op wrapped in nested loops

Uday Bondhugula llvmlistbot at llvm.org
Sat Aug 8 14:46:35 PDT 2020


Author: Vincent Zhao
Date: 2020-08-09T03:16:03+05:30
New Revision: 654e8aadfdda97524c463bcf3552d2ecf2feda93

URL: https://github.com/llvm/llvm-project/commit/654e8aadfdda97524c463bcf3552d2ecf2feda93
DIFF: https://github.com/llvm/llvm-project/commit/654e8aadfdda97524c463bcf3552d2ecf2feda93.diff

LOG: [MLIR] Consider AffineIfOp when getting the index set of an Op wrapped in nested loops

This diff attempts to resolve the TODO in `getOpIndexSet` (formerly
known as `getInstIndexSet`), which states "Add support to handle IfInsts
surronding `op`".

Major changes in this diff:

1. Overload `getIndexSet`. The overloaded version considers both
`AffineForOp` and `AffineIfOp`.
2. The `getInstIndexSet` is updated accordingly: its name is changed to
`getOpIndexSet` and its implementation is based on a new API `getIVs`
instead of `getLoopIVs`.
3. Add `addAffineIfOpDomain` to `FlatAffineConstraints`, which extracts
new constraints from the integer set of `AffineIfOp` and merges it to
the current constraint system.
4. Update how a `Value` is determined as dim or symbol for
`ValuePositionMap` in `buildDimAndSymbolPositionMaps`.

Differential Revision: https://reviews.llvm.org/D84698

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineAnalysis.h
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/include/mlir/Analysis/Utils.h
    mlir/lib/Analysis/AffineAnalysis.cpp
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/Transforms/memref-dependence-check.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h
index 3322f3da6a09..c3aaa40bda9a 100644
--- a/mlir/include/mlir/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -34,12 +34,14 @@ void getReachableAffineApplyOps(ArrayRef<Value> operands,
                                 SmallVectorImpl<Operation *> &affineApplyOps);
 
 /// Builds a system of constraints with dimensional identifiers corresponding to
-/// the loop IVs of the forOps appearing in that order. Bounds of the loop are
-/// used to add appropriate inequalities. Any symbols founds in the bound
-/// operands are added as symbols in the system. Returns failure for the yet
-/// unimplemented cases.
+/// the loop IVs of the forOps and AffineIfOp's operands appearing in
+/// that order. Bounds of the loop are used to add appropriate inequalities.
+/// Constraints from the index sets of AffineIfOp are also added. Any symbols
+/// founds in the bound operands are added as symbols in the system. Returns
+/// failure for the yet unimplemented cases. `ops` accepts both AffineForOp and
+/// AffineIfOp.
 //  TODO: handle non-unit strides.
-LogicalResult getIndexSet(MutableArrayRef<AffineForOp> forOps,
+LogicalResult getIndexSet(MutableArrayRef<Operation *> ops,
                           FlatAffineConstraints *domain);
 
 /// Encapsulates a memref load or store access information.

diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 0424e0bb7d33..14a19827902f 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -21,6 +21,7 @@ namespace mlir {
 
 class AffineCondition;
 class AffineForOp;
+class AffineIfOp;
 class AffineMap;
 class AffineValueMap;
 class IntegerSet;
@@ -215,6 +216,15 @@ class FlatAffineConstraints {
   //  TODO: add support for non-unit strides.
   LogicalResult addAffineForOpDomain(AffineForOp forOp);
 
+  /// Adds constraints imposed by the `affine.if` operation. These constraints
+  /// are collected from the IntegerSet attached to the given `affine.if`
+  /// instance argument (`ifOp`). It is asserted that:
+  /// 1) The IntegerSet of the given `affine.if` instance should not contain
+  /// semi-affine expressions,
+  /// 2) The columns of the constraint system created from `ifOp` should match
+  /// the columns in the current one regarding numbers and values.
+  void addAffineIfOpDomain(AffineIfOp ifOp);
+
   /// Adds a lower or an upper bound for the identifier at the specified
   /// position with constraints being drawn from the specified bound map and
   /// operands. If `eq` is true, add a single equality equal to the bound map's

diff  --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index 943a2f125b7d..b502d909d5c0 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -39,6 +39,12 @@ class Value;
 //  TODO: handle 'affine.if' ops.
 void getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
 
+/// Populates 'ops' with IVs of the loops surrounding `op`, along with
+/// `affine.if` operations interleaved between these loops, ordered from the
+/// outermost `affine.for` or `affine.if` operation to the innermost one.
+void getEnclosingAffineForAndIfOps(Operation &op,
+                                   SmallVectorImpl<Operation *> *ops);
+
 /// Returns the nesting depth of this operation, i.e., the number of loops
 /// surrounding this operation.
 unsigned getNestingDepth(Operation *op);

diff  --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 5bdfcda8533a..aeaa5cb482ee 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -85,33 +85,42 @@ void mlir::getReachableAffineApplyOps(
 // FlatAffineConstraints. (For eg., by using iv - lb % step = 0 and/or by
 // introducing a method in FlatAffineConstraints setExprStride(ArrayRef<int64_t>
 // expr, int64_t stride)
-LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps,
+LogicalResult mlir::getIndexSet(MutableArrayRef<Operation *> ops,
                                 FlatAffineConstraints *domain) {
   SmallVector<Value, 4> indices;
+  SmallVector<AffineForOp, 8> forOps;
+
+  for (Operation *op : ops) {
+    assert((isa<AffineForOp, AffineIfOp>(op)) &&
+           "ops should have either AffineForOp or AffineIfOp");
+    if (AffineForOp forOp = dyn_cast<AffineForOp>(op))
+      forOps.push_back(forOp);
+  }
   extractForInductionVars(forOps, &indices);
   // Reset while associated Values in 'indices' to the domain.
   domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
-  for (auto forOp : forOps) {
+  for (Operation *op : ops) {
     // Add constraints from forOp's bounds.
-    if (failed(domain->addAffineForOpDomain(forOp)))
-      return failure();
+    if (AffineForOp forOp = dyn_cast<AffineForOp>(op)) {
+      if (failed(domain->addAffineForOpDomain(forOp)))
+        return failure();
+    } else if (AffineIfOp ifOp = dyn_cast<AffineIfOp>(op)) {
+      domain->addAffineIfOpDomain(ifOp);
+    }
   }
   return success();
 }
 
-// Computes the iteration domain for 'opInst' and populates 'indexSet', which
-// encapsulates the constraints involving loops surrounding 'opInst' and
-// potentially involving any Function symbols. The dimensional identifiers in
-// 'indexSet' correspond to the loops surrounding 'op' from outermost to
-// innermost.
-// TODO: Add support to handle IfInsts surrounding 'op'.
-static LogicalResult getInstIndexSet(Operation *op,
-                                     FlatAffineConstraints *indexSet) {
-  // TODO: Extend this to gather enclosing IfInsts and consider
-  // factoring it out into a utility function.
-  SmallVector<AffineForOp, 4> loops;
-  getLoopIVs(*op, &loops);
-  return getIndexSet(loops, indexSet);
+/// Computes the iteration domain for 'op' and populates 'indexSet', which
+/// encapsulates the constraints involving loops surrounding 'op' and
+/// potentially involving any Function symbols. The dimensional identifiers in
+/// 'indexSet' correspond to the loops surrounding 'op' from outermost to
+/// innermost.
+static LogicalResult getOpIndexSet(Operation *op,
+                                   FlatAffineConstraints *indexSet) {
+  SmallVector<Operation *, 4> ops;
+  getEnclosingAffineForAndIfOps(*op, &ops);
+  return getIndexSet(ops, indexSet);
 }
 
 namespace {
@@ -209,32 +218,83 @@ static void buildDimAndSymbolPositionMaps(
     const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
     const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
     FlatAffineConstraints *dependenceConstraints) {
-  auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc) {
+
+  // IsDimState is a tri-state boolean. It is used to distinguish three
+  // 
diff erent cases of the values passed to updateValuePosMap.
+  // - When it is TRUE, we are certain that all values are dim values.
+  // - When it is FALSE, we are certain that all values are symbol values.
+  // - When it is UNKNOWN, we need to further check whether the value is from a
+  // loop IV to determine its type (dim or symbol).
+
+  // We need this enumeration because sometimes we cannot determine whether a
+  // Value is a symbol or a dim by the information from the Value itself. If a
+  // Value appears in an affine map of a loop, we can determine whether it is a
+  // dim or not by the function `isForInductionVar`. But when a Value is in the
+  // affine set of an if-statement, there is no way to identify its category
+  // (dim/symbol) by itself. Fortunately, the Values to be inserted into
+  // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such
+  // information of Value category: `srcDomain` and `dstDomain` organize Values
+  // by their category, such that the position of each Value stored in
+  // `srcDomain` and `dstDomain` marks which category that a Value belongs to.
+  // Therefore, we can separate Values into dim and symbol groups before passing
+  // them to the function `updateValuePosMap`. Specifically, when passing the
+  // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE.
+  // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are
+  // not explicitly categorized into dim or symbol, and we have to rely on
+  // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in
+  // this case.
+  enum IsDimState { TRUE, FALSE, UNKNOWN };
+
+  // This function places each given Value (in `values`) under a respective
+  // category in `valuePosMap`. Specifically, the placement rules are:
+  // 1) If `isDim` is FALSE, then every value in `values` are inserted into
+  // `valuePosMap` as symbols.
+  // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an
+  // induction variable of a for-loop, we treat it as symbol as well.
+  // 3) For other cases, we decide whether to add a value to the `src` or the
+  // `dst` section of the dim category simply by the boolean value `isSrc`.
+  auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc,
+                               IsDimState isDim) {
     for (unsigned i = 0, e = values.size(); i < e; ++i) {
       auto value = values[i];
-      if (!isForInductionVar(values[i])) {
-        assert(isValidSymbol(values[i]) &&
+      if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) {
+        assert(isValidSymbol(value) &&
                "access operand has to be either a loop IV or a symbol");
         valuePosMap->addSymbolValue(value);
-      } else if (isSrc) {
-        valuePosMap->addSrcValue(value);
       } else {
-        valuePosMap->addDstValue(value);
+        if (isSrc)
+          valuePosMap->addSrcValue(value);
+        else
+          valuePosMap->addDstValue(value);
       }
     }
   };
 
-  SmallVector<Value, 4> srcValues, destValues;
-  srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues);
-  dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues);
-  // Update value position map with identifiers from src iteration domain.
-  updateValuePosMap(srcValues, /*isSrc=*/true);
-  // Update value position map with identifiers from dst iteration domain.
-  updateValuePosMap(destValues, /*isSrc=*/false);
+  // Collect values from the src and dst domains. For each domain, we separate
+  // the collected values into dim and symbol parts.
+  SmallVector<Value, 4> srcDimValues, dstDimValues, srcSymbolValues,
+      dstSymbolValues;
+  srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcDimValues);
+  dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstDimValues);
+  srcDomain.getIdValues(srcDomain.getNumDimIds(),
+                        srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
+  dstDomain.getIdValues(dstDomain.getNumDimIds(),
+                        dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
+
+  // Update value position map with dim values from src iteration domain.
+  updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE);
+  // Update value position map with dim values from dst iteration domain.
+  updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE);
+  // Update value position map with symbols from src iteration domain.
+  updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE);
+  // Update value position map with symbols from dst iteration domain.
+  updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE);
   // Update value position map with identifiers from src access function.
-  updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true);
+  updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true,
+                    /*isDim=*/UNKNOWN);
   // Update value position map with identifiers from dst access function.
-  updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
+  updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false,
+                    /*isDim=*/UNKNOWN);
 }
 
 // Sets up dependence constraints columns appropriately, in the format:
@@ -270,24 +330,33 @@ static void initDependenceConstraints(
   dependenceConstraints->setIdValues(
       srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
 
-  // Set values for the symbolic identifier dimensions.
-  auto setSymbolIds = [&](ArrayRef<Value> values) {
+  // Set values for the symbolic identifier dimensions. `isSymbolDetermined`
+  // indicates whether we are certain that the `values` passed in are all
+  // symbols. If `isSymbolDetermined` is true, then we treat every Value in
+  // `values` as a symbol; otherwise, we let the function `isForInductionVar` to
+  // distinguish whether a Value in `values` is a symbol or not.
+  auto setSymbolIds = [&](ArrayRef<Value> values,
+                          bool isSymbolDetermined = true) {
     for (auto value : values) {
-      if (!isForInductionVar(value)) {
+      if (isSymbolDetermined || !isForInductionVar(value)) {
         assert(isValidSymbol(value) && "expected symbol");
         dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
       }
     }
   };
 
-  setSymbolIds(srcAccessMap.getOperands());
-  setSymbolIds(dstAccessMap.getOperands());
+  // We are uncertain about whether all operands in `srcAccessMap` and
+  // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false.
+  setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false);
+  setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false);
 
   SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
   srcDomain.getIdValues(srcDomain.getNumDimIds(),
                         srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
   dstDomain.getIdValues(dstDomain.getNumDimIds(),
                         dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
+  // Since we only take symbol Values out of `srcDomain` and `dstDomain`,
+  // `isSymbolDetermined` is kept to its default value: true.
   setSymbolIds(srcSymbolValues);
   setSymbolIds(dstSymbolValues);
 
@@ -530,22 +599,50 @@ getNumCommonLoops(const FlatAffineConstraints &srcDomain,
   return numCommonLoops;
 }
 
-// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
+/// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
 static Block *getCommonBlock(const MemRefAccess &srcAccess,
                              const MemRefAccess &dstAccess,
                              const FlatAffineConstraints &srcDomain,
                              unsigned numCommonLoops) {
+  // Get the chain of ancestor blocks to the given `MemRefAccess` instance. The
+  // search terminates when either an op with the `AffineScope` trait or
+  // `endBlock` is reached.
+  auto getChainOfAncestorBlocks = [&](const MemRefAccess &access,
+                                      SmallVector<Block *, 4> &ancestorBlocks,
+                                      Block *endBlock = nullptr) {
+    Block *currBlock = access.opInst->getBlock();
+    // Loop terminates when the currBlock is nullptr or equals to the endBlock,
+    // or its parent operation holds an affine scope.
+    while (currBlock && currBlock != endBlock &&
+           !currBlock->getParentOp()->hasTrait<OpTrait::AffineScope>()) {
+      ancestorBlocks.push_back(currBlock);
+      currBlock = currBlock->getParentOp()->getBlock();
+    }
+  };
+
   if (numCommonLoops == 0) {
-    auto *block = srcAccess.opInst->getBlock();
+    Block *block = srcAccess.opInst->getBlock();
     while (!llvm::isa<FuncOp>(block->getParentOp())) {
       block = block->getParentOp()->getBlock();
     }
     return block;
   }
-  auto commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
-  auto forOp = getForInductionVarOwner(commonForValue);
+  Value commonForIV = srcDomain.getIdValue(numCommonLoops - 1);
+  AffineForOp forOp = getForInductionVarOwner(commonForIV);
   assert(forOp && "commonForValue was not an induction variable");
-  return forOp.getBody();
+
+  // Find the closest common block including those in AffineIf.
+  SmallVector<Block *, 4> srcAncestorBlocks, dstAncestorBlocks;
+  getChainOfAncestorBlocks(srcAccess, srcAncestorBlocks, forOp.getBody());
+  getChainOfAncestorBlocks(dstAccess, dstAncestorBlocks, forOp.getBody());
+
+  Block *commonBlock = forOp.getBody();
+  for (int i = srcAncestorBlocks.size() - 1, j = dstAncestorBlocks.size() - 1;
+       i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j];
+       i--, j--)
+    commonBlock = srcAncestorBlocks[i];
+
+  return commonBlock;
 }
 
 // Returns true if the ancestor operation of 'srcAccess' appears before the
@@ -788,12 +885,12 @@ DependenceResult mlir::checkMemrefAccessDependence(
 
   // Get iteration domain for the 'srcAccess' operation.
   FlatAffineConstraints srcDomain;
-  if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain)))
+  if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain)))
     return DependenceResult::Failure;
 
   // Get iteration domain for 'dstAccess' operation.
   FlatAffineConstraints dstDomain;
-  if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain)))
+  if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain)))
     return DependenceResult::Failure;
 
   // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
@@ -814,7 +911,6 @@ DependenceResult mlir::checkMemrefAccessDependence(
   buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
                                 dstAccessMap, &valuePosMap,
                                 dependenceConstraints);
-
   initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
                             valuePosMap, dependenceConstraints);
 

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index c09a9af45712..d3a4c2ed216e 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -723,6 +723,18 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
                               /*eq=*/false, /*lower=*/false);
 }
 
+void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
+  // Create the base constraints from the integer set attached to ifOp.
+  FlatAffineConstraints cst(ifOp.getIntegerSet());
+
+  // Bind ids in the constraints to ifOp operands.
+  SmallVector<Value, 4> operands = ifOp.getOperands();
+  cst.setIdValues(0, cst.getNumDimAndSymbolIds(), operands);
+
+  // Merge the constraints from ifOp to the current domain.
+  mergeAndAlignIdsWithOther(0, &cst);
+}
+
 // Searches for a constraint with a non-zero coefficient at 'colIdx' in
 // equality (isEq=true) or inequality (isEq=false) constraints.
 // Returns true and sets row found in search in 'rowIdx'.

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index ead45491c159..b02212a09bba 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -44,6 +44,23 @@ void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
   std::reverse(loops->begin(), loops->end());
 }
 
+/// Populates 'ops' with IVs of the loops surrounding `op`, along with
+/// `affine.if` operations interleaved between these loops, ordered from the
+/// outermost `affine.for` operation to the innermost one.
+void mlir::getEnclosingAffineForAndIfOps(Operation &op,
+                                         SmallVectorImpl<Operation *> *ops) {
+  ops->clear();
+  Operation *currOp = op.getParentOp();
+
+  // Traverse up the hierarchy collecting all `affine.for` and `affine.if`
+  // operations.
+  while (currOp && (isa<AffineIfOp, AffineForOp>(currOp))) {
+    ops->push_back(currOp);
+    currOp = currOp->getParentOp();
+  }
+  std::reverse(ops->begin(), ops->end());
+}
+
 // Populates 'cst' with FlatAffineConstraints which represent slice bounds.
 LogicalResult
 ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
index 5223cb83b10a..1889711cbf7a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
@@ -289,7 +289,11 @@ mlir::tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
   extractForInductionVars(input, &origLoopIVs);
 
   FlatAffineConstraints cst;
-  getIndexSet(input, &cst);
+  SmallVector<Operation *, 8> ops;
+  ops.reserve(input.size());
+  for (AffineForOp forOp : input)
+    ops.push_back(forOp);
+  getIndexSet(ops, &cst);
   if (!cst.isHyperRectangular(0, width)) {
     rootAffineForOp.emitError("tiled code generation unimplemented for the "
                               "non-hyperrectangular case");

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index ecb478adbbdf..db6a071367d6 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -2335,7 +2335,11 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
   auto *context = loops[0].getContext();
 
   FlatAffineConstraints cst;
-  getIndexSet(loops, &cst);
+  SmallVector<Operation *, 8> ops;
+  ops.reserve(loops.size());
+  for (AffineForOp forOp : loops)
+    ops.push_back(forOp);
+  getIndexSet(ops, &cst);
 
   // Remove constraints that are independent of these loop IVs.
   cst.removeIndependentConstraints(/*pos=*/0, /*num=*/loops.size());
@@ -2419,7 +2423,8 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest,
                  << "[tile separation] non-unit stride not implemented\n");
       return failure();
     }
-    getIndexSet({loop}, &cst);
+    SmallVector<Operation *, 1> loopOp{loop.getOperation()};
+    getIndexSet(loopOp, &cst);
     // We will mark everything other than this loop IV as symbol for getting a
     // pair of <lb, ub> with a constant 
diff erence.
     cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - 1);

diff  --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir
index 154dcf79c114..fc1c72cf1664 100644
--- a/mlir/test/Transforms/memref-dependence-check.mlir
+++ b/mlir/test/Transforms/memref-dependence-check.mlir
@@ -904,3 +904,163 @@ func @test_dep_store_depth2_load_depth1() {
   }
   return
 }
+
+// -----
+
+// Test the case that `affine.if` changes the domain for both load/store simultaneously.
+#set = affine_set<(d0): (d0 - 50 >= 0)>
+
+// CHECK-LABEL: func @test_affine_for_if_same_block() {
+func @test_affine_for_if_same_block() {
+  %0 = alloc() : memref<100xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 100 {
+    affine.if #set(%i0) {
+      %1 = affine.load %0[%i0] : memref<100xf32>
+      // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+      // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+      // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+      // expected-remark at above {{dependence from 0 to 1 at depth 2 = true}}
+      affine.store %cf7, %0[%i0] : memref<100xf32>
+      // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+      // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+      // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+      // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+    }
+  }
+
+  return
+}
+
+// -----
+
+// Test the case that the domain that load/store access is completedly separated by `affine.if`.
+#set = affine_set<(d0): (d0 - 50 >= 0)>
+
+// CHECK-LABEL: func @test_affine_for_if_separated() {
+func @test_affine_for_if_separated() {
+  %0 = alloc() : memref<100xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 10 {
+    affine.if #set(%i0) {
+      %1 = affine.load %0[%i0] : memref<100xf32>
+      // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+      // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+      // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+      // expected-remark at above {{dependence from 0 to 1 at depth 2 = false}}
+    } else {
+      affine.store %cf7, %0[%i0] : memref<100xf32>
+      // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+      // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+      // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+      // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+    }
+  }
+
+  return
+}
+
+// -----
+
+// Test the case that the domain that load/store access has non-empty union set.
+#set1 = affine_set<(d0): (  d0 - 25 >= 0)>
+#set2 = affine_set<(d0): (- d0 + 75 >= 0)>
+
+// CHECK-LABEL: func @test_affine_for_if_partially_joined() {
+func @test_affine_for_if_partially_joined() {
+  %0 = alloc() : memref<100xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 100 {
+    affine.if #set1(%i0) {
+      %1 = affine.load %0[%i0] : memref<100xf32>
+      // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+      // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+      // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+      // expected-remark at above {{dependence from 0 to 1 at depth 2 = true}}
+    }
+    affine.if #set2(%i0) {
+      affine.store %cf7, %0[%i0] : memref<100xf32>
+      // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+      // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+      // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+      // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+    }
+  }
+
+  return
+}
+
+// -----
+
+// Test whether interleaved affine.for/affine.if can be properly handled.
+#set1 = affine_set<(d0): (d0 - 50 >= 0)>
+#set2 = affine_set<(d0, d1): (d0 - 75 >= 0, d1 - 50 >= 0)>
+
+// CHECK-LABEL: func @test_interleaved_affine_for_if() {
+func @test_interleaved_affine_for_if() {
+  %0 = alloc() : memref<100x100xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 100 {
+    affine.if #set1(%i0) {
+      affine.for %i1 = 0 to 100 {
+        %1 = affine.load %0[%i0, %i1] : memref<100x100xf32>
+        // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+        // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+        // expected-remark at above {{dependence from 0 to 0 at depth 3 = false}}
+        // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+        // expected-remark at above {{dependence from 0 to 1 at depth 2 = false}}
+        // expected-remark at above {{dependence from 0 to 1 at depth 3 = true}}
+
+        affine.if #set2(%i0, %i1) {
+          affine.store %cf7, %0[%i0, %i1] : memref<100x100xf32>
+          // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+          // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+          // expected-remark at above {{dependence from 1 to 0 at depth 3 = false}}
+          // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+          // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+          // expected-remark at above {{dependence from 1 to 1 at depth 3 = false}}
+        }
+      }
+    }
+  }
+
+  return
+}
+
+// -----
+
+// Test whether symbols can be handled .
+#set1 = affine_set<(d0)[s0]: (  d0 - s0 floordiv 2 >= 0)>
+#set2 = affine_set<(d0):     (- d0 +            51 >= 0)>
+
+// CHECK-LABEL: func @test_interleaved_affine_for_if() {
+func @test_interleaved_affine_for_if() {
+  %0 = alloc() : memref<101xf32>
+  %c0 = constant 0 : index
+  %N = dim %0, %c0 : memref<101xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 100 {
+    affine.if #set1(%i0)[%N] {
+      %1 = affine.load %0[%i0] : memref<101xf32>
+      // expected-remark at above {{dependence from 0 to 0 at depth 1 = false}}
+      // expected-remark at above {{dependence from 0 to 0 at depth 2 = false}}
+      // expected-remark at above {{dependence from 0 to 1 at depth 1 = false}}
+      // expected-remark at above {{dependence from 0 to 1 at depth 2 = true}}
+    }
+
+    affine.if #set2(%i0) {
+      affine.store %cf7, %0[%i0] : memref<101xf32>
+      // expected-remark at above {{dependence from 1 to 0 at depth 1 = false}}
+      // expected-remark at above {{dependence from 1 to 0 at depth 2 = false}}
+      // expected-remark at above {{dependence from 1 to 1 at depth 1 = false}}
+      // expected-remark at above {{dependence from 1 to 1 at depth 2 = false}}
+    }
+  }
+
+  return
+}


        


More information about the Mlir-commits mailing list