[Mlir-commits] [mlir] 27aabca - [mlir][sparse] make resolve cycle works with affine expressions.

Peiming Liu llvmlistbot at llvm.org
Tue Nov 22 16:09:38 PST 2022


Author: Peiming Liu
Date: 2022-11-23T00:09:33Z
New Revision: 27aabca0581f16a96faa1a452c6e7716f118aed4

URL: https://github.com/llvm/llvm-project/commit/27aabca0581f16a96faa1a452c6e7716f118aed4
DIFF: https://github.com/llvm/llvm-project/commit/27aabca0581f16a96faa1a452c6e7716f118aed4.diff

LOG: [mlir][sparse] make resolve cycle works with affine expressions.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D138173

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 7e6a350b7117..bafe752b03d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -430,6 +430,9 @@ class SparseTensorLoopEmitter {
       coords.push_back(l.iv);
   }
 
+  /// Gets loop induction variable at the given level.
+  unsigned getCurrentDepth() const { return loopStack.size(); }
+
   /// Gets loop induction variable at the given level.
   Value getLoopIV(size_t level) const {
     if (level < loopStack.size())

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 46cc2182c5c1..1367baaa02b5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -90,6 +90,11 @@ struct CodeGen {
   // Topsort (reference should remain in scope).
   std::vector<unsigned> &topSort;
 
+  ArrayRef<unsigned> getLoopCurStack() const {
+    ArrayRef<unsigned> topSortRef = topSort;
+    return topSortRef.slice(0, loopEmitter.getCurrentDepth());
+  }
+
   Value getLoopIdxValue(size_t loopIdx) const {
     for (unsigned lv = 0; lv < topSort.size(); lv++)
       if (topSort[lv] == loopIdx)
@@ -134,21 +139,83 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
 // Sparse compiler analysis methods.
 //===----------------------------------------------------------------------===//
 
+/// Determines if affine expression is invariant.
+static bool isInvariantAffine(AffineExpr a, ArrayRef<unsigned> loopStack,
+                              unsigned ldx, bool &atLevel) {
+  switch (a.getKind()) {
+  case AffineExprKind::DimId: {
+    unsigned idx = a.cast<AffineDimExpr>().getPosition();
+    if (idx == ldx) {
+      atLevel = true;
+      // Must be invariant if we are at the level.
+      return true;
+    }
+    bool isInvariant = false;
+    for (unsigned loop : loopStack) {
+      isInvariant = (loop == idx);
+      if (isInvariant)
+        break;
+    }
+    return isInvariant;
+  }
+  case AffineExprKind::Add:
+  case AffineExprKind::Mul: {
+    auto binOp = a.cast<AffineBinaryOpExpr>();
+    return isInvariantAffine(binOp.getLHS(), loopStack, ldx, atLevel) &&
+           isInvariantAffine(binOp.getRHS(), loopStack, ldx, atLevel);
+  }
+  default: {
+    assert(a.isa<AffineConstantExpr>());
+    return true;
+  }
+  }
+}
+
+/// Determines if affine expression is invariant.
+static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
+                              unsigned ldx, bool &atLevel) {
+  return isInvariantAffine(a, codegen.getLoopCurStack(), ldx, atLevel);
+}
+
 /// Helper method to construct a permuted dimension ordering
 /// that adheres to the given topological sort.
-static AffineMap permute(MLIRContext *context, AffineMap m,
-                         std::vector<unsigned> &topSort) {
+static AffineMap permute(const Merger &merger, MLIRContext *context,
+                         AffineMap m, ArrayRef<unsigned> topSort) {
   unsigned sz = topSort.size();
-  assert(m.getNumResults() == sz && "TopoSort/AffineMap size mismatch");
+  assert(m.getNumDims() + merger.getNumFilterLoops() == sz &&
+         "TopoSort/AffineMap size mismatch");
   // Construct the inverse of `m`; to avoid the asymptotic complexity
   // of calling `m.getPermutedPosition` repeatedly.
-  SmallVector<unsigned> inv(sz);
-  for (unsigned i = 0; i < sz; i++)
-    inv[i] = m.getDimPosition(i);
+  SmallVector<unsigned> perm;
+  unsigned numResults = m.getNumResults();
+  BitVector worklist(numResults, true);
+  unsigned loopDepth = 1;
+
   // Construct the permutation.
-  SmallVector<unsigned> perm(sz);
-  for (unsigned i = 0; i < sz; i++)
-    perm[i] = inv[topSort[i]];
+  while (worklist.any() && loopDepth <= topSort.size()) {
+    unsigned preSize = perm.size();
+    for (auto dim : worklist.set_bits()) {
+      bool atLevel = false;
+      if (m.getResult(dim).isa<AffineConstantExpr>() ||
+          (isInvariantAffine(m.getResult(dim), topSort.slice(0, loopDepth),
+                             topSort[loopDepth - 1], atLevel) &&
+           atLevel)) {
+        // If the matching affine is constant expression or just become
+        // invariant. We can visit the dimension now without breaking the
+        // topSort constraint.
+        perm.push_back(dim);
+      }
+    }
+
+    // Removes resolved dimension.
+    for (unsigned i = preSize, e = perm.size(); i < e; i++)
+      worklist.reset(perm[i]);
+
+    // Tries to entering the next loop level.
+    loopDepth += 1;
+  }
+
+  assert(perm.size() == numResults);
   return AffineMap::getPermutationMap(perm, context);
 }
 
@@ -422,9 +489,6 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   auto iteratorTypes = op.getIteratorTypesArray();
   // Iterate over the indexing maps of every tensor in the tensor expression.
   for (OpOperand &t : op->getOpOperands()) {
-    // Skip tensor during cycle resolution.
-    if (&t == skip)
-      continue;
     // Get map and encoding.
     auto map = op.getMatchingIndexingMap(&t);
     auto enc = getSparseTensorEncoding(t.get().getType());
@@ -453,6 +517,11 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
         ta = AffineExpr();
       }
 
+      // Skip tensor during cycle resolution, though order between filter loop
+      // and dependent loops need to be guaranteed unconditionally.
+      if (&t == skip)
+        continue;
+
       if (d > 0) {
         AffineExpr fa = map.getResult(toOrigDim(enc, d - 1));
         Optional<unsigned> fldx =
@@ -945,30 +1014,6 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
   return ee;
 }
 
-/// Determines if affine expression is invariant.
-static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
-                              unsigned ldx, bool &atLevel) {
-  switch (a.getKind()) {
-  case AffineExprKind::DimId: {
-    unsigned idx = a.cast<AffineDimExpr>().getPosition();
-    if (idx == ldx) {
-      atLevel = true;
-      // Must be invariant if we are at the level.
-      return true;
-    }
-    return codegen.getLoopIdxValue(idx) != nullptr; // no longer in play?
-  }
-  case AffineExprKind::Add:
-  case AffineExprKind::Mul: {
-    auto binOp = a.cast<AffineBinaryOpExpr>();
-    return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) &&
-           isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel);
-  }
-  default:
-    return true;
-  }
-}
-
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
                           linalg::GenericOp op, unsigned exp, unsigned ldx,
@@ -1428,7 +1473,6 @@ static void translateBitsToTidDimPairs(
     // Note that we generate dense indices of the output tensor
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized codegen.
-    // Only dense dimensions should be optimized from conditions.
     auto dim = merger.getDimNum(merger.getOutTensorID(), idx).value();
     extraTids.push_back(merger.getOutTensorID());
     extraDims.push_back(dim);
@@ -1698,7 +1742,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       auto srcTp = tval.getType().cast<RankedTensorType>();
       auto dstEnc = SparseTensorEncodingAttr::get(
           op->getContext(), srcEnc.getDimLevelType(),
-          permute(getContext(), op.getMatchingIndexingMap(t),
+          permute(merger, getContext(), op.getMatchingIndexingMap(t),
                   topSort), // new order
           srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(),
           srcEnc.getIndexBitWidth());


        


More information about the Mlir-commits mailing list