[Mlir-commits] [mlir] [mlir][sparse] cleanup of CodegenEnv reduction API (PR #75243)

Aart Bik llvmlistbot at llvm.org
Tue Dec 12 12:33:51 PST 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/75243

None

>From e274a58e30a527849ff5de6ffb0c0846fab2ecc3 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 12 Dec 2023 12:30:57 -0800
Subject: [PATCH] [mlir][sparse] cleanup of CodegenEnv reduction API

---
 .../SparseTensor/Transforms/CodegenEnv.cpp    | 27 +++++++++++--------
 .../SparseTensor/Transforms/CodegenEnv.h      |  9 ++++---
 .../Transforms/Sparsification.cpp             | 26 +++++++++---------
 3 files changed, 36 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 312aefc0936c28..4bd3af2d3f2f6a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -115,10 +115,10 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
   SmallVector<Value> params;
   if (isReduc()) {
     params.push_back(redVal);
-    if (redValidLexInsert)
+    if (isValidLexInsert())
       params.push_back(redValidLexInsert);
   } else {
-    assert(!redValidLexInsert);
+    assert(!isValidLexInsert());
   }
   if (isExpand())
     params.push_back(expCount);
@@ -128,8 +128,8 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
   unsigned i = 0;
   if (isReduc()) {
     updateReduc(params[i++]);
-    if (redValidLexInsert)
-      setValidLexInsert(params[i++]);
+    if (isValidLexInsert())
+      updateValidLexInsert(params[i++]);
   }
   if (isExpand())
     updateExpandCount(params[i++]);
@@ -235,14 +235,14 @@ void CodegenEnv::endExpand() {
 //===----------------------------------------------------------------------===//
 
 void CodegenEnv::startReduc(ExprId exp, Value val) {
-  assert(!isReduc() && exp != detail::kInvalidId);
+  assert(!isReduc() && exp != detail::kInvalidId && val);
   redExp = exp;
   redVal = val;
   latticeMerger.setExprValue(exp, val);
 }
 
 void CodegenEnv::updateReduc(Value val) {
-  assert(isReduc());
+  assert(isReduc() && val);
   redVal = val;
   latticeMerger.clearExprValue(redExp);
   latticeMerger.setExprValue(redExp, val);
@@ -257,13 +257,18 @@ Value CodegenEnv::endReduc() {
   return val;
 }
 
-void CodegenEnv::setValidLexInsert(Value val) {
-  assert(isReduc() && val);
+void CodegenEnv::startValidLexInsert(Value val) {
+  assert(!isValidLexInsert() && isReduc() && val);
+  redValidLexInsert = val;
+}
+
+void CodegenEnv::updateValidLexInsert(Value val) {
+  assert(redValidLexInsert && isReduc() && val);
   redValidLexInsert = val;
 }
 
-void CodegenEnv::clearValidLexInsert() {
-  assert(!isReduc());
+void CodegenEnv::endValidLexInsert() {
+  assert(isValidLexInsert() && !isReduc());
   redValidLexInsert = Value();
 }
 
@@ -272,7 +277,7 @@ void CodegenEnv::startCustomReduc(ExprId exp) {
   redCustom = exp;
 }
 
-Value CodegenEnv::getCustomRedId() {
+Value CodegenEnv::getCustomRedId() const {
   assert(isCustomReduc());
   return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index a1947f48393ef9..cd626041834b12 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -150,13 +150,16 @@ class CodegenEnv {
   void updateReduc(Value val);
   Value getReduc() const { return redVal; }
   Value endReduc();
-  void setValidLexInsert(Value val);
-  void clearValidLexInsert();
+
+  void startValidLexInsert(Value val);
+  bool isValidLexInsert() const { return redValidLexInsert != nullptr; }
+  void updateValidLexInsert(Value val);
   Value getValidLexInsert() const { return redValidLexInsert; }
+  void endValidLexInsert();
 
   void startCustomReduc(ExprId exp);
   bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
-  Value getCustomRedId();
+  Value getCustomRedId() const;
   void endCustomReduc();
 
 private:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 992be434fc6231..2367d3b5f37ade 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -415,9 +415,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
     SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
         env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
     Value chain = env.getInsertionChain();
-    if (!env.getValidLexInsert()) {
-      env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
-    } else {
+    if (env.isValidLexInsert()) {
       // Generates runtime check for a valid lex during reduction,
       // to avoid inserting the identity value for empty reductions.
       //   if (validLexInsert) then
@@ -438,6 +436,9 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
       // Value assignment.
       builder.setInsertionPointAfter(ifValidLexInsert);
       env.updateInsertionChain(ifValidLexInsert.getResult(0));
+    } else {
+      // Generates regular insertion chain.
+      env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
     }
     return;
   }
@@ -688,12 +689,13 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
           env.startReduc(exp, genTensorLoad(env, builder, exp));
         }
         if (env.hasSparseOutput())
-          env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false));
+          env.startValidLexInsert(
+              constantI1(builder, env.op().getLoc(), false));
       } else {
         if (!env.isCustomReduc() || env.isReduc())
           genTensorStore(env, builder, exp, env.endReduc());
         if (env.hasSparseOutput())
-          env.clearValidLexInsert();
+          env.endValidLexInsert();
       }
     } else {
       // Start or end loop invariant hoisting of a tensor load.
@@ -846,9 +848,9 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
       if (env.isReduc()) {
         yields.push_back(env.getReduc());
         env.updateReduc(ifOp.getResult(y++));
-        if (env.getValidLexInsert()) {
+        if (env.isValidLexInsert()) {
           yields.push_back(env.getValidLexInsert());
-          env.setValidLexInsert(ifOp.getResult(y++));
+          env.updateValidLexInsert(ifOp.getResult(y++));
         }
       }
       if (env.isExpand()) {
@@ -904,7 +906,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
       });
   if (env.isReduc()) {
     types.push_back(env.getReduc().getType());
-    if (env.getValidLexInsert())
+    if (env.isValidLexInsert())
       types.push_back(env.getValidLexInsert().getType());
   }
   if (env.isExpand())
@@ -924,10 +926,10 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
   if (env.isReduc()) {
     operands.push_back(env.getReduc());
     env.updateReduc(redInput);
-    if (env.getValidLexInsert()) {
+    if (env.isValidLexInsert()) {
       // Any overlapping indices during a reduction creates a valid lex insert.
       operands.push_back(constantI1(builder, env.op().getLoc(), true));
-      env.setValidLexInsert(validIns);
+      env.updateValidLexInsert(validIns);
     }
   }
   if (env.isExpand()) {
@@ -1174,8 +1176,8 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
   // Either a for-loop or a while-loop that iterates over a slice.
   if (isSingleCond) {
     // Any iteration creates a valid lex insert.
-    if (env.isReduc() && env.getValidLexInsert())
-      env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
+    if (env.isReduc() && env.isValidLexInsert())
+      env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
   } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
     // End a while-loop.
     finalizeWhileOp(env, rewriter, needsUniv);



More information about the Mlir-commits mailing list