[Mlir-commits] [mlir] 1328bb6 - [mlir][sparse] extend loop emitter and optimize lattices with the awareness of slice based iteration
Peiming Liu
llvmlistbot at llvm.org
Mon Mar 20 15:20:02 PDT 2023
Author: Peiming Liu
Date: 2023-03-20T22:19:57Z
New Revision: 1328bb6ef1645951606ee3e8fa6acbbff6b2438f
URL: https://github.com/llvm/llvm-project/commit/1328bb6ef1645951606ee3e8fa6acbbff6b2438f
DIFF: https://github.com/llvm/llvm-project/commit/1328bb6ef1645951606ee3e8fa6acbbff6b2438f.diff
LOG: [mlir][sparse] extend loop emitter and optimize lattices with the awareness of slice based iteration
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D142929
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 0e6c2f1553f1c..4a83237fb1634 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -399,11 +399,17 @@ class Merger {
/// to sparse level-type.
bool hasAnySparse(const BitVector &bits) const;
+ /// Returns true if bits contains a dependent index reduction condition on
+ /// sparse levels.
+ bool hasSparseIdxReduction(const BitVector &bits) const;
+
/// Gets the level-type of the `t`th tensor on `i`th loop.
DimLevelType getDimLevelType(TensorId t, LoopId i) const {
assert(t < numTensors && i < numLoops);
return lvlTypes[t][i];
}
+
+ /// Gets the level-type of the TensorLoopId.
DimLevelType getDimLevelType(TensorLoopId b) const {
return getDimLevelType(tensor(b), loop(b));
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 8e4904ad3a592..f326d5b950a31 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -28,6 +28,23 @@ static bool isMaterializing(Value val) {
val.getDefiningOp<bufferization::AllocTensorOp>();
}
+/// Makes target array's elements sorted according to the `order` array.
+static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
+ ArrayRef<LoopId> order) {
+ std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) {
+ assert(l != r);
+ int idxL = -1, idxR = -1;
+ for (int i = 0, e = order.size(); i < e; i++) {
+ if (order[i] == l)
+ idxL = i;
+ if (order[i] == r)
+ idxR = i;
+ }
+ assert(idxL >= 0 && idxR >= 0);
+ return idxL < idxR;
+ });
+}
+
//===----------------------------------------------------------------------===//
// Code generation environment constructor and general methods
//===----------------------------------------------------------------------===//
@@ -57,15 +74,42 @@ void CodegenEnv::startEmit() {
insChain = sparseOut->get();
latticeMerger.setHasSparseOut(true);
}
+
+ // Sort the related loop array such that they are in the same order as they
+ // appears on the topoOrder.
+ // TODO: since we only handle affine addition for slice based codegen, and
+ // addition is assoicative, the order how we evaluate the expression does
+ // not matter. However, to support multiplication, the order of the loop
+ // index should match the evaluation order to the affine expression AST.
+
// Initialize loop emitter.
- SmallVector<Value> tensors;
- for (OpOperand &t : linalgOp->getOpOperands())
+ SmallVector<Value> tensors; // input tensors passed to loop emitter
+ for (OpOperand &t : linalgOp->getOpOperands()) {
tensors.push_back(t.get());
- loopEmitter.initialize(tensors,
- StringAttr::get(linalgOp.getContext(),
- linalg::GenericOp::getOperationName()),
- /*hasOutput=*/true,
- /*isSparseOut=*/sparseOut != nullptr, topSort);
+ Level rank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
+ for (Level lvl = 0; lvl < rank; lvl++) {
+ sortArrayBasedOnOrder(
+ latticeMerger.getDependentLoops(t.getOperandNumber(), lvl), topSort);
+ }
+ }
+
+ loopEmitter.initialize(
+ tensors,
+ StringAttr::get(linalgOp.getContext(),
+ linalg::GenericOp::getOperationName()),
+ /*hasOutput=*/true,
+ /*isSparseOut=*/sparseOut != nullptr, topSort,
+ // TODO: compute the map and pass it to loop emitter directly instead of
+ // passing in a callback.
+ [this](TensorId t, Level lvl) -> std::vector<std::pair<TensorId, Level>> {
+ // Translates from a list of loop index to a list of [tid, dim] pair.
+ std::vector<LoopId> rLoops = this->merger().getDependentLoops(t, lvl);
+ std::vector<std::pair<TensorId, Level>> ret;
+ ret.reserve(rLoops.size());
+ for (LoopId l : rLoops)
+ ret.emplace_back(this->merger().getLoopDefiningLvl(l));
+ return ret;
+ });
}
std::optional<Operation *> CodegenEnv::genLoopBoundary(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 776d7f7f47ece..8c6a7bd6433db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -99,7 +99,6 @@ class CodegenEnv {
topSort.reserve(capacity);
}
- ArrayRef<LoopId> getTopSort() const { return topSort; };
ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
ArrayRef<LoopId> getCurrentLoopStack() const;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index c3823c0f204d9..459a1b38e03de 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -208,12 +208,14 @@ Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
}
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
- bool isSparseOut, ArrayRef<LoopId> topSort) {
- initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
+ bool isSparseOut, ArrayRef<LoopId> topSort,
+ DependentLvlGetter dimGetter) {
+ initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, dimGetter);
}
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
- bool isSparseOut, ArrayRef<LoopId> topSort) {
+ bool isSparseOut, ArrayRef<LoopId> topSort,
+ DependentLvlGetter dimGetter) {
// First initialize the top-level type of the fields.
this->loopTag = loopTag;
this->hasOutput = hasOutput;
@@ -242,6 +244,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->loopStack.reserve(numLoops);
this->loopSeqStack.reserve(numLoops);
+ this->dependentLvlMap.assign(
+ numTensors, std::vector<std::vector<std::pair<TensorId, Level>>>());
+
// Initialize nested types of `TensorId`-indexed fields.
for (TensorId tid = 0; tid < numTensors; tid++) {
const Value t = tensors[tid];
@@ -283,6 +288,12 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
coordinatesBuffers[tid].assign(lvlRank, Value());
sliceOffsets[tid].assign(lvlRank, Value());
sliceStrides[tid].assign(lvlRank, Value());
+
+ dependentLvlMap[tid].assign(lvlRank,
+ std::vector<std::pair<TensorId, Level>>());
+ if (dimGetter)
+ for (Level l = 0; l < lvlRank; l++)
+ dependentLvlMap[tid][l] = dimGetter(tid, l);
}
// Construct the inverse of the `topSort` from the sparsifier.
@@ -997,8 +1008,8 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
}
}
-void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
- MutableArrayRef<Value> reduc) {
+void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
+ MutableArrayRef<Value> reduc) {
const LoopInfo &loopInfo = loopStack.back();
auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
@@ -1082,7 +1093,7 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
assert(loopInfo.tids.size() == loopInfo.lvls.size());
SmallVector<Value> red;
if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
- exitCoIterationLoop(rewriter, loc, reduc);
+ exitWhileLoop(rewriter, loc, reduc);
} else {
exitForLoop(rewriter, loc, reduc);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 8e6c65fd96c92..8cfe00100eba8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -76,6 +76,14 @@ class LoopEmitter {
/// initializing the loop emitter (e.g., to fill a dense output with zeros).
using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
Value memref, Value tensor)>;
+ // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
+ // index on sparse tensors.
+ // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
+ // d0 and d1 (for affine expression reduction).
+ // If the list is empty, it means that there is no affine expression on the
+ // input [tid, dim].
+ using DependentLvlGetter =
+ function_ref<std::vector<std::pair<TensorId, Level>>(TensorId, Level)>;
LoopEmitter() = default;
@@ -89,11 +97,13 @@ class LoopEmitter {
/// to `LoopId`.
void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
- ArrayRef<LoopId> topSort = {});
+ ArrayRef<LoopId> topSort = {},
+ DependentLvlGetter getter = nullptr);
explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
- ArrayRef<LoopId> topSort = {});
+ ArrayRef<LoopId> topSort = {},
+ DependentLvlGetter getter = nullptr);
/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
@@ -295,8 +305,8 @@ class LoopEmitter {
MutableArrayRef<Value> reduc);
/// Exits a while loop, returns the reduction results.
- void exitCoIterationLoop(OpBuilder &builder, Location loc,
- MutableArrayRef<Value> reduc);
+ void exitWhileLoop(OpBuilder &builder, Location loc,
+ MutableArrayRef<Value> reduc);
//
// View-based-reshape methods.
@@ -380,6 +390,15 @@ class LoopEmitter {
std::vector<std::vector<Value>> sliceOffsets;
std::vector<std::vector<Value>> sliceStrides;
+ // Map from [tid, level] to a list of dependent [tid, level].
+ // See comments for `DependentDimGetter`.
+ std::vector<std::vector<std::vector<std::pair<TensorId, Level>>>>
+ dependentLvlMap;
+
+ //
+ // View based reshape related-fields and methods
+ //
+
/// Collapse Reassociations related to a specific tensor
// TODO: support expand.
std::vector<ArrayAttr> collapseReassoc;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d7ce2b7f63f5c..f189b14c60c7e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -593,23 +593,6 @@ static void tryRelaxAffineConstraints(linalg::GenericOp op,
}
}
-/// Makes target array's elements appear in the same order as the `order` array.
-static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
- ArrayRef<LoopId> order) {
- std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) {
- assert(l != r);
- int idxL = -1, idxR = -1;
- for (int i = 0, e = order.size(); i < e; i++) {
- if (order[i] == l)
- idxL = i;
- if (order[i] == r)
- idxR = i;
- }
- assert(idxL >= 0 && idxR >= 0);
- return idxL < idxR;
- });
-}
-
static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
OpOperand *skip, SortMask mask,
std::vector<std::vector<bool>> &adjM,
@@ -1484,9 +1467,10 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
SmallVector<Level> lvls;
env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
std::optional<Level> lvl,
- DimLevelType dlt, bool) {
+ DimLevelType dlt, bool isIdxReduc) {
assert(env.merger().loop(b) == idx);
- if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
+ // FIXME: Dense index reduction can reuse the universal index as well.
+ if (!isIdxReduc && (isDenseDLT(dlt) || isUndefDLT(dlt))) {
needsUniv = true;
} else {
// sparse/singleton levels.
@@ -1503,7 +1487,8 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
unsigned lsize = env.set(lts).size();
for (unsigned i = 1; i < lsize; i++) {
const LatPointId li = env.set(lts)[i];
- if (!env.merger().hasAnySparse(env.lat(li).simple))
+ if (!env.merger().hasAnySparse(env.lat(li).simple) &&
+ !env.merger().hasSparseIdxReduction(env.lat(li).simple))
return true;
}
}
@@ -1557,75 +1542,82 @@ static bool translateBitsToTidLvlPairs(
unsigned numloopCond = 0;
bool hasNonUnique = false;
- env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
- std::optional<Level> lvl,
- DimLevelType dlt, bool) {
- if (simple.test(b)) {
- if (isUndefDLT(dlt)) {
- // An undefined dlt in the lattices, we probably mean to
- // iterate based on the level of output tensor. E.g., this
- // could be a synthetic tensor (for invariants and sparse
- // output tensor).
- // out[i][j] = invariant; or a broadcast
- // out[i][j] = in[i] (j is undef for input)
- tid = outTid;
- lvl = outLvl;
- // Skips invalid lvl (e.g., when this is a zero ranked tensor).
- if (!lvl)
- return;
- }
- hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique;
- tids.push_back(tid);
- lvls.push_back(*lvl);
- numloopCond++;
- } else if (isDenseDLT(dlt)) {
- tids.push_back(tid);
- lvls.push_back(*lvl);
- } else {
- assert(isUndefDLT(dlt));
- linalg::GenericOp op = env.op();
- if (tid >= op.getNumDpsInputs())
- // We only handle affine expression on input tensors (for now).
- return;
- OpOperand *operand = &op->getOpOperand(tid);
- const auto stt = getSparseTensorType(operand->get());
- // Non-annotated dense tensors requires no special handling.
- if (!stt.hasEncoding())
- return;
-
- ArrayRef<AffineExpr> affines =
- op.getMatchingIndexingMap(operand).getResults();
- const Level lvlRank = stt.getLvlRank();
- assert(affines.size() == static_cast<size_t>(lvlRank));
- for (Level l = 0; l < lvlRank; l++) {
- // FIXME: `toOrigDim` is deprecated.
- AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)];
- // Skip simple affine expression and non-dense levels (which
- // have their own filter loop).
- if (exp.isa<AffineDimExpr>() || !stt.isDenseLvl(l))
- continue;
- // Constant affine expression are handled in genLoop
- if (!exp.isa<AffineConstantExpr>()) {
- bool isAtLoop = false;
- if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) {
- // If the compound affine is invariant and we are right at the
- // level. We need to generate the address according to the
- // affine expression. This is also the best place we can do it
- // to avoid putting it inside inner loops.
- // NOTE: It assumes that the levels of the input tensor are
- // initialized in order (and it is also currently guaranteed by
- // computeIterationGraph), another more admissible approach
- // might be accepting out-of-order access between consecutive
- // dense levels.
- affineTids.push_back(tid);
- affineLvls.push_back(l);
- exps.push_back(exp);
+ env.merger().foreachTensorLoopId(
+ li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
+ DimLevelType dlt, bool isIdxReduc) {
+ if (simple.test(b)) {
+ if (isIdxReduc) {
+ tids.push_back(tid);
+ lvls.push_back(*lvl);
+ numloopCond++;
+ return;
+ }
+ if (isUndefDLT(dlt)) {
+ // An undefined dlt in the lattices, we probably mean to
+ // iterate based on the level of output tensor. E.g., this
+ // could be a synthetic tensor (for invariants and sparse
+ // output tensor).
+ // out[i][j] = invariant; or a broadcast
+ // out[i][j] = in[i] (j is undef for input)
+ tid = outTid;
+ lvl = outLvl;
+ // Skips invalid lvl (e.g., when this is a zero ranked tensor).
+ if (!lvl)
+ return;
+ }
+ hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique;
+ tids.push_back(tid);
+ lvls.push_back(*lvl);
+ numloopCond++;
+ } else if (isDenseDLT(dlt)) {
+ tids.push_back(tid);
+ lvls.push_back(*lvl);
+ } else {
+ assert(isUndefDLT(dlt));
+ linalg::GenericOp op = env.op();
+ if (tid >= op.getNumDpsInputs())
+ // We only handle affine expression on input tensors (for now).
+ return;
+ OpOperand *operand = &op->getOpOperand(tid);
+ const auto stt = getSparseTensorType(operand->get());
+ // Non-annotated dense tensors requires no special handling.
+ if (!stt.hasEncoding())
+ return;
+
+ ArrayRef<AffineExpr> affines =
+ op.getMatchingIndexingMap(operand).getResults();
+ const Level lvlRank = stt.getLvlRank();
+ assert(affines.size() == static_cast<size_t>(lvlRank));
+ for (Level l = 0; l < lvlRank; l++) {
+ // FIXME: `toOrigDim` is deprecated.
+ AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)];
+ // Skip simple affine expression and non-dense levels (which
+ // have their own filter loop).
+ if (exp.isa<AffineDimExpr>() || !stt.isDenseLvl(l))
+ continue;
+
+ // Constant affine expression are handled in genLoop
+ if (!exp.isa<AffineConstantExpr>()) {
+ bool isAtLoop = false;
+ if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) {
+ // If the compound affine is invariant and we are right at the
+ // level. We need to generate the address according to the
+ // affine expression. This is also the best place we can do it
+ // to avoid putting it inside inner loops.
+ // NOTE: It assumes that the levels of the input tensor are
+ // initialized in order (and it is also currently guaranteed by
+ // computeIterationGraph), another more admissible approach
+ // might be accepting out-of-order access between consecutive
+ // dense levels.
+ affineTids.push_back(tid);
+ affineLvls.push_back(l);
+ exps.push_back(exp);
+ }
+ }
}
}
- }
- }
- });
+ });
if (isDenseDLT(env.dlt(outTid, ldx))) {
// Note that we generate dense indices of the output tensor
@@ -1642,8 +1634,9 @@ static bool translateBitsToTidLvlPairs(
}
/// Starts a single loop in current sequence.
-static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
- LatPointId li, bool needsUniv) {
+static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
+ OpBuilder &builder, unsigned at,
+ unsigned li, bool needsUniv) {
// The set of tensors + lvls to generate loops on
SmallVector<TensorId> tids, affineTids;
SmallVector<Level> lvls, affineLvls;
@@ -1651,11 +1644,12 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
// becomes invariant and the address shall now be generated at the current
// level.
SmallVector<AffineExpr> affines;
- bool isFor = translateBitsToTidLvlPairs(
+ bool isSingleCond = translateBitsToTidLvlPairs(
env, li, env.topSortAt(at), tids, lvls, affineTids, affineLvls, affines);
// Emit the for/while-loop control.
- Operation *loop = genLoop(env, builder, at, needsUniv, tids, lvls, isFor);
+ Operation *loop =
+ genLoop(env, builder, at, needsUniv, tids, lvls, isSingleCond);
Location loc = env.op().getLoc();
for (auto [tid, lvl, exp] : llvm::zip(affineTids, affineLvls, affines)) {
env.emitter().genDenseAffineAddress(builder, loc, tid, lvl, exp);
@@ -1671,7 +1665,7 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
}
- return loop;
+ return std::make_pair(loop, isSingleCond);
}
/// Ends a single loop in current sequence. Returns new values for needsUniv.
@@ -1734,20 +1728,19 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
for (unsigned i = 0; i < lsize; i++) {
// Start a loop.
const LatPointId li = env.set(lts)[i];
- Operation *loop = startLoop(env, rewriter, at, li, needsUniv);
+ auto [loop, isSingleCond] = startLoop(env, rewriter, at, li, needsUniv);
// Visit all lattices points with Li >= Lj to generate the
// loop-body, possibly with if statements for coiteration.
Value redInput = env.getReduc();
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
- bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
for (unsigned j = 0; j < lsize; j++) {
const LatPointId lj = env.set(lts)[j];
const ExprId ej = env.lat(lj).exp;
if (li == lj || env.merger().latGT(li, lj)) {
// Recurse into body of each branch.
- if (isWhile) {
+ if (!isSingleCond) {
scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple);
genStmt(env, rewriter, ej, at + 1);
endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput);
@@ -1866,18 +1859,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (!isAdmissible)
return failure(); // inadmissible expression, reject
- for (OpOperand &t : env.op()->getOpOperands()) {
- Level rank = env.op().getMatchingIndexingMap(&t).getNumResults();
- for (Level lvl = 0; lvl < rank; lvl++) {
- sortArrayBasedOnOrder(
- env.merger().getDependentLoops(t.getOperandNumber(), lvl),
- env.getTopSort());
- }
- }
-
// Recursively generates code if admissible.
env.startEmit();
genBuffers(env, rewriter);
+ // TODO: Constant affine expression should be handled
diff erently when using
+ // slice-based codegen, it does not matter now becasue we already reject the
+ // constant expression at a earlier stage.
genInitConstantDenseAddress(env, rewriter);
genStmt(env, rewriter, env.getExprId(), 0);
genResult(env, rewriter);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 7f4400188cf14..40db5411132b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -362,7 +362,8 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
}
BitVector simple(latPoints[p0].bits);
- bool reset = isSingleton && hasAnySparse(simple);
+ bool reset =
+ isSingleton && (hasAnySparse(simple) || hasSparseIdxReduction(simple));
const TensorLoopId be = simple.size();
TensorLoopId offset = 0; // relative to the end
if (!reset)
@@ -379,7 +380,9 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// keep the rightmost bit (which could possibly be a synthetic tensor).
for (TensorLoopId b = be - 1 - offset, i = 0; i < be;
b = b == 0 ? be - 1 : b - 1, i++) {
- if (simple[b]) {
+ // FIXME: better name? also slice on dense level has locate property as
+ // well. Handle it correctly!
+ if (simple[b] && !isLvlWithNonTrivialIdxExp(b)) {
const auto dlt = getDimLevelType(b);
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) {
if (reset)
@@ -407,7 +410,7 @@ bool Merger::latGT(LatPointId i, LatPointId j) const {
bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
BitVector tmp(latPoints[j].bits);
tmp ^= latPoints[i].bits;
- return !hasAnySparse(tmp);
+ return !hasAnySparse(tmp) && !hasSparseIdxReduction(tmp);
}
bool Merger::expContainsTensor(ExprId e, TensorId t) const {
@@ -555,6 +558,14 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
return false;
}
+bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
+ // TODO: return false on dense levels.
+ for (unsigned b = 0, be = bits.size(); b < be; b++)
+ if (bits[b] && isLvlWithNonTrivialIdxExp(b))
+ return true;
+ return false;
+}
+
#ifndef NDEBUG
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list