[Mlir-commits] [mlir] 52d6c5d - [MLIR] Generalize Affine dependence analysis using Affine Relations

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 15 23:50:56 PDT 2021


Author: Groverkss
Date: 2021-10-16T12:16:42+05:30
New Revision: 52d6c5df85faa7bfeb1887a75e3fbdb22efaaf94

URL: https://github.com/llvm/llvm-project/commit/52d6c5df85faa7bfeb1887a75e3fbdb22efaaf94
DIFF: https://github.com/llvm/llvm-project/commit/52d6c5df85faa7bfeb1887a75e3fbdb22efaaf94.diff

LOG: [MLIR] Generalize Affine dependence analysis using Affine Relations

This patch removes code very specific to affine dependence analysis and
refactors it as a FlatAfffineRelation.

A FlatAffineRelation represents a set of ordered pairs (domain -> range) where
"domain" and "range" are tuples of identifiers. These relations are used to
represent an "access relation" for memory access on a memref.  An access
relation maps elements of an iteration domain to the element(s) of an array
domain accessed by that iteration of the associated statement through some
array reference.  The dependence relation representing the dependence
constraints between two memory accesses can be built by composing the access
relation of the destination access by the inverse of the access relation of
source access.

This patch does not change the functionality of the existing dependence
analysis in checkMemrefAccessDependence, but refactors it to use
FlatAffineRelations to deduplicate code and enable code reuse for future
development of features like scheduling, value-based dependence analysis, etc.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D110563

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineAnalysis.h
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/lib/Analysis/AffineAnalysis.cpp
    mlir/lib/Analysis/AffineStructures.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h
index 1e5c5b8f87c6..98b02609724a 100644
--- a/mlir/include/mlir/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -25,6 +25,7 @@ namespace mlir {
 class AffineApplyOp;
 class AffineForOp;
 class AffineValueMap;
+class FlatAffineRelation;
 class FlatAffineValueConstraints;
 class Operation;
 
@@ -89,6 +90,30 @@ struct MemRefAccess {
   // Returns true if this access is of a store op.
   bool isStore() const;
 
+  /// Creates an access relation for the access. An access relation maps
+  /// elements of an iteration domain to the element(s) of an array domain
+  /// accessed by that iteration of the associated statement through some array
+  /// reference. For example, given the MLIR code:
+  ///
+  /// affine.for %i0 = 0 to 10 {
+  ///   affine.for %i1 = 0 to 10 {
+  ///     %a = affine.load %arr[%i0 + %i1, %i0 + 2 * %i1] : memref<100x100xf32>
+  ///   }
+  /// }
+  ///
+  /// The access relation, assuming that the memory locations for %arr are
+  /// represented as %m0, %m1 would be:
+  ///
+  ///   (%i0, %i1) -> (%m0, %m1)
+  ///   %m0 = %i0 + %i1
+  ///   %m1 = %i0 + 2 * %i1
+  ///   0  <= %i0 < 10
+  ///   0  <= %i1 < 10
+  ///
+  /// Returns failure for yet unimplemented/unsupported cases (see docs of
+  /// mlir::getIndexSet and mlir::getRelationFromMap for these cases).
+  LogicalResult getAccessRelation(FlatAffineRelation &accessRel) const;
+
   /// Populates 'accessMap' with composition of AffineApplyOps reachable from
   /// 'indices'.
   void getAccessMap(AffineValueMap *accessMap) const;

diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 450fa0f53b44..2c2344145ffc 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -426,6 +426,10 @@ class FlatAffineConstraints {
   /// O(VC) time.
   void removeRedundantConstraints();
 
+  /// Converts identifiers in the column range [idStart, idLimit) to local
+  /// variables.
+  void convertDimToLocal(unsigned dimStart, unsigned dimLimit);
+
   /// Merge local ids of `this` and `other`. This is done by appending local ids
   /// of `other` to `this` and inserting local ids of `this` to `other` at start
   /// of its local ids.
@@ -581,6 +585,16 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
                                        numLocals + 1,
                                    numDims, numSymbols, numLocals, valArgs) {}
 
+  FlatAffineValueConstraints(const FlatAffineConstraints &fac,
+                             ArrayRef<Optional<Value>> valArgs = {})
+      : FlatAffineConstraints(fac) {
+    assert(valArgs.empty() || valArgs.size() == numIds);
+    if (valArgs.empty())
+      values.resize(numIds, None);
+    else
+      values.append(valArgs.begin(), valArgs.end());
+  }
+
   /// Create a flat affine constraint system from an AffineValueMap or a list of
   /// these. The constructed system will only include equalities.
   explicit FlatAffineValueConstraints(const AffineValueMap &avm);
@@ -721,7 +735,8 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
   using FlatAffineConstraints::insertDimId;
   unsigned insertSymbolId(unsigned pos, ValueRange vals);
   using FlatAffineConstraints::insertSymbolId;
-  unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
+  virtual unsigned insertId(IdKind kind, unsigned pos,
+                            unsigned num = 1) override;
   unsigned insertId(IdKind kind, unsigned pos, ValueRange vals);
 
   /// Append identifiers of the specified kind after the last identifier of that
@@ -882,7 +897,7 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
   /// Removes identifiers in the column range [idStart, idLimit), and copies any
   /// remaining valid data into place, updates member variables, and resizes
   /// arrays as needed.
-  void removeIdRange(unsigned idStart, unsigned idLimit) override;
+  virtual void removeIdRange(unsigned idStart, unsigned idLimit) override;
 
   /// Eliminates the identifier at the specified position using Fourier-Motzkin
   /// variable elimination, but uses Gaussian elimination if there is an
@@ -901,6 +916,83 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
   SmallVector<Optional<Value>, 8> values;
 };
 
+/// A FlatAffineRelation represents a set of ordered pairs (domain -> range)
+/// where "domain" and "range" are tuples of identifiers. The relation is
+/// represented as a FlatAffineValueConstraints with separation of dimension
+/// identifiers into domain and  range. The identifiers are stored as:
+/// [domainIds, rangeIds, symbolIds, localIds, constant].
+class FlatAffineRelation : public FlatAffineValueConstraints {
+public:
+  FlatAffineRelation(unsigned numReservedInequalities,
+                     unsigned numReservedEqualities, unsigned numReservedCols,
+                     unsigned numDomainDims, unsigned numRangeDims,
+                     unsigned numSymbols, unsigned numLocals,
+                     ArrayRef<Optional<Value>> valArgs = {})
+      : FlatAffineValueConstraints(
+            numReservedInequalities, numReservedEqualities, numReservedCols,
+            numDomainDims + numRangeDims, numSymbols, numLocals, valArgs),
+        numDomainDims(numDomainDims), numRangeDims(numRangeDims) {}
+
+  FlatAffineRelation(unsigned numDomainDims = 0, unsigned numRangeDims = 0,
+                     unsigned numSymbols = 0, unsigned numLocals = 0)
+      : FlatAffineValueConstraints(numDomainDims + numRangeDims, numSymbols,
+                                   numLocals),
+        numDomainDims(numDomainDims), numRangeDims(numRangeDims) {}
+
+  FlatAffineRelation(unsigned numDomainDims, unsigned numRangeDims,
+                     FlatAffineValueConstraints &fac)
+      : FlatAffineValueConstraints(fac), numDomainDims(numDomainDims),
+        numRangeDims(numRangeDims) {}
+
+  FlatAffineRelation(unsigned numDomainDims, unsigned numRangeDims,
+                     FlatAffineConstraints &fac)
+      : FlatAffineValueConstraints(fac), numDomainDims(numDomainDims),
+        numRangeDims(numRangeDims) {}
+
+  /// Returns a set corresponding to the domain/range of the affine relation.
+  FlatAffineValueConstraints getDomainSet() const;
+  FlatAffineValueConstraints getRangeSet() const;
+
+  /// Returns the number of identifiers corresponding to domain/range of
+  /// relation.
+  inline unsigned getNumDomainDims() const { return numDomainDims; }
+  inline unsigned getNumRangeDims() const { return numRangeDims; }
+
+  /// Given affine relation `other: (domainOther -> rangeOther)`, this operation
+  /// takes the composition of `other` on `this: (domainThis -> rangeThis)`.
+  /// The resulting relation represents tuples of the form: `domainOther ->
+  /// rangeThis`.
+  void compose(const FlatAffineRelation &other);
+
+  /// Swap domain and range of the relation.
+  /// `(domain -> range)` is converted to `(range -> domain)`.
+  void inverse();
+
+  /// Insert `num` identifiers of the specified kind after the `pos` identifier
+  /// of that kind. The coefficient columns corresponding to the added
+  /// identifiers are initialized to zero.
+  void insertDomainId(unsigned pos, unsigned num = 1);
+  void insertRangeId(unsigned pos, unsigned num = 1);
+
+  /// Append `num` identifiers of the specified kind after the last identifier
+  /// of that kind. The coefficient columns corresponding to the added
+  /// identifiers are initialized to zero.
+  void appendDomainId(unsigned num = 1);
+  void appendRangeId(unsigned num = 1);
+
+protected:
+  // Number of dimension identifers corresponding to domain identifers.
+  unsigned numDomainDims;
+
+  // Number of dimension identifers corresponding to range identifers.
+  unsigned numRangeDims;
+
+  /// Removes identifiers in the column range [idStart, idLimit), and copies any
+  /// remaining valid data into place, updates member variables, and resizes
+  /// arrays as needed.
+  void removeIdRange(unsigned idStart, unsigned idLimit) override;
+};
+
 /// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the
 /// dimensions, symbols, and additional variables that represent floor divisions
 /// of dimensions, symbols, and in turn other floor divisions.  Returns failure
@@ -957,6 +1049,26 @@ AffineMap alignAffineMapWithValues(AffineMap map, ValueRange operands,
                                    ValueRange dims, ValueRange syms,
                                    SmallVector<Value> *newSyms = nullptr);
 
+/// Builds a relation from the given AffineMap/AffineValueMap `map`, containing
+/// all pairs of the form `operands -> result` that satisfy `map`. `rel` is set
+/// to the relation built. For example, give the AffineMap:
+///
+///   (d0, d1)[s0] -> (d0 + s0, d0 - s0)
+///
+/// the resulting relation formed is:
+///
+///   (d0, d1) -> (r1, r2)
+///   [d0  d1  r1  r2  s0  const]
+///    1   0   -1   0  1     0     = 0
+///    0   1    0  -1  -1    0     = 0
+///
+/// For AffineValueMap, the domain and symbols have Value set corresponding to
+/// the Value in `map`. Returns failure if the AffineMap could not be flattened
+/// (i.e., semi-affine is not yet handled).
+LogicalResult getRelationFromMap(AffineMap &map, FlatAffineRelation &rel);
+LogicalResult getRelationFromMap(const AffineValueMap &map,
+                                 FlatAffineRelation &rel);
+
 } // end namespace mlir.
 
 #endif // MLIR_ANALYSIS_AFFINESTRUCTURES_H

diff  --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index db3edc95e582..fcc18508341e 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -241,459 +241,6 @@ static LogicalResult getOpIndexSet(Operation *op,
   return getIndexSet(ops, indexSet);
 }
 
-namespace {
-// ValuePositionMap manages the mapping from Values which represent dimension
-// and symbol identifiers from 'src' and 'dst' access functions to positions
-// in new space where some Values are kept separate (using addSrc/DstValue)
-// and some Values are merged (addSymbolValue).
-// Position lookups return the absolute position in the new space which
-// has the following format:
-//
-//   [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers]
-//
-// Note: access function non-IV dimension identifiers (that have 'dimension'
-// positions in the access function position space) are assigned as symbols
-// in the output position space. Convenience access functions which lookup
-// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
-// the common case of resolving positions for all access function operands.
-//
-// TODO: Generalize this: could take a template parameter for the number of maps
-// (3 in the current case), and lookups could take indices of maps to check. So
-// getSrcDimOrSymPos would be "getPos(value, {0, 2})".
-class ValuePositionMap {
-public:
-  void addSrcValue(Value value) {
-    if (addValueAt(value, &srcDimPosMap, numSrcDims))
-      ++numSrcDims;
-  }
-  void addDstValue(Value value) {
-    if (addValueAt(value, &dstDimPosMap, numDstDims))
-      ++numDstDims;
-  }
-  void addSymbolValue(Value value) {
-    if (addValueAt(value, &symbolPosMap, numSymbols))
-      ++numSymbols;
-  }
-  unsigned getSrcDimOrSymPos(Value value) const {
-    return getDimOrSymPos(value, srcDimPosMap, 0);
-  }
-  unsigned getDstDimOrSymPos(Value value) const {
-    return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
-  }
-  unsigned getSymPos(Value value) const {
-    auto it = symbolPosMap.find(value);
-    assert(it != symbolPosMap.end());
-    return numSrcDims + numDstDims + it->second;
-  }
-
-  unsigned getNumSrcDims() const { return numSrcDims; }
-  unsigned getNumDstDims() const { return numDstDims; }
-  unsigned getNumDims() const { return numSrcDims + numDstDims; }
-  unsigned getNumSymbols() const { return numSymbols; }
-
-private:
-  bool addValueAt(Value value, DenseMap<Value, unsigned> *posMap,
-                  unsigned position) {
-    auto it = posMap->find(value);
-    if (it == posMap->end()) {
-      (*posMap)[value] = position;
-      return true;
-    }
-    return false;
-  }
-  unsigned getDimOrSymPos(Value value,
-                          const DenseMap<Value, unsigned> &dimPosMap,
-                          unsigned dimPosOffset) const {
-    auto it = dimPosMap.find(value);
-    if (it != dimPosMap.end()) {
-      return dimPosOffset + it->second;
-    }
-    it = symbolPosMap.find(value);
-    assert(it != symbolPosMap.end());
-    return numSrcDims + numDstDims + it->second;
-  }
-
-  unsigned numSrcDims = 0;
-  unsigned numDstDims = 0;
-  unsigned numSymbols = 0;
-  DenseMap<Value, unsigned> srcDimPosMap;
-  DenseMap<Value, unsigned> dstDimPosMap;
-  DenseMap<Value, unsigned> symbolPosMap;
-};
-} // namespace
-
-// Builds a map from Value to identifier position in a new merged identifier
-// list, which is the result of merging dim/symbol lists from src/dst
-// iteration domains, the format of which is as follows:
-//
-//   [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
-//
-// This method populates 'valuePosMap' with mappings from operand Values in
-// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
-// to the position of these values in the merged list.
-static void buildDimAndSymbolPositionMaps(
-    const FlatAffineValueConstraints &srcDomain,
-    const FlatAffineValueConstraints &dstDomain,
-    const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap,
-    ValuePositionMap *valuePosMap,
-    FlatAffineValueConstraints *dependenceConstraints) {
-
-  // IsDimState is a tri-state boolean. It is used to distinguish three
-  // 
diff erent cases of the values passed to updateValuePosMap.
-  // - When it is TRUE, we are certain that all values are dim values.
-  // - When it is FALSE, we are certain that all values are symbol values.
-  // - When it is UNKNOWN, we need to further check whether the value is from a
-  // loop IV to determine its type (dim or symbol).
-
-  // We need this enumeration because sometimes we cannot determine whether a
-  // Value is a symbol or a dim by the information from the Value itself. If a
-  // Value appears in an affine map of a loop, we can determine whether it is a
-  // dim or not by the function `isForInductionVar`. But when a Value is in the
-  // affine set of an if-statement, there is no way to identify its category
-  // (dim/symbol) by itself. Fortunately, the Values to be inserted into
-  // `valuePosMap` come from `srcDomain` and `dstDomain`, and they hold such
-  // information of Value category: `srcDomain` and `dstDomain` organize Values
-  // by their category, such that the position of each Value stored in
-  // `srcDomain` and `dstDomain` marks which category that a Value belongs to.
-  // Therefore, we can separate Values into dim and symbol groups before passing
-  // them to the function `updateValuePosMap`. Specifically, when passing the
-  // dim group, we set IsDimState to TRUE; otherwise, we set it to FALSE.
-  // However, Values from the operands of `srcAccessMap` and `dstAccessMap` are
-  // not explicitly categorized into dim or symbol, and we have to rely on
-  // `isForInductionVar` to make the decision. IsDimState is set to UNKNOWN in
-  // this case.
-  enum IsDimState { TRUE, FALSE, UNKNOWN };
-
-  // This function places each given Value (in `values`) under a respective
-  // category in `valuePosMap`. Specifically, the placement rules are:
-  // 1) If `isDim` is FALSE, then every value in `values` are inserted into
-  // `valuePosMap` as symbols.
-  // 2) If `isDim` is UNKNOWN and the value of the current iteration is NOT an
-  // induction variable of a for-loop, we treat it as symbol as well.
-  // 3) For other cases, we decide whether to add a value to the `src` or the
-  // `dst` section of the dim category simply by the boolean value `isSrc`.
-  auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc,
-                               IsDimState isDim) {
-    for (unsigned i = 0, e = values.size(); i < e; ++i) {
-      auto value = values[i];
-      if (isDim == FALSE || (isDim == UNKNOWN && !isForInductionVar(value))) {
-        assert(isValidSymbol(value) &&
-               "access operand has to be either a loop IV or a symbol");
-        valuePosMap->addSymbolValue(value);
-      } else {
-        if (isSrc)
-          valuePosMap->addSrcValue(value);
-        else
-          valuePosMap->addDstValue(value);
-      }
-    }
-  };
-
-  // Collect values from the src and dst domains. For each domain, we separate
-  // the collected values into dim and symbol parts.
-  SmallVector<Value, 4> srcDimValues, dstDimValues, srcSymbolValues,
-      dstSymbolValues;
-  srcDomain.getValues(0, srcDomain.getNumDimIds(), &srcDimValues);
-  dstDomain.getValues(0, dstDomain.getNumDimIds(), &dstDimValues);
-  srcDomain.getValues(srcDomain.getNumDimIds(),
-                      srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
-  dstDomain.getValues(dstDomain.getNumDimIds(),
-                      dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
-
-  // Update value position map with dim values from src iteration domain.
-  updateValuePosMap(srcDimValues, /*isSrc=*/true, /*isDim=*/TRUE);
-  // Update value position map with dim values from dst iteration domain.
-  updateValuePosMap(dstDimValues, /*isSrc=*/false, /*isDim=*/TRUE);
-  // Update value position map with symbols from src iteration domain.
-  updateValuePosMap(srcSymbolValues, /*isSrc=*/true, /*isDim=*/FALSE);
-  // Update value position map with symbols from dst iteration domain.
-  updateValuePosMap(dstSymbolValues, /*isSrc=*/false, /*isDim=*/FALSE);
-  // Update value position map with identifiers from src access function.
-  updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true,
-                    /*isDim=*/UNKNOWN);
-  // Update value position map with identifiers from dst access function.
-  updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false,
-                    /*isDim=*/UNKNOWN);
-}
-
-// Sets up dependence constraints columns appropriately, in the format:
-// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
-static void
-initDependenceConstraints(const FlatAffineValueConstraints &srcDomain,
-                          const FlatAffineValueConstraints &dstDomain,
-                          const AffineValueMap &srcAccessMap,
-                          const AffineValueMap &dstAccessMap,
-                          const ValuePositionMap &valuePosMap,
-                          FlatAffineValueConstraints *dependenceConstraints) {
-  // Calculate number of equalities/inequalities and columns required to
-  // initialize FlatAffineValueConstraints for 'dependenceDomain'.
-  unsigned numIneq =
-      srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
-  AffineMap srcMap = srcAccessMap.getAffineMap();
-  assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
-  unsigned numEq = srcMap.getNumResults();
-  unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
-  unsigned numSymbols = valuePosMap.getNumSymbols();
-  unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds();
-  unsigned numIds = numDims + numSymbols + numLocals;
-  unsigned numCols = numIds + 1;
-
-  // Set flat affine constraints sizes and reserving space for constraints.
-  dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
-                               numLocals);
-
-  // Set values corresponding to dependence constraint identifiers.
-  SmallVector<Value, 4> srcLoopIVs, dstLoopIVs;
-  srcDomain.getValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
-  dstDomain.getValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
-
-  dependenceConstraints->setValues(0, srcLoopIVs.size(), srcLoopIVs);
-  dependenceConstraints->setValues(
-      srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
-
-  // Set values for the symbolic identifier dimensions. `isSymbolDetermined`
-  // indicates whether we are certain that the `values` passed in are all
-  // symbols. If `isSymbolDetermined` is true, then we treat every Value in
-  // `values` as a symbol; otherwise, we let the function `isForInductionVar` to
-  // distinguish whether a Value in `values` is a symbol or not.
-  auto setSymbolIds = [&](ArrayRef<Value> values,
-                          bool isSymbolDetermined = true) {
-    for (auto value : values) {
-      if (isSymbolDetermined || !isForInductionVar(value)) {
-        assert(isValidSymbol(value) && "expected symbol");
-        dependenceConstraints->setValue(valuePosMap.getSymPos(value), value);
-      }
-    }
-  };
-
-  // We are uncertain about whether all operands in `srcAccessMap` and
-  // `dstAccessMap` are symbols, so we set `isSymbolDetermined` to false.
-  setSymbolIds(srcAccessMap.getOperands(), /*isSymbolDetermined=*/false);
-  setSymbolIds(dstAccessMap.getOperands(), /*isSymbolDetermined=*/false);
-
-  SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
-  srcDomain.getValues(srcDomain.getNumDimIds(),
-                      srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
-  dstDomain.getValues(dstDomain.getNumDimIds(),
-                      dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
-  // Since we only take symbol Values out of `srcDomain` and `dstDomain`,
-  // `isSymbolDetermined` is kept to its default value: true.
-  setSymbolIds(srcSymbolValues);
-  setSymbolIds(dstSymbolValues);
-
-  for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds();
-       i < e; i++)
-    assert(dependenceConstraints->hasValue(i));
-}
-
-// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
-// 'dependenceDomain'.
-// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
-// srcDomain/dstDomain Value maps.
-static void addDomainConstraints(const FlatAffineValueConstraints &srcDomain,
-                                 const FlatAffineValueConstraints &dstDomain,
-                                 const ValuePositionMap &valuePosMap,
-                                 FlatAffineValueConstraints *dependenceDomain) {
-  unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds();
-
-  SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols());
-
-  auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) {
-    const FlatAffineValueConstraints &domain = isSrc ? srcDomain : dstDomain;
-    unsigned numCsts =
-        isEq ? domain.getNumEqualities() : domain.getNumInequalities();
-    unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds();
-    auto at = [&](unsigned i, unsigned j) -> int64_t {
-      return isEq ? domain.atEq(i, j) : domain.atIneq(i, j);
-    };
-    auto map = [&](unsigned i) -> int64_t {
-      return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getValue(i))
-                   : valuePosMap.getDstDimOrSymPos(domain.getValue(i));
-    };
-
-    for (unsigned i = 0; i < numCsts; ++i) {
-      // Zero fill.
-      std::fill(cst.begin(), cst.end(), 0);
-      // Set coefficients for identifiers corresponding to domain.
-      for (unsigned j = 0; j < numDimAndSymbolIds; ++j)
-        cst[map(j)] = at(i, j);
-      // Local terms.
-      for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++)
-        cst[depNumDimsAndSymbolIds + localOffset + j] =
-            at(i, numDimAndSymbolIds + j);
-      // Set constant term.
-      cst[cst.size() - 1] = at(i, domain.getNumCols() - 1);
-      // Add constraint.
-      if (isEq)
-        dependenceDomain->addEquality(cst);
-      else
-        dependenceDomain->addInequality(cst);
-    }
-  };
-
-  // Add equalities from src domain.
-  addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0);
-  // Add inequalities from src domain.
-  addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0);
-  // Add equalities from dst domain.
-  addDomain(/*isSrc=*/false, /*isEq=*/true,
-            /*localOffset=*/srcDomain.getNumLocalIds());
-  // Add inequalities from dst domain.
-  addDomain(/*isSrc=*/false, /*isEq=*/false,
-            /*localOffset=*/srcDomain.getNumLocalIds());
-}
-
-// Adds equality constraints that equate src and dst access functions
-// represented by 'srcAccessMap' and 'dstAccessMap' for each result.
-// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
-// For example, given the following two accesses functions to a 2D memref:
-//
-//   Source access function:
-//     (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
-//
-//   Destination access function:
-//     (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
-//
-// This method constructs the following equality constraints in
-// 'dependenceDomain', by equating the access functions for each result
-// (i.e. each memref dim). Notice that 'd0' for the destination access function
-// is mapped into 'd0' in the equality constraint:
-//
-//   d0      d1      s0         c
-//   --      --      --         --
-//   a0     -c0      (a1 - c1)  (a2 - c2) = 0
-//   b0     -f0      (b1 - f1)  (b2 - f2) = 0
-//
-// Returns failure if any AffineExpr cannot be flattened (due to it being
-// semi-affine). Returns success otherwise.
-static LogicalResult
-addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
-                           const AffineValueMap &dstAccessMap,
-                           const ValuePositionMap &valuePosMap,
-                           FlatAffineValueConstraints *dependenceDomain) {
-  AffineMap srcMap = srcAccessMap.getAffineMap();
-  AffineMap dstMap = dstAccessMap.getAffineMap();
-  assert(srcMap.getNumResults() == dstMap.getNumResults());
-  unsigned numResults = srcMap.getNumResults();
-
-  unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
-  ArrayRef<Value> srcOperands = srcAccessMap.getOperands();
-
-  unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
-  ArrayRef<Value> dstOperands = dstAccessMap.getOperands();
-
-  std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
-  std::vector<SmallVector<int64_t, 8>> destFlatExprs;
-  FlatAffineValueConstraints srcLocalVarCst, destLocalVarCst;
-  // Get flattened expressions for the source destination maps.
-  if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) ||
-      failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst)))
-    return failure();
-
-  unsigned domNumLocalIds = dependenceDomain->getNumLocalIds();
-  unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds();
-  unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds();
-  unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds;
-  dependenceDomain->appendLocalId(numLocalIdsToAdd);
-
-  unsigned numDims = dependenceDomain->getNumDimIds();
-  unsigned numSymbols = dependenceDomain->getNumSymbolIds();
-  unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
-  unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds;
-
-  // Equality to add.
-  SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
-  for (unsigned i = 0; i < numResults; ++i) {
-    // Zero fill.
-    std::fill(eq.begin(), eq.end(), 0);
-
-    // Flattened AffineExpr for src result 'i'.
-    const auto &srcFlatExpr = srcFlatExprs[i];
-    // Set identifier coefficients from src access function.
-    for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
-      eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
-    // Local terms.
-    for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
-      eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j];
-    // Set constant term.
-    eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
-
-    // Flattened AffineExpr for dest result 'i'.
-    const auto &destFlatExpr = destFlatExprs[i];
-    // Set identifier coefficients from dst access function.
-    for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
-      eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
-    // Local terms.
-    for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
-      eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j];
-    // Set constant term.
-    eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
-
-    // Add equality constraint.
-    dependenceDomain->addEquality(eq);
-  }
-
-  // Add equality constraints for any operands that are defined by constant ops.
-  auto addEqForConstOperands = [&](ArrayRef<Value> operands) {
-    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
-      if (isForInductionVar(operands[i]))
-        continue;
-      auto symbol = operands[i];
-      assert(isValidSymbol(symbol));
-      // Check if the symbol is a constant.
-      if (auto cOp = symbol.getDefiningOp<arith::ConstantIndexOp>())
-        dependenceDomain->addBound(FlatAffineConstraints::EQ,
-                                   valuePosMap.getSymPos(symbol), cOp.value());
-    }
-  };
-
-  // Add equality constraints for any src symbols defined by constant ops.
-  addEqForConstOperands(srcOperands);
-  // Add equality constraints for any dst symbols defined by constant ops.
-  addEqForConstOperands(dstOperands);
-
-  // By construction (see flattener), local var constraints will not have any
-  // equalities.
-  assert(srcLocalVarCst.getNumEqualities() == 0 &&
-         destLocalVarCst.getNumEqualities() == 0);
-  // Add inequalities from srcLocalVarCst and destLocalVarCst into the
-  // dependence domain.
-  SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
-  for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
-    std::fill(ineq.begin(), ineq.end(), 0);
-
-    // Set identifier coefficients from src local var constraints.
-    for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
-      ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
-          srcLocalVarCst.atIneq(r, j);
-    // Local terms.
-    for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
-      ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
-    // Set constant term.
-    ineq[ineq.size() - 1] =
-        srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
-    dependenceDomain->addInequality(ineq);
-  }
-
-  for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
-    std::fill(ineq.begin(), ineq.end(), 0);
-    // Set identifier coefficients from dest local var constraints.
-    for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
-      ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
-          destLocalVarCst.atIneq(r, j);
-    // Local terms.
-    for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
-      ineq[newLocalIdOffset + numSrcLocalIds + j] =
-          destLocalVarCst.atIneq(r, dstNumIds + j);
-    // Set constant term.
-    ineq[ineq.size() - 1] =
-        destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
-
-    dependenceDomain->addInequality(ineq);
-  }
-  return success();
-}
-
 // Returns the number of outer loop common to 'src/dstDomain'.
 // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
 static unsigned
@@ -872,6 +419,43 @@ static void computeDirectionVector(
   }
 }
 
+LogicalResult MemRefAccess::getAccessRelation(FlatAffineRelation &rel) const {
+  // Create set corresponding to domain of access.
+  FlatAffineValueConstraints domain;
+  if (failed(getOpIndexSet(opInst, &domain)))
+    return failure();
+
+  // Get access relation from access map.
+  AffineValueMap accessValueMap;
+  getAccessMap(&accessValueMap);
+  if (failed(getRelationFromMap(accessValueMap, rel)))
+    return failure();
+
+  FlatAffineRelation domainRel(rel.getNumDomainDims(), /*numRangeDims=*/0,
+                               domain);
+
+  // Merge and align domain ids of `ret` and ids of `domain`. Since the domain
+  // of the access map is a subset of the domain of access, the domain ids of
+  // `ret` are guranteed to be a subset of ids of `domain`.
+  for (unsigned i = 0, e = domain.getNumDimIds(); i < e; ++i) {
+    unsigned loc;
+    if (rel.findId(domain.getValue(i), &loc)) {
+      rel.swapId(i, loc);
+    } else {
+      rel.insertDomainId(i);
+      rel.setValue(i, domain.getValue(i));
+    }
+  }
+
+  // Append domain constraints to `ret`.
+  domainRel.appendRangeId(rel.getNumRangeDims());
+  domainRel.mergeLocalIds(rel);
+  domainRel.mergeSymbolIds(rel);
+  rel.append(domainRel);
+
+  return success();
+}
+
 // Populates 'accessMap' with composition of AffineApplyOps reachable from
 // indices of MemRefAccess.
 void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
@@ -900,17 +484,16 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
 // common to both accesses (see Dependence in AffineAnalysis.h for details).
 //
 // The memref access dependence check is comprised of the following steps:
-// *) Compute access functions for each access. Access functions are computed
-//    using AffineValueMaps initialized with the indices from an access, then
-//    composed with AffineApplyOps reachable from operands of that access,
-//    until operands of the AffineValueMap are loop IVs or symbols.
-// *) Build iteration domain constraints for each access. Iteration domain
-//    constraints are pairs of inequality constraints representing the
-//    upper/lower loop bounds for each AffineForOp in the loop nest associated
-//    with each access.
-// *) Build dimension and symbol position maps for each access, which map
-//    Values from access functions and iteration domains to their position
-//    in the merged constraint system built by this method.
+// *) Build access relation for each access. An access relation maps elements
+//    of an iteration domain to the element(s) of an array domain accessed by
+//    that iteration of the associated statement through some array reference.
+// *) Compute the dependence relation by composing access relation of
+//    `srcAccess` with the inverse of access relation of `dstAccess`.
+//    Doing this builds a relation between iteration domain of `srcAccess`
+//    to the iteration domain of `dstAccess` which access the same memory
+//    location.
+// *) Add ordering constraints for `srcAccess` to be accessed before
+//    `dstAccess`.
 //
 // This method builds a constraint system with the following column format:
 //
@@ -937,34 +520,34 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
 //     }
 //   }
 //
-// The access functions would be the following:
-//
-//   src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
-//   dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
-//
-// The iteration domains for the src/dst accesses would be the following:
+// The access relation for `srcAccess` would be the following:
 //
-//   src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
-//   dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
+//   [src_dim0, src_dim1, mem_dim0, mem_dim1,  %N,   %M,  const]
+//       2        -4       -1         0         1     0     0     = 0
+//       0         3        0        -1         0    -1     0     = 0
+//       1         0        0         0         0     0     0    >= 0
+//      -1         0        0         0         0     0     100  >= 0
+//       0         1        0         0         0     0     0    >= 0
+//       0        -1        0         0         0     0     50   >= 0
 //
-// The symbols by both accesses would be assigned to a canonical position order
-// which will be used in the dependence constraint system:
+//  The access relation for `dstAccess` would be the following:
 //
-//   symbol name: %M  %N  %K
-//   symbol  pos:  0   1   2
+//   [dst_dim0, dst_dim1, mem_dim0, mem_dim1,  %M,   %K,  const]
+//       7         9       -1         0        -1     0     0     = 0
+//       0         11       0        -1         0    -1     0     = 0
+//       1         0        0         0         0     0     0    >= 0
+//      -1         0        0         0         0     0     100  >= 0
+//       0         1        0         0         0     0     0    >= 0
+//       0        -1        0         0         0     0     50   >= 0
 //
-// Equality constraints are built by equating each result of src/destination
-// access functions. For this example, the following two equality constraints
-// will be added to the dependence constraint system:
+//  The equalities in the above relations correspond to the access maps while
+//  the inequalities corresspond to the iteration domain constraints.
 //
-//   [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
-//      2         -4        -7        -9       1      1     0     0    = 0
-//      0          3         0        -11     -1      0     1     0    = 0
+// The dependence relation formed:
 //
-// Inequality constraints from the iteration domain will be meged into
-// the dependence constraint system
-//
-//   [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
+//   [src_dim0, src_dim1, dst_dim0, dst_dim1,  %M,   %N,   %K,  const]
+//      2         -4        -7        -9        1     1     0     0    = 0
+//      0          3         0        -11      -1     0     1     0    = 0
 //       1         0         0         0        0     0     0     0    >= 0
 //      -1         0         0         0        0     0     0     100  >= 0
 //       0         1         0         0        0     0     0     0    >= 0
@@ -995,24 +578,16 @@ DependenceResult mlir::checkMemrefAccessDependence(
       !isa<AffineWriteOpInterface>(dstAccess.opInst))
     return DependenceResult::NoDependence;
 
-  // Get composed access function for 'srcAccess'.
-  AffineValueMap srcAccessMap;
-  srcAccess.getAccessMap(&srcAccessMap);
-
-  // Get composed access function for 'dstAccess'.
-  AffineValueMap dstAccessMap;
-  dstAccess.getAccessMap(&dstAccessMap);
-
-  // Get iteration domain for the 'srcAccess' operation.
-  FlatAffineValueConstraints srcDomain;
-  if (failed(getOpIndexSet(srcAccess.opInst, &srcDomain)))
+  // Create access relation from each MemRefAccess.
+  FlatAffineRelation srcRel, dstRel;
+  if (failed(srcAccess.getAccessRelation(srcRel)))
     return DependenceResult::Failure;
-
-  // Get iteration domain for 'dstAccess' operation.
-  FlatAffineValueConstraints dstDomain;
-  if (failed(getOpIndexSet(dstAccess.opInst, &dstDomain)))
+  if (failed(dstAccess.getAccessRelation(dstRel)))
     return DependenceResult::Failure;
 
+  FlatAffineValueConstraints srcDomain = srcRel.getDomainSet();
+  FlatAffineValueConstraints dstDomain = dstRel.getDomainSet();
+
   // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
   // operation of 'srcAccess' does not properly dominate the ancestor
   // operation of 'dstAccess' in the same common operation block.
@@ -1025,42 +600,27 @@ DependenceResult mlir::checkMemrefAccessDependence(
                                            numCommonLoops)) {
     return DependenceResult::NoDependence;
   }
-  // Build dim and symbol position maps for each access from access operand
-  // Value to position in merged constraint system.
-  ValuePositionMap valuePosMap;
-  buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
-                                dstAccessMap, &valuePosMap,
-                                dependenceConstraints);
-  initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
-                            valuePosMap, dependenceConstraints);
-
-  assert(valuePosMap.getNumDims() ==
-         srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
 
-  // Create memref access constraint by equating src/dst access functions.
-  // Note that this check is conservative, and will fail in the future when
-  // local variables for mod/div exprs are supported.
-  if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
-                                        dependenceConstraints)))
-    return DependenceResult::Failure;
+  // Compute the dependence relation by composing `srcRel` with the inverse of
+  // `dstRel`. Doing this builds a relation between iteration domain of
+  // `srcAccess` to the iteration domain of `dstAccess` which access the same
+  // memory locations.
+  dstRel.inverse();
+  dstRel.compose(srcRel);
+  *dependenceConstraints = dstRel;
 
   // Add 'src' happens before 'dst' ordering constraints.
   addOrderingConstraints(srcDomain, dstDomain, loopDepth,
                          dependenceConstraints);
-  // Add src and dst domain constraints.
-  addDomainConstraints(srcDomain, dstDomain, valuePosMap,
-                       dependenceConstraints);
 
   // Return 'NoDependence' if the solution space is empty: no dependence.
-  if (dependenceConstraints->isEmpty()) {
+  if (dependenceConstraints->isEmpty())
     return DependenceResult::NoDependence;
-  }
 
   // Compute dependence direction vector and return true.
-  if (dependenceComponents != nullptr) {
+  if (dependenceComponents != nullptr)
     computeDirectionVector(srcDomain, dstDomain, loopDepth,
                            dependenceConstraints, dependenceComponents);
-  }
 
   LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
   LLVM_DEBUG(dependenceConstraints->dump());

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 1688a65bcfd6..74b1a8960094 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1894,6 +1894,26 @@ void FlatAffineConstraints::removeRedundantLocalVars() {
   }
 }
 
+void FlatAffineConstraints::convertDimToLocal(unsigned dimStart,
+                                              unsigned dimLimit) {
+  assert(dimLimit <= getNumDimIds() && "Invalid dim pos range");
+
+  if (dimStart >= dimLimit)
+    return;
+
+  // Append new local variables corresponding to the dimensions to be converted.
+  unsigned convertCount = dimLimit - dimStart;
+  unsigned newLocalIdStart = getNumIds();
+  appendLocalId(convertCount);
+
+  // Swap the new local variables with dimensions.
+  for (unsigned i = 0; i < convertCount; ++i)
+    swapId(i + dimStart, i + newLocalIdStart);
+
+  // Remove dimensions converted to local variables.
+  removeIdRange(dimStart, dimLimit);
+}
+
 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
     ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
@@ -3585,3 +3605,168 @@ AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
   return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
                                    dims.size(), numSymbols);
 }
+
+FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const {
+  FlatAffineValueConstraints domain = *this;
+  // Convert all range variables to local variables.
+  domain.convertDimToLocal(getNumDomainDims(),
+                           getNumDomainDims() + getNumRangeDims());
+  return domain;
+}
+
+FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
+  FlatAffineValueConstraints range = *this;
+  // Convert all domain variables to local variables.
+  range.convertDimToLocal(0, getNumDomainDims());
+  return range;
+}
+
+void FlatAffineRelation::compose(const FlatAffineRelation &other) {
+  assert(getNumDomainDims() == other.getNumRangeDims() &&
+         "Domain of this and range of other do not match");
+  assert(std::equal(values.begin(), values.begin() + getNumDomainDims(),
+                    other.values.begin() + other.getNumDomainDims()) &&
+         "Domain of this and range of other do not match");
+
+  FlatAffineRelation rel = other;
+  mergeSymbolIds(rel);
+  mergeLocalIds(rel);
+
+  // Convert domain of `this` and range of `rel` to local identifiers.
+  convertDimToLocal(0, getNumDomainDims());
+  rel.convertDimToLocal(rel.getNumDomainDims(), rel.getNumDimIds());
+  // Add dimensions such that both relations become `domainRel -> rangeThis`.
+  appendDomainId(rel.getNumDomainDims());
+  rel.appendRangeId(getNumRangeDims());
+
+  auto thisMaybeValues = getMaybeDimValues();
+  auto relMaybeValues = rel.getMaybeDimValues();
+
+  // Add and match domain of `rel` to domain of `this`.
+  for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
+    if (relMaybeValues[i].hasValue())
+      setValue(i, relMaybeValues[i].getValue());
+  // Add and match range of `this` to range of `rel`.
+  for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) {
+    unsigned rangeIdx = rel.getNumDomainDims() + i;
+    if (thisMaybeValues[rangeIdx].hasValue())
+      rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue());
+  }
+
+  // Append `this` to `rel` and simplify constraints.
+  rel.append(*this);
+  rel.removeRedundantLocalVars();
+
+  *this = rel;
+}
+
+void FlatAffineRelation::inverse() {
+  unsigned oldDomain = getNumDomainDims();
+  unsigned oldRange = getNumRangeDims();
+  // Add new range ids.
+  appendRangeId(oldDomain);
+  // Swap new ids with domain.
+  for (unsigned i = 0; i < oldDomain; ++i)
+    swapId(i, oldDomain + oldRange + i);
+  // Remove the swapped domain.
+  removeIdRange(0, oldDomain);
+  // Set domain and range as inverse.
+  numDomainDims = oldRange;
+  numRangeDims = oldDomain;
+}
+
+void FlatAffineRelation::insertDomainId(unsigned pos, unsigned num) {
+  assert(pos <= getNumDomainDims() &&
+         "Id cannot be inserted at invalid position");
+  insertDimId(pos, num);
+  numDomainDims += num;
+}
+
+void FlatAffineRelation::insertRangeId(unsigned pos, unsigned num) {
+  assert(pos <= getNumRangeDims() &&
+         "Id cannot be inserted at invalid position");
+  insertDimId(getNumDomainDims() + pos, num);
+  numRangeDims += num;
+}
+
+void FlatAffineRelation::appendDomainId(unsigned num) {
+  insertDimId(getNumDomainDims(), num);
+  numDomainDims += num;
+}
+
+void FlatAffineRelation::appendRangeId(unsigned num) {
+  insertDimId(getNumDimIds(), num);
+  numRangeDims += num;
+}
+
+void FlatAffineRelation::removeIdRange(unsigned idStart, unsigned idLimit) {
+  if (idStart >= idLimit)
+    return;
+
+  // Compute number of domain and range identifiers to remove. This is done by
+  // intersecting the range of domain/range ids with range of ids to remove.
+  unsigned intersectDomainLHS = std::min(idLimit, getNumDomainDims());
+  unsigned intersectDomainRHS = idStart;
+  unsigned intersectRangeLHS = std::min(idLimit, getNumDimIds());
+  unsigned intersectRangeRHS = std::max(idStart, getNumDomainDims());
+
+  FlatAffineValueConstraints::removeIdRange(idStart, idLimit);
+
+  if (intersectDomainLHS > intersectDomainRHS)
+    numDomainDims -= intersectDomainLHS - intersectDomainRHS;
+  if (intersectRangeLHS > intersectRangeRHS)
+    numRangeDims -= intersectRangeLHS - intersectRangeRHS;
+}
+
+LogicalResult mlir::getRelationFromMap(AffineMap &map,
+                                       FlatAffineRelation &rel) {
+  // Get flattened affine expressions.
+  std::vector<SmallVector<int64_t, 8>> flatExprs;
+  FlatAffineConstraints localVarCst;
+  if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst)))
+    return failure();
+
+  unsigned oldDimNum = localVarCst.getNumDimIds();
+  unsigned oldCols = localVarCst.getNumCols();
+  unsigned numRangeIds = map.getNumResults();
+  unsigned numDomainIds = map.getNumDims();
+
+  // Add range as the new expressions.
+  localVarCst.appendDimId(numRangeIds);
+
+  // Add equalities between source and range.
+  SmallVector<int64_t, 8> eq(localVarCst.getNumCols());
+  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+    // Zero fill.
+    std::fill(eq.begin(), eq.end(), 0);
+    // Fill equality.
+    for (unsigned j = 0, f = oldDimNum; j < f; ++j)
+      eq[j] = flatExprs[i][j];
+    for (unsigned j = oldDimNum, f = oldCols; j < f; ++j)
+      eq[j + numRangeIds] = flatExprs[i][j];
+    // Set this dimension to -1 to equate lhs and rhs and add equality.
+    eq[numDomainIds + i] = -1;
+    localVarCst.addEquality(eq);
+  }
+
+  // Create relation and return success.
+  rel = FlatAffineRelation(numDomainIds, numRangeIds, localVarCst);
+  return success();
+}
+
+LogicalResult mlir::getRelationFromMap(const AffineValueMap &map,
+                                       FlatAffineRelation &rel) {
+
+  AffineMap affineMap = map.getAffineMap();
+  if (failed(getRelationFromMap(affineMap, rel)))
+    return failure();
+
+  // Set symbol values for domain dimensions and symbols.
+  for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
+    rel.setValue(i, map.getOperand(i));
+  for (unsigned i = rel.getNumDimIds(), e = rel.getNumDimAndSymbolIds(); i < e;
+       ++i)
+    rel.setValue(i, map.getOperand(i - rel.getNumRangeDims()));
+
+  return success();
+}


        


More information about the Mlir-commits mailing list