[Mlir-commits] [mlir] [mlir][sparse] minor refactoring of sparsification file (PR #74403)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 4 18:47:27 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter
---
Full diff: https://github.com/llvm/llvm-project/pull/74403.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+19-38)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index e0d3ce241e454..d171087f56ab1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -34,6 +34,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/SmallBitVector.h"
+
#include <optional>
using namespace mlir;
@@ -43,11 +44,6 @@ using namespace mlir::sparse_tensor;
// Sparsifier analysis methods.
//===----------------------------------------------------------------------===//
-// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
-// and those letters are too easy to confuse visually. We should switch
-// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
-// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
-
/// Determines if affine expression is invariant.
static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
bool &isAtLoop) {
@@ -56,11 +52,9 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
const LoopId i = cast<AffineDimExpr>(a).getPosition();
if (i == ldx) {
isAtLoop = true;
- // Must be invariant if we are at the given loop.
- return true;
+ return true; // invariant at given loop
}
- // The DimExpr is invariant the loop has already been generated.
- return i < loopDepth;
+ return i < loopDepth; // invariant when already generated
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
@@ -85,7 +79,6 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
if (!isUndefLT(merger.getLvlType(tid, idx)))
return false; // used more than once
-
if (setLvlFormat)
merger.setLevelAndType(tid, idx, lvl, lt);
return true;
@@ -195,7 +188,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
}
}
-/// Get the total number of compound affine expressions in the
+/// Gets the total number of compound affine expressions in the
/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
///
/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
@@ -225,7 +218,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
return num;
}
-/// Get the total number of sparse levels with compound affine
+/// Gets the total number of sparse levels with compound affine
/// expressions, summed over all operands of the `GenericOp`.
static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
unsigned num = 0;
@@ -235,6 +228,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
return num;
}
+// Returns true iff output has nontrivial affine indices.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
OpOperand *out = op.getDpsInitOperand(0);
if (getSparseTensorType(out->get()).isAllDense())
@@ -260,11 +254,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
const auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
annotated = true;
-
const Level lvlRank = map.getNumResults();
assert(!enc || lvlRank == enc.getLvlRank());
assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
-
// We only need to do index reduction if there is at least one non-trivial
// index expression on sparse levels.
// If all non-trivial index expression is on dense levels, we can
@@ -343,9 +335,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
}
/// Generates index for load/store on sparse tensor.
-// FIXME: It's not entirely clear what "index" means here (i.e., is it
-// a "coordinate", or "Ldx", or what). So the function should be renamed
-// and/or the documentation expanded in order to clarify.
static Value genIndex(CodegenEnv &env, OpOperand *t) {
const auto map = env.op().getMatchingIndexingMap(t);
const auto stt = getSparseTensorType(t->get());
@@ -495,7 +484,6 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
Value val = env.exp(exp).val;
if (val)
return val;
-
// Load during insertion.
linalg::GenericOp op = env.op();
OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
@@ -574,7 +562,7 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
/// exception of index computations, which need to be relinked to actual
/// inlined cloned code.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
- Value e, LoopId ldx) {
+ Value e) {
if (auto arg = dyn_cast<BlockArgument>(e)) {
// Direct arguments of the original linalg op must be converted
// into dense tensor loads. Note that we should not encounter
@@ -598,7 +586,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
rewriter.updateRootInPlace(def, [&]() {
def->setOperand(
- i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+ i, relinkBranch(env, rewriter, block, def->getOperand(i)));
});
}
}
@@ -607,8 +595,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
}
/// Recursively generates tensor expression.
-static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
- LoopId ldx) {
+static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
if (e == ::mlir::sparse_tensor::detail::kInvalidId)
return Value();
@@ -631,15 +618,15 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
// based on the type of the other operand.
if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
- v1 = genExp(env, rewriter, exp.children.e1, ldx);
+ v1 = genExp(env, rewriter, exp.children.e1);
v0 = constantZero(rewriter, loc, v1.getType());
} else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
- v0 = genExp(env, rewriter, exp.children.e0, ldx);
+ v0 = genExp(env, rewriter, exp.children.e0);
v1 = constantZero(rewriter, loc, v0.getType());
} else {
- v0 = genExp(env, rewriter, exp.children.e0, ldx);
- v1 = genExp(env, rewriter, exp.children.e1, ldx);
+ v0 = genExp(env, rewriter, exp.children.e0);
+ v1 = genExp(env, rewriter, exp.children.e1);
}
Value ee;
@@ -653,7 +640,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
kind == TensorExp::Kind::kReduce ||
kind == TensorExp::Kind::kSelect)) {
OpBuilder::InsertionGuard guard(rewriter);
- ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
+ ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
}
}
@@ -806,7 +793,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
return isCompressedLT(lt) || isSingletonLT(lt);
});
-
return isParallelFor(env, isOuter, isSparse);
}
@@ -1112,11 +1098,6 @@ static bool translateBitsToTidLvlPairs(
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
// to avoid putting it inside inner loops.
- // NOTE: It assumes that the levels of the input tensor are
- // initialized in order (and it is also currently guaranteed by
- // computeIterationGraph), another more admissible approach
- // might be accepting out-of-order access between consecutive
- // dense levels.
affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
}
}
@@ -1221,7 +1202,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
LoopOrd at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
if (at == env.getLoopNum()) {
- Value rhs = genExp(env, rewriter, exp, at - 1);
+ Value rhs = genExp(env, rewriter, exp);
genTensorStore(env, rewriter, exp, rhs);
return;
}
@@ -1235,8 +1216,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
- //
- // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
+ // We cannot change this to `for (const LatPointId li : env.set(lts))`
// because the loop body causes data-movement which invalidates
// the iterator.
const unsigned lsize = env.set(lts).size();
@@ -1251,7 +1231,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
Value validIns = env.getValidLexInsert();
- // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
+ // We cannot change this to `for (const LatPointId lj : env.set(lts))`
// because the loop body causes data-movement which invalidates the
// iterator.
for (unsigned j = 0; j < lsize; j++) {
@@ -1323,6 +1303,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (hasNonTrivialAffineOnSparseOut(op))
return failure();
+ // Only accept scheduled loops.
if (!op->hasAttr("sorted")) {
return rewriter.notifyMatchFailure(
op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
@@ -1348,9 +1329,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
}
}
- CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
// Detects sparse annotations and translates the per-level sparsity
// information for all tensors to loop indices in the kernel.
+ CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
if (!findSparseAnnotations(env, needIdxRed))
return failure();
``````````
</details>
https://github.com/llvm/llvm-project/pull/74403
More information about the Mlir-commits
mailing list