[Mlir-commits] [mlir] [mlir][sparse] minor refactoring of sparsification file (PR #74403)

Aart Bik llvmlistbot at llvm.org
Mon Dec 4 18:47:00 PST 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/74403

Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter

>From 1f506bb85775f72c6ad963ec541492eb701f42e2 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 4 Dec 2023 18:44:25 -0800
Subject: [PATCH] [mlir][sparse] minor refactoring of sparsification file

Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter
---
 .../Transforms/Sparsification.cpp             | 57 +++++++------------
 1 file changed, 19 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index e0d3ce241e454..d171087f56ab1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -34,6 +34,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/TensorEncoding.h"
 #include "llvm/ADT/SmallBitVector.h"
+
 #include <optional>
 
 using namespace mlir;
@@ -43,11 +44,6 @@ using namespace mlir::sparse_tensor;
 // Sparsifier analysis methods.
 //===----------------------------------------------------------------------===//
 
-// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
-// and those letters are too easy to confuse visually.  We should switch
-// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
-// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
-
 /// Determines if affine expression is invariant.
 static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
                               bool &isAtLoop) {
@@ -56,11 +52,9 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
     const LoopId i = cast<AffineDimExpr>(a).getPosition();
     if (i == ldx) {
       isAtLoop = true;
-      // Must be invariant if we are at the given loop.
-      return true;
+      return true; // invariant at given loop
     }
-    // The DimExpr is invariant the loop has already been generated.
-    return i < loopDepth;
+    return i < loopDepth; // invariant when already generated
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
@@ -85,7 +79,6 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
     const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
     if (!isUndefLT(merger.getLvlType(tid, idx)))
       return false; // used more than once
-
     if (setLvlFormat)
       merger.setLevelAndType(tid, idx, lvl, lt);
     return true;
@@ -195,7 +188,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
   }
 }
 
-/// Get the total number of compound affine expressions in the
+/// Gets the total number of compound affine expressions in the
 /// `getMatchingIndexingMap` for the given tensor.  For the following inputs:
 ///
 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
@@ -225,7 +218,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
   return num;
 }
 
-/// Get the total number of sparse levels with compound affine
+/// Gets the total number of sparse levels with compound affine
 /// expressions, summed over all operands of the `GenericOp`.
 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
   unsigned num = 0;
@@ -235,6 +228,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
   return num;
 }
 
+// Returns true iff output has nontrivial affine indices.
 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
   OpOperand *out = op.getDpsInitOperand(0);
   if (getSparseTensorType(out->get()).isAllDense())
@@ -260,11 +254,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
     const auto enc = getSparseTensorEncoding(t.get().getType());
     if (enc)
       annotated = true;
-
     const Level lvlRank = map.getNumResults();
     assert(!enc || lvlRank == enc.getLvlRank());
     assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
-
     // We only need to do index reduction if there is at least one non-trivial
     // index expression on sparse levels.
     // If all non-trivial index expression is on dense levels, we can
@@ -343,9 +335,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
 }
 
 /// Generates index for load/store on sparse tensor.
-// FIXME: It's not entirely clear what "index" means here (i.e., is it
-// a "coordinate", or "Ldx", or what).  So the function should be renamed
-// and/or the documentation expanded in order to clarify.
 static Value genIndex(CodegenEnv &env, OpOperand *t) {
   const auto map = env.op().getMatchingIndexingMap(t);
   const auto stt = getSparseTensorType(t->get());
@@ -495,7 +484,6 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
   Value val = env.exp(exp).val;
   if (val)
     return val;
-
   // Load during insertion.
   linalg::GenericOp op = env.op();
   OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
@@ -574,7 +562,7 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
 /// exception of index computations, which need to be relinked to actual
 /// inlined cloned code.
 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
-                          Value e, LoopId ldx) {
+                          Value e) {
   if (auto arg = dyn_cast<BlockArgument>(e)) {
     // Direct arguments of the original linalg op must be converted
     // into dense tensor loads. Note that we should not encounter
@@ -598,7 +586,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
       for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
         rewriter.updateRootInPlace(def, [&]() {
           def->setOperand(
-              i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+              i, relinkBranch(env, rewriter, block, def->getOperand(i)));
         });
       }
     }
@@ -607,8 +595,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
 }
 
 /// Recursively generates tensor expression.
-static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
-                    LoopId ldx) {
+static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
   if (e == ::mlir::sparse_tensor::detail::kInvalidId)
     return Value();
 
@@ -631,15 +618,15 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
   // based on the type of the other operand.
   if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
       env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
-    v1 = genExp(env, rewriter, exp.children.e1, ldx);
+    v1 = genExp(env, rewriter, exp.children.e1);
     v0 = constantZero(rewriter, loc, v1.getType());
   } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
              env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
-    v0 = genExp(env, rewriter, exp.children.e0, ldx);
+    v0 = genExp(env, rewriter, exp.children.e0);
     v1 = constantZero(rewriter, loc, v0.getType());
   } else {
-    v0 = genExp(env, rewriter, exp.children.e0, ldx);
-    v1 = genExp(env, rewriter, exp.children.e1, ldx);
+    v0 = genExp(env, rewriter, exp.children.e0);
+    v1 = genExp(env, rewriter, exp.children.e1);
   }
 
   Value ee;
@@ -653,7 +640,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
          kind == TensorExp::Kind::kReduce ||
          kind == TensorExp::Kind::kSelect)) {
       OpBuilder::InsertionGuard guard(rewriter);
-      ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
+      ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
     }
   }
 
@@ -806,7 +793,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
     const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
     return isCompressedLT(lt) || isSingletonLT(lt);
   });
-
   return isParallelFor(env, isOuter, isSparse);
 }
 
@@ -1112,11 +1098,6 @@ static bool translateBitsToTidLvlPairs(
             // level. We need to generate the address according to the
             // affine expression. This is also the best place we can do it
             // to avoid putting it inside inner loops.
-            // NOTE: It assumes that the levels of the input tensor are
-            // initialized in order (and it is also currently guaranteed by
-            // computeIterationGraph), another more admissible approach
-            // might be accepting out-of-order access between consecutive
-            // dense levels.
             affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
           }
         }
@@ -1221,7 +1202,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
                     LoopOrd at) {
   // At each leaf, assign remaining tensor (sub)expression to output tensor.
   if (at == env.getLoopNum()) {
-    Value rhs = genExp(env, rewriter, exp, at - 1);
+    Value rhs = genExp(env, rewriter, exp);
     genTensorStore(env, rewriter, exp, rhs);
     return;
   }
@@ -1235,8 +1216,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
   bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
 
   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
-  //
-  // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
+  // We cannot change this to `for (const LatPointId li : env.set(lts))`
   // because the loop body causes data-movement which invalidates
   // the iterator.
   const unsigned lsize = env.set(lts).size();
@@ -1251,7 +1231,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     Value cntInput = env.getExpandCount();
     Value insInput = env.getInsertionChain();
     Value validIns = env.getValidLexInsert();
-    // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
+    // We cannot change this to `for (const LatPointId lj : env.set(lts))`
     // because the loop body causes data-movement which invalidates the
     // iterator.
     for (unsigned j = 0; j < lsize; j++) {
@@ -1323,6 +1303,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (hasNonTrivialAffineOnSparseOut(op))
       return failure();
 
+    // Only accept scheduled loops.
     if (!op->hasAttr("sorted")) {
       return rewriter.notifyMatchFailure(
           op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
@@ -1348,9 +1329,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       }
     }
 
-    CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
     // Detects sparse annotations and translates the per-level sparsity
     // information for all tensors to loop indices in the kernel.
+    CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
     if (!findSparseAnnotations(env, needIdxRed))
       return failure();
 



More information about the Mlir-commits mailing list