[Mlir-commits] [mlir] c442025 - [mlir][sparse] support sparsification to coiterate operations. (#102546)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 20 11:13:41 PDT 2024
Author: Peiming Liu
Date: 2024-08-20T11:13:38-07:00
New Revision: c44202574ff9a8c0632aba30c2765b134557435f
URL: https://github.com/llvm/llvm-project/commit/c44202574ff9a8c0632aba30c2765b134557435f
DIFF: https://github.com/llvm/llvm-project/commit/c44202574ff9a8c0632aba30c2765b134557435f.diff
LOG: [mlir][sparse] support sparsification to coiterate operations. (#102546)
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6e17f804993e2a..2803223354d5ee 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1749,6 +1749,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
let results = (outs Variadic<AnyType>:$results);
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+ let builders = [
+ OpBuilder<(ins "ValueRange":$iterSpace, "ValueRange":$initArgs, "unsigned":$numCases)>,
+ ];
+
let extraClassDeclaration = [{
unsigned getSpaceDim() {
return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
@@ -1765,18 +1769,18 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
});
}
- // The block arguments starts with referenced coordinates, follows by
- // user-provided iteration arguments and ends with iterators.
+ // The block arguments starts with user-provided iteration arguments,
+ // follows by referenced coordinates and ends with iterators.
Block::BlockArgListType getCrds(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
- .take_front(getCrdUsedLvls().count());
+ .slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
}
- unsigned getNumRegionIterArgs(unsigned regionIdx) {
+ unsigned getNumRegionIterArgs() {
return getInitArgs().size();
}
Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
- .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
+ .take_front(getNumRegionIterArgs());
}
Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a284aa2f1f020f..a143189c301a43 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2293,9 +2293,10 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
return failure();
- if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ SmallVector<OpAsmParser::Argument> coords;
+ if (failed(parseUsedCoordList(parser, state, coords)))
return failure();
- size_t numCrds = blockArgs.size();
+ size_t numCrds = coords.size();
// Parse "iter_args(%arg = %init, ...)"
SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
@@ -2303,6 +2304,7 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
if (hasIterArgs)
if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
+ blockArgs.append(coords);
SmallVector<Type> iterSpaceTps;
// parse ": (sparse_tensor.iter_space, ...) -> ret"
@@ -2326,8 +2328,8 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
state.operands.append(spacesVals);
if (hasIterArgs) {
- // Strip off leading args that used for coordinates.
- MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+ // Strip off trailing args that used for coordinates.
+ MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
@@ -2602,6 +2604,24 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
regions.push_back(RegionSuccessor(getResults()));
}
+void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
+ ValueRange iterSpaces, ValueRange initArgs,
+ unsigned numCases) {
+ unsigned rank =
+ cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
+ // All ones.
+ I64BitSet set((1 << rank) - 1);
+ // Generates all-zero case bits (they only serve as placeholders), which are
+ // supposed to be overriden later. We need to preallocate all the regions as
+ // mlir::Region cannot be dynamically added later after the operation is
+ // created.
+ SmallVector<int64_t> caseBits(numCases, 0);
+ ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
+ return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
+ initArgs, set, cases,
+ /*caseRegionsCount=*/numCases);
+}
+
ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<Value> spaces;
@@ -2685,7 +2705,7 @@ ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
LogicalResult CoIterateOp::verifyRegions() {
for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
- if (getNumRegionIterArgs(r) != getNumResults())
+ if (getNumRegionIterArgs() != getNumResults())
return emitOpError(
"mismatch in number of basic block args and defined values");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5fb009e3eebe66..cc372ed1be6217 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1395,7 +1395,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
// Note that reduc will be taken care of by loop emitter and get updated
// in place.
- loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
+ loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
reduc);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 08fc104fcbeead..bf12dc8ae05cca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -842,11 +842,13 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
/// one sparse level in the list.
static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
ArrayRef<TensorLevel> tidLvls,
- bool tryParallel, bool needsUniv) {
+ unsigned numCases, bool tryParallel,
+ bool needsUniv) {
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
// Construct while-loop with a parameter for each index.
return env.emitter().enterCoIterationOverTensorsAtLvls(
- builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
+ builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
+ needsUniv);
});
assert(loop);
return loop;
@@ -855,9 +857,11 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
/// Generates a for-loop or a while-loop, depending on whether it implements
/// singleton iteration or co-iteration over the given conjunction.
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
- bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
+ unsigned numCases, bool needsUniv,
+ ArrayRef<TensorLevel> tidLvls) {
bool tryParallel = shouldTryParallize(env, curr, tidLvls);
- return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
+ return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
+ needsUniv);
}
/// Generates the induction structure for a while-loop.
@@ -900,6 +904,26 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
// basic block where scf::Yield should be inserted.
}
+/// Generates a case region in the coiterate operation.
+static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
+ unsigned caseIdx, LatPointId allCase,
+ LatPointId curCase,
+ MutableArrayRef<Value> reduc) {
+ assert(allCase == curCase || env.merger().latGT(allCase, curCase));
+ const BitVector &allCaseBits = env.merger().lat(allCase).simple;
+ const BitVector &curCaseBits = env.merger().lat(curCase).simple;
+
+ /// Computes the subset of iterators that are valid in the current case being
+ /// generated.
+ I64BitSet caseBit(0);
+ for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
+ if (curCaseBits.test(set))
+ caseBit.set(idx);
+
+ env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
+ caseIdx, reduc);
+}
+
/// Generates a single if-statement within a while-loop.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
LatPointId p) {
@@ -1175,7 +1199,10 @@ static bool translateBitsToTidLvlPairs(
/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
OpBuilder &builder, LoopId curr,
- LatPointId li, bool needsUniv) {
+ LatPointId li, unsigned numCases,
+ bool needsUniv) {
+ // TODO: numCases only used when generating iterator-based loops. Cleanup
+ // after fully migration.
// The set of tensors + lvls to generate loops on
SmallVector<TensorLevel> tidLvls;
@@ -1186,7 +1213,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
// Emit the for/while-loop control.
- Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
+ Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
Location loc = env.op().getLoc();
for (auto [tidLvl, exp] : affineTidLvls) {
env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
@@ -1259,42 +1286,73 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
// Start a loop sequence.
bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
- // Emit a loop for every lattice point L0 >= Li in this loop sequence.
- // We cannot change this to `for (const LatPointId li : env.set(lts))`
- // because the loop body causes data-movement which invalidates
- // the iterator.
+ // When using sparse-iterator-based loops, we only need one loops, as
+ // opposed to a loop sequence, to cover all the iterator spaces.
const unsigned lsize = env.set(lts).size();
- for (unsigned i = 0; i < lsize; i++) {
- const LatPointId li = env.set(lts)[i];
- // Start a loop.
- auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv);
-
- // Visit all lattices points with Li >= Lj to generate the
- // loop-body, possibly with if statements for coiteration.
- Value redInput = env.getReduc();
- Value cntInput = env.getExpandCount();
- Value insInput = env.getInsertionChain();
- Value validIns = env.getValidLexInsert();
- // We cannot change this to `for (const LatPointId lj : env.set(lts))`
- // because the loop body causes data-movement which invalidates the
- // iterator.
+ if (env.generatingSparseIterator()) {
+ // Get the largest lattice point and start a loop.
+ const LatPointId li = env.set(lts)[0];
+ auto [loop, isSingleCond] =
+ startLoop(env, rewriter, curr, li, lsize, needsUniv);
+ assert(isSingleCond == llvm::isa<IterateOp>(loop));
+ // We cannot change this to `for (const LatPointId li : env.set(lts))`
+ // because the loop body causes data-movement which invalidates
+ // the iterator.
for (unsigned j = 0; j < lsize; j++) {
const LatPointId lj = env.set(lts)[j];
const ExprId ej = env.lat(lj).exp;
- if (li == lj || env.merger().latGT(li, lj)) {
- // Recurse into body of each branch.
- if (!isSingleCond) {
- scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
- genStmt(env, rewriter, ej, curr + 1);
- endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
- } else {
+ // Recurse into body of each branch.
+ if (!isSingleCond) {
+ env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
+ genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
genStmt(env, rewriter, ej, curr + 1);
- }
+ // TODO: handle yield values.
+ assert(reduc.empty() && "Not Implemented");
+ rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
+ return std::nullopt;
+ });
+ // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
+ } else {
+ genStmt(env, rewriter, ej, curr + 1);
}
}
-
// End a loop.
needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
+ } else {
+ // Emit a loop for every lattice point L0 >= Li in this loop sequence.
+ for (unsigned i = 0; i < lsize; i++) {
+ const LatPointId li = env.set(lts)[i];
+ // Start a loop.
+ auto [loop, isSingleCond] =
+ startLoop(env, rewriter, curr, li, lsize, needsUniv);
+
+ // Visit all lattices points with Li >= Lj to generate the
+ // loop-body, possibly with if statements for coiteration.
+ Value redInput = env.getReduc();
+ Value cntInput = env.getExpandCount();
+ Value insInput = env.getInsertionChain();
+ Value validIns = env.getValidLexInsert();
+ // We cannot change this to `for (const LatPointId lj : env.set(lts))`
+ // because the loop body causes data-movement which invalidates the
+ // iterator.
+ for (unsigned j = 0; j < lsize; j++) {
+ const LatPointId lj = env.set(lts)[j];
+ const ExprId ej = env.lat(lj).exp;
+ if (li == lj || env.merger().latGT(li, lj)) {
+ // Recurse into body of each branch.
+ if (!isSingleCond) {
+ scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
+ genStmt(env, rewriter, ej, curr + 1);
+ endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
+ } else {
+ genStmt(env, rewriter, ej, curr + 1);
+ }
+ }
+ }
+
+ // End a loop.
+ needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
+ }
}
// End a loop sequence.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index d69ae53fb0f298..34b793ee11e4ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -49,6 +49,10 @@ class CodegenEnv {
linalg::GenericOp op() const { return linalgOp; }
const SparsificationOptions &options() const { return sparseOptions; }
+ bool generatingSparseIterator() const {
+ return sparseOptions.sparseEmitStrategy ==
+ SparseEmitStrategy::kSparseIterator;
+ }
Merger &merger() { return latticeMerger; }
LoopEmitter &emitter() { return loopEmitter; }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 2be0193f0de83e..efb3295fb2a4bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -615,33 +615,106 @@ bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
return true;
}
+Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder,
+ Location loc,
+ I64BitSet caseBit,
+ unsigned caseIdx,
+ MutableArrayRef<Value> reduc) {
+ auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
+ SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>());
+ cases[caseIdx] = builder.getI64IntegerAttr(caseBit);
+
+ coIterOp.setCasesAttr(builder.getArrayAttr(cases));
+ Region &caseRegion = coIterOp.getRegion(caseIdx);
+ assert(caseRegion.getBlocks().empty() &&
+ "re-initialize the same coiteration case region.");
+
+ // Each block starts with by a list of user-provided iteration arguments.
+ TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
+ // Followed by a list of used coordinates of index type.
+ SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(),
+ builder.getIndexType());
+
+ blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
+ // Ends with a set of iterators that defines the actually iteration space.
+ for (auto i : caseBit.bits()) {
+ blockArgTps.push_back(
+ cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType())
+ .getIteratorType());
+ }
+ SmallVector<Location> locs(blockArgTps.size(), loc);
+ caseRegion.emplaceBlock().addArguments(blockArgTps, locs);
+
+ // Entering the new region scope, updating the SSA chain.
+ builder.setInsertionPointToStart(&caseRegion.front());
+ // Update the coordinates.
+ loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
+ // Updates loop iteration arguments.
+ ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
+ llvm::copy(iterArgs, reduc.begin());
+ // Updates sparse iterator values.
+ ValueRange iters = coIterOp.getRegionIterators(caseIdx);
+ ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls;
+ for (auto [i, tl] : llvm::enumerate(unpackTensorLevelRange(tidLvls))) {
+ if (caseBit[i]) {
+ spIterVals[tl.first][tl.second] = iters.front();
+ iters = iters.drop_front();
+ } else {
+ spIterVals[tl.first][tl.second] = nullptr;
+ }
+ }
+ // Must have consumed all iterator SSA values.
+ assert(iters.empty());
+ return &caseRegion;
+}
+
Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
- MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
-
+ unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel,
+ bool needsUniv) {
+ // TODO: Argument `numCases` only used when generating iterator-based sparse
+ // loops. Simplify the code upon feature complete.
// TODO: handle coiteration with sparse iterator.
if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
- assert(tidLvls.size() == 1);
- auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
- Value t = tensors[tid];
-
- // Extract and iterate over the iteration space.
- ExtractIterSpaceOp extractSpaceOp =
- lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
- : builder.create<ExtractIterSpaceOp>(
- loc, t, spIterVals[tid][lvl - 1], lvl);
-
- IterateOp iterOp = builder.create<IterateOp>(
- loc, extractSpaceOp.getExtractedSpace(), reduc);
- spIterVals[tid][lvl] = iterOp.getIterator();
+ if (tidLvls.size() == 1) {
+ auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
+ Value t = tensors[tid];
+
+ // Extract and iterate over the iteration space.
+ ExtractIterSpaceOp extractSpaceOp =
+ lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
+ : builder.create<ExtractIterSpaceOp>(
+ loc, t, spIterVals[tid][lvl - 1], lvl);
+
+ IterateOp iterOp = builder.create<IterateOp>(
+ loc, extractSpaceOp.getExtractedSpace(), reduc);
+ spIterVals[tid][lvl] = iterOp.getIterator();
+
+ // Update the reduction varaibles.
+ llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
+ // Set the insertion point to loop body.
+ builder.setInsertionPointToStart(iterOp.getBody());
+ loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
+ iterOp.getCrds().front(), loopTag);
+ return iterOp;
+ }
- // Update the reduction varaibles.
- llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
- // Set the insertion point to loop body.
- builder.setInsertionPointToStart(iterOp.getBody());
- loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
- iterOp.getIterator(), loopTag);
- return iterOp;
+ // CoIteration Loops.
+ SmallVector<Value> spaces;
+ for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
+ Value t = tensors[tid];
+ ExtractIterSpaceOp extractSpaceOp =
+ lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
+ : builder.create<ExtractIterSpaceOp>(
+ loc, t, spIterVals[tid][lvl - 1], lvl);
+ spaces.push_back(extractSpaceOp.getExtractedSpace());
+ }
+ auto coIterOp = builder.create<CoIterateOp>(loc, spaces, reduc, numCases);
+ // The CoIterationOp does not have insertion block nor induction variable.
+ // TODO: the `struct LoopInfo` should be simplied after full migration.
+ loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr,
+ /*induction variable*/ nullptr, loopTag);
+ return coIterOp;
}
// TODO: support multiple return on parallel for?
@@ -866,6 +939,18 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
// Clean up the values, it would help use to discover potential bug at a
// earlier stage (instead of silently using a wrong value).
const LoopInfo &loopInfo = loopStack.back();
+ if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
+ Operation *p = loopInfo.loop;
+ if (isa<IterateOp>(p))
+ rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
+
+ // Exit the loop.
+ rewriter.setInsertionPointAfter(p);
+ // In-place update reduction variables.
+ llvm::copy(p->getResults(), reduc.begin());
+ loopStack.pop_back();
+ return;
+ }
// Sets the insertion point to the right position.
rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index f3e73e4692c1fd..a9eb888c8b6bec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -145,8 +145,12 @@ class LoopEmitter {
/// return the reduction variable used inside the generated loop.
Operation *enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
- MutableArrayRef<Value> reduc = {}, bool isParallel = false,
- bool needsUniv = false);
+ unsigned numCases, MutableArrayRef<Value> reduc = {},
+ bool isParallel = false, bool needsUniv = false);
+
+ Region *enterCurrentCoIterationCase(OpBuilder &builder, Location loc,
+ I64BitSet caseBit, unsigned caseIdx,
+ MutableArrayRef<Value> reduc);
/// Generates code to exit the current loop (e.g., generates yields, forwards
/// loop induction variables, etc).
@@ -260,9 +264,9 @@ class LoopEmitter {
// required for levels with non-tivial index expressions, which is
// maintained by the sliceDrivenInfo array below.
const llvm::SmallVector<TensorLevel> tidLvls;
- const Operation *loop; // the loop operation
+ Operation *loop; // the loop operation
Block *const userCodeBlock; // the block holding users' generated code.
- const Value iv; // the induction variable for the loop
+ Value iv; // the induction variable for the loop
};
void categorizeIterators(ArrayRef<TensorLevel> tidLvls,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index 268b3940418b71..2487156a9a2e48 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -1,4 +1,8 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER"
+
+// TODO: temporarilly disabled since there is no lowering rules from `coiterate` to `scf`.
+// R_U_N: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
+
#COO = #sparse_tensor.encoding<{
@@ -10,13 +14,18 @@
)
}>
+#VEC = #sparse_tensor.encoding<{
+ map = (d0) -> (d0 : compressed)
+}>
+
+
// CHECK-LABEL: func.func @sqsum(
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse> to memref<?xindex>
+// CHECK-DAG: %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xindex>
// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK: %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse> to memref<?xi32>
+// CHECK: %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xi32>
// CHECK: %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} {
// CHECK: %[[VAL:.*]] = memref.load %[[VAL_BUF]]{{\[}}%[[POS]]] : memref<?xi32>
// CHECK: %[[MUL:.*]] = arith.muli %[[VAL]], %[[VAL]] : i32
@@ -27,6 +36,12 @@
// CHECK: %[[RET:.*]] = bufferization.to_tensor
// CHECK: return %[[RET]] : tensor<i32>
// CHECK: }
+
+// ITER-LABEL: func.func @sqsum(
+// ITER: sparse_tensor.iterate
+// ITER: sparse_tensor.iterate
+// ITER: sparse_tensor.iterate
+// ITER: }
func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
%cst = arith.constant dense<0> : tensor<i32>
%0 = linalg.generic {
@@ -43,3 +58,42 @@ func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
} -> tensor<i32>
return %0 : tensor<i32>
}
+
+
+// ITER-LABEL: func.func @add(
+// ITER: sparse_tensor.coiterate
+// ITER: case %[[IT_1:.*]], %[[IT_2:.*]] {
+// ITER: %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]]
+// ITER: %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]]
+// ITER: %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32
+// ITER: memref.store %[[SUM]]
+// ITER: }
+// ITER: case %[[IT_1:.*]], _ {
+// ITER: %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]]
+// ITER: memref.store %[[LHS]]
+// ITER: }
+// ITER: case _, %[[IT_2:.*]] {
+// ITER: %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]]
+// ITER: memref.store %[[RHS]]
+// ITER: }
+// ITER: bufferization.to_tensor
+// ITER: return
+// ITER: }
+func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> {
+ %cst = arith.constant dense<0> : tensor<10xi32>
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>
+ ],
+ iterator_types = ["parallel"]
+ }
+ ins(%arg0, %arg1 : tensor<10xi32, #VEC>, tensor<10xi32, #VEC>)
+ outs(%cst : tensor<10xi32>) {
+ ^bb0(%in1: i32, %in2: i32, %out: i32):
+ %2 = arith.addi %in1, %in2 : i32
+ linalg.yield %2 : i32
+ } -> tensor<10xi32>
+ return %0 : tensor<10xi32>
+}
More information about the Mlir-commits
mailing list