[Mlir-commits] [mlir] [mlir][sparse] cleanup of CodegenEnv reduction API (PR #75243)
Aart Bik
llvmlistbot at llvm.org
Tue Dec 12 12:33:51 PST 2023
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/75243
None
>From e274a58e30a527849ff5de6ffb0c0846fab2ecc3 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 12 Dec 2023 12:30:57 -0800
Subject: [PATCH] [mlir][sparse] cleanup of CodegenEnv reduction API
---
.../SparseTensor/Transforms/CodegenEnv.cpp | 27 +++++++++++--------
.../SparseTensor/Transforms/CodegenEnv.h | 9 ++++---
.../Transforms/Sparsification.cpp | 26 +++++++++---------
3 files changed, 36 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 312aefc0936c28..4bd3af2d3f2f6a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -115,10 +115,10 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
SmallVector<Value> params;
if (isReduc()) {
params.push_back(redVal);
- if (redValidLexInsert)
+ if (isValidLexInsert())
params.push_back(redValidLexInsert);
} else {
- assert(!redValidLexInsert);
+ assert(!isValidLexInsert());
}
if (isExpand())
params.push_back(expCount);
@@ -128,8 +128,8 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
unsigned i = 0;
if (isReduc()) {
updateReduc(params[i++]);
- if (redValidLexInsert)
- setValidLexInsert(params[i++]);
+ if (isValidLexInsert())
+ updateValidLexInsert(params[i++]);
}
if (isExpand())
updateExpandCount(params[i++]);
@@ -235,14 +235,14 @@ void CodegenEnv::endExpand() {
//===----------------------------------------------------------------------===//
void CodegenEnv::startReduc(ExprId exp, Value val) {
- assert(!isReduc() && exp != detail::kInvalidId);
+ assert(!isReduc() && exp != detail::kInvalidId && val);
redExp = exp;
redVal = val;
latticeMerger.setExprValue(exp, val);
}
void CodegenEnv::updateReduc(Value val) {
- assert(isReduc());
+ assert(isReduc() && val);
redVal = val;
latticeMerger.clearExprValue(redExp);
latticeMerger.setExprValue(redExp, val);
@@ -257,13 +257,18 @@ Value CodegenEnv::endReduc() {
return val;
}
-void CodegenEnv::setValidLexInsert(Value val) {
- assert(isReduc() && val);
+void CodegenEnv::startValidLexInsert(Value val) {
+ assert(!isValidLexInsert() && isReduc() && val);
+ redValidLexInsert = val;
+}
+
+void CodegenEnv::updateValidLexInsert(Value val) {
+ assert(redValidLexInsert && isReduc() && val);
redValidLexInsert = val;
}
-void CodegenEnv::clearValidLexInsert() {
- assert(!isReduc());
+void CodegenEnv::endValidLexInsert() {
+ assert(isValidLexInsert() && !isReduc());
redValidLexInsert = Value();
}
@@ -272,7 +277,7 @@ void CodegenEnv::startCustomReduc(ExprId exp) {
redCustom = exp;
}
-Value CodegenEnv::getCustomRedId() {
+Value CodegenEnv::getCustomRedId() const {
assert(isCustomReduc());
return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index a1947f48393ef9..cd626041834b12 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -150,13 +150,16 @@ class CodegenEnv {
void updateReduc(Value val);
Value getReduc() const { return redVal; }
Value endReduc();
- void setValidLexInsert(Value val);
- void clearValidLexInsert();
+
+ void startValidLexInsert(Value val);
+ bool isValidLexInsert() const { return redValidLexInsert != nullptr; }
+ void updateValidLexInsert(Value val);
Value getValidLexInsert() const { return redValidLexInsert; }
+ void endValidLexInsert();
void startCustomReduc(ExprId exp);
bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
- Value getCustomRedId();
+ Value getCustomRedId() const;
void endCustomReduc();
private:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 992be434fc6231..2367d3b5f37ade 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -415,9 +415,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
Value chain = env.getInsertionChain();
- if (!env.getValidLexInsert()) {
- env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
- } else {
+ if (env.isValidLexInsert()) {
// Generates runtime check for a valid lex during reduction,
// to avoid inserting the identity value for empty reductions.
// if (validLexInsert) then
@@ -438,6 +436,9 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
// Value assignment.
builder.setInsertionPointAfter(ifValidLexInsert);
env.updateInsertionChain(ifValidLexInsert.getResult(0));
+ } else {
+ // Generates regular insertion chain.
+ env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
}
return;
}
@@ -688,12 +689,13 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
env.startReduc(exp, genTensorLoad(env, builder, exp));
}
if (env.hasSparseOutput())
- env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false));
+ env.startValidLexInsert(
+ constantI1(builder, env.op().getLoc(), false));
} else {
if (!env.isCustomReduc() || env.isReduc())
genTensorStore(env, builder, exp, env.endReduc());
if (env.hasSparseOutput())
- env.clearValidLexInsert();
+ env.endValidLexInsert();
}
} else {
// Start or end loop invariant hoisting of a tensor load.
@@ -846,9 +848,9 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
if (env.isReduc()) {
yields.push_back(env.getReduc());
env.updateReduc(ifOp.getResult(y++));
- if (env.getValidLexInsert()) {
+ if (env.isValidLexInsert()) {
yields.push_back(env.getValidLexInsert());
- env.setValidLexInsert(ifOp.getResult(y++));
+ env.updateValidLexInsert(ifOp.getResult(y++));
}
}
if (env.isExpand()) {
@@ -904,7 +906,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
});
if (env.isReduc()) {
types.push_back(env.getReduc().getType());
- if (env.getValidLexInsert())
+ if (env.isValidLexInsert())
types.push_back(env.getValidLexInsert().getType());
}
if (env.isExpand())
@@ -924,10 +926,10 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
if (env.isReduc()) {
operands.push_back(env.getReduc());
env.updateReduc(redInput);
- if (env.getValidLexInsert()) {
+ if (env.isValidLexInsert()) {
// Any overlapping indices during a reduction creates a valid lex insert.
operands.push_back(constantI1(builder, env.op().getLoc(), true));
- env.setValidLexInsert(validIns);
+ env.updateValidLexInsert(validIns);
}
}
if (env.isExpand()) {
@@ -1174,8 +1176,8 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
// Either a for-loop or a while-loop that iterates over a slice.
if (isSingleCond) {
// Any iteration creates a valid lex insert.
- if (env.isReduc() && env.getValidLexInsert())
- env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
+ if (env.isReduc() && env.isValidLexInsert())
+ env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
// End a while-loop.
finalizeWhileOp(env, rewriter, needsUniv);
More information about the Mlir-commits
mailing list