[Mlir-commits] [mlir] 372d88b - [mlir][sparse] code cleanup.

Peiming Liu llvmlistbot at llvm.org
Wed Aug 2 16:20:01 PDT 2023


Author: Peiming Liu
Date: 2023-08-02T23:19:55Z
New Revision: 372d88b051515f54264bf6568395bcb7f3db4de1

URL: https://github.com/llvm/llvm-project/commit/372d88b051515f54264bf6568395bcb7f3db4de1
DIFF: https://github.com/llvm/llvm-project/commit/372d88b051515f54264bf6568395bcb7f3db4de1.diff

LOG: [mlir][sparse] code cleanup.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D156941

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 315cc1d05e9266..30f4c1db3c4cd0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -184,24 +184,32 @@ class LoopEmitter {
   void exitCurrentLoop(RewriterBase &rewriter, Location loc,
                        MutableArrayRef<Value> reduc = {});
 
+  /// Get the range of values for all induction variables.
+  auto getLoopIVsRange() const {
+    return llvm::map_range(loopStack, [](const LoopInfo &li) { return li.iv; });
+  }
+
   /// Fills the out-parameter with the loop induction variables for all
   /// loops in the current loop-stack.  The variables are given in the
   /// same order as the loop-stack, hence `ivs` should be indexed into
   /// by `LoopOrd` (not `LoopId`).
-  void getLoopIVs(SmallVectorImpl<Value> &ivs) const {
-    ivs.clear();
-    ivs.reserve(getCurrentDepth());
-    for (auto &l : loopStack)
-      ivs.push_back(l.iv);
+  SmallVector<Value> getLoopIVs() const {
+    return llvm::to_vector(getLoopIVsRange());
   }
 
   /// Gets the current depth of the loop-stack.  The result is given
   /// the type `LoopOrd` for the same reason as one-past-the-end iterators.
-  LoopOrd getCurrentDepth() const { return loopStack.size(); }
+  LoopOrd getCurrentDepth() const {
+    return llvm::range_size(getLoopIVsRange());
+  }
 
   /// Gets loop induction variable for the given `LoopOrd`.
   Value getLoopIV(LoopOrd n) const {
-    return n < getCurrentDepth() ? loopStack[n].iv : Value();
+    if (n >= getCurrentDepth())
+      return Value();
+    auto it = getLoopIVsRange().begin();
+    std::advance(it, n);
+    return *it;
   }
 
   /// Gets the total number of manifest tensors (excluding the synthetic

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index ebbe88ee902948..2a290f202c70a5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1318,10 +1318,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
                                                     reduc);
     }
 
-    SmallVector<Value> lcvs;
-    lcvs.reserve(lvlRank);
-    loopEmitter.getLoopIVs(lcvs);
-
+    SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
     if (op.getOrder()) {
       // FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
       SmallVector<Value> dcvs = lcvs; // keep a copy

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index b75cdba8449a1a..2450fd6c7d03f6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -977,17 +977,10 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
   // Direct insertion in lexicographic coordinate order.
   if (!env.isExpand()) {
     const LoopOrd numLoops = op.getRank(t);
-    // TODO: rewrite this to use `env.emitter().getLoopIVs(ivs)`
-    // instead.  We just need to either assert that `numLoops ==
-    // env.emitter().getCurrentDepth()`, or else update the `getLoopIVs`
-    // method to take an optional parameter to restrict to a smaller depth.
-    SmallVector<Value> ivs;
-    ivs.reserve(numLoops);
-    for (LoopOrd n = 0; n < numLoops; n++) {
-      const auto iv = env.emitter().getLoopIV(n);
-      assert(iv);
-      ivs.push_back(iv);
-    }
+    // Retrieves the first `numLoop` induction variables.
+    SmallVector<Value> ivs = llvm::to_vector(
+        llvm::drop_end(env.emitter().getLoopIVsRange(),
+                       env.emitter().getCurrentDepth() - numLoops));
     Value chain = env.getInsertionChain();
     if (!env.getValidLexInsert()) {
       env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
@@ -1438,7 +1431,7 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
 
 /// Generates the induction structure for a while-loop.
 static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
-                            bool needsUniv, scf::WhileOp whileOp) {
+                            bool needsUniv) {
   Location loc = env.op().getLoc();
   // Finalize each else branch of all if statements.
   if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
@@ -1472,7 +1465,8 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
       builder.setInsertionPointAfter(ifOp);
     }
   }
-  builder.setInsertionPointToEnd(&whileOp.getAfter().front());
+  // No need to set the insertion point here as LoopEmitter keeps track of the
+  // basic block where scf::Yield should be inserted.
 }
 
 /// Generates a single if-statement within a while-loop.
@@ -1525,8 +1519,8 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
 
 /// Generates end of true branch of if-statement within a while-loop.
 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
-                  Operation *loop, Value redInput, Value cntInput,
-                  Value insInput, Value validIns) {
+                  Value redInput, Value cntInput, Value insInput,
+                  Value validIns) {
   SmallVector<Value> operands;
   if (env.isReduc()) {
     operands.push_back(env.getReduc());
@@ -1800,7 +1794,7 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
       env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
   } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
     // End a while-loop.
-    finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp);
+    finalizeWhileOp(env, rewriter, idx, needsUniv);
   } else {
     needsUniv = false;
   }
@@ -1875,8 +1869,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
         if (!isSingleCond) {
           scf::IfOp ifOp = genIf(env, rewriter, idx, lj);
           genStmt(env, rewriter, ej, at + 1);
-          endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput,
-                validIns);
+          endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
         } else {
           genStmt(env, rewriter, ej, at + 1);
         }


        


More information about the Mlir-commits mailing list