[Mlir-commits] [mlir] [mlir][sparse] support sparsifying sparse kernels to sparse-iterator-based loop (PR #95858)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 17 15:16:18 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Patch is 195.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95858.diff
40 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+1-1)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+24-4)
- (modified) mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h (+16-1)
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+1)
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+3-1)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+38)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+2-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+3-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+6)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+96-41)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+7)
- (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+7-7)
- (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+20-20)
- (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+17-17)
- (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+15-15)
- (modified) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+11-11)
- (modified) mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir (+6-6)
- (modified) mlir/test/Dialect/SparseTensor/sparse_concat.mlir (+50-50)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+10-10)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+44-44)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+50-50)
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+6-6)
- (added) mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir (+43)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+13-13)
- (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir (+12-12)
- (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_out.mlir (+30-30)
- (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+17-17)
- (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+5-5)
- (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+24-24)
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+8-8)
- (modified) mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir (+6-6)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+7-7)
- (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+10-10)
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir (+8-8)
- (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+42-42)
- (added) mlir/test/Integration/Dialect/SparseTensor/CPU/iterator-based-sqsum.mlir (+80)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 04a6386a199de..68ca036121520 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,7 +89,7 @@ class LevelSet {
assert(i < 64);
return (bits & (1 << i)) != 0;
}
-
+ unsigned max() const { return 64 - llvm::countl_zero(bits); }
unsigned count() const { return llvm::popcount(bits); }
bool empty() const { return bits == 0; }
};
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index b2089924291cd..f31df080d7811 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1493,6 +1493,10 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
```
}];
+ let arguments = (ins AnySparseTensor:$tensor,
+ Optional<AnySparseIterator>:$parentIter,
+ LevelAttr:$loLvl, LevelAttr:$hiLvl);
+ let results = (outs AnySparseIterSpace:$extractedSpace);
let extraClassDeclaration = [{
std::pair<Level, Level> getLvlRange() {
@@ -1506,10 +1510,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
}
}];
- let arguments = (ins AnySparseTensor:$tensor,
- Optional<AnySparseIterator>:$parentIter,
- LevelAttr:$loLvl, LevelAttr:$hiLvl);
- let results = (outs AnySparseIterSpace:$extractedSpace);
+ let builders = [
+ // Construct a 1-D iteration space.
+ OpBuilder<(ins "Value":$tensor, "Value":$parentIter,
+ "sparse_tensor::Level":$loLvl),
+ [{
+ build($_builder, $_state, tensor, parentIter, loLvl, loLvl + 1);
+ }]>,
+ // Construct a 1-D root iteration space
+ OpBuilder<(ins "Value":$tensor),
+ [{
+ build($_builder, $_state, tensor, nullptr, 0);
+ }]>
+ ];
+
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
"`->` qualified(type($extractedSpace))";
@@ -1594,6 +1608,12 @@ def IterateOp : SparseTensor_Op<"iterate",
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
+ OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
+ ];
+
let extraClassDeclaration = [{
unsigned getSpaceDim() {
return getIterSpace().getType().getSpaceDim();
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 85bfee336f848..90021ffa7c380 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -51,6 +51,20 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
"any-storage-any-loop",
"Enable sparse parallelization for any storage and loop."))};
+ PassOptions::Option<mlir::SparseEmitStrategy> emitStrategy{
+ *this, "sparse-emit-strategy",
+ ::llvm::cl::desc(
+ "Emit functional code or interfaces (to debug) for sparse loops"),
+ ::llvm::cl::init(mlir::SparseEmitStrategy::kFunctional),
+ llvm::cl::values(
+ clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
+ "Emit functional code (with scf.for/while)."),
+ clEnumValN(mlir::SparseEmitStrategy::kSparseIterator,
+ "sparse-iterator",
+ "Emit (experimental) loops (with sparse.iterate)."),
+ clEnumValN(
+ mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
+ "Emit non-functional but easy-to-read interfaces to debug."))};
PassOptions::Option<bool> enableRuntimeLibrary{
*this, "enable-runtime-library",
@@ -143,7 +157,8 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
/// Projects out the options for `createSparsificationPass`.
SparsificationOptions sparsificationOptions() const {
- return SparsificationOptions(parallelization, enableRuntimeLibrary);
+ return SparsificationOptions(parallelization, emitStrategy,
+ enableRuntimeLibrary);
}
/// Projects out the options for `createConvertVectorToLLVMPass`.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index c9164e39a3a75..2edb9a61d1876 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -51,6 +51,7 @@ enum class ReinterpretMapScope {
/// Defines a scope for reinterpret map pass.
enum class SparseEmitStrategy {
kFunctional, // generate fully inlined (and functional) sparse iteration
+ kSparseIterator, // generate (experimental) loop using sparse iterator.
kDebugInterface, // generate only place-holder for sparse iteration
};
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index b18c975105b75..7173cde67384b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -163,7 +163,9 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
"mlir::SparseEmitStrategy::kFunctional",
"Emit functional code or interfaces (to debug) for sparse loops", [{llvm::cl::values(
clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
- "Emit functional code."),
+ "Emit functional code (with scf.for/while)."),
+ clEnumValN(mlir::SparseEmitStrategy::kSparseIterator, "sparse-iterator",
+ "Emit (experimental) loops (with sparse.iterate)."),
clEnumValN(mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
"Emit non-functional but easy-to-read interfaces to debug."))}]>,
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ac711769ed2ea..504888be4c85f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2300,6 +2300,41 @@ void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
results.add<RemoveUnusedLvlCrds>(context);
}
+void IterateOp::build(OpBuilder &builder, OperationState &odsState,
+ Value iterSpace, ValueRange initArgs) {
+ unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
+ // All ones.
+ LevelSet set((1 << rank) - 1);
+ return build(builder, odsState, iterSpace, initArgs, set);
+}
+
+void IterateOp::build(OpBuilder &builder, OperationState &odsState,
+ Value iterSpace, ValueRange initArgs,
+ LevelSet crdUsedLvls) {
+ OpBuilder::InsertionGuard guard(builder);
+
+ odsState.addOperands(iterSpace);
+ odsState.addOperands(initArgs);
+ odsState.getOrAddProperties<Properties>().crdUsedLvls =
+ builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
+ Region *bodyRegion = odsState.addRegion();
+ odsState.addTypes(initArgs.getTypes());
+ Block *bodyBlock = builder.createBlock(bodyRegion);
+
+ // 1st args, sparse iterator
+ bodyBlock->addArgument(
+ llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
+ odsState.location);
+
+ // Followed by a list of used coordinates.
+ for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
+ bodyBlock->addArgument(builder.getIndexType(), odsState.location);
+
+ // Followed by a list of user-provided loop arguments.
+ for (Value v : initArgs)
+ bodyBlock->addArgument(v.getType(), v.getLoc());
+}
+
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument iterator;
OpAsmParser::UnresolvedOperand iterSpace;
@@ -2384,6 +2419,9 @@ LogicalResult IterateOp::verify() {
return emitOpError(
"mismatch in number of loop-carried values and defined values");
}
+ if (getCrdUsedLvls().max() > getSpaceDim())
+ return emitOpError("required out-of-bound coordinates");
+
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 4224925147c84..1d614b7b29361 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -164,7 +164,6 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
// replace sparse_tensor.yield with scf.yield.
rewriter.eraseOp(yieldOp);
rewriter.create<scf::YieldOp>(loc, yields);
-
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(
op, whileOp.getResults().drop_front(it->getCursor().size()),
@@ -192,6 +191,8 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
void mlir::populateLowerSparseIterationToSCFPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {
+
+ IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0c8e431d8c996..e8f85aab887ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1071,7 +1071,9 @@ static bool getAllTidLvlsInLatPoints(
}
// If we just need to one loop conditions and the conditions is not imposed on
// non-unique level, the loop can be generated by a for loop.
- return numloopCond == 1 && !hasNonUnique;
+ return numloopCond == 1 &&
+ (!hasNonUnique || env.options().sparseEmitStrategy ==
+ SparseEmitStrategy::kSparseIterator);
}
/// Starts a loop sequence at given level. Returns true if
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 13c750e83d045..3ae6732f900fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -159,6 +159,12 @@ class SparsificationAndBufferizationPass
pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
pm.addPass(createSparsificationPass(sparsificationOptions));
+ if (sparsificationOptions.sparseEmitStrategy ==
+ SparseEmitStrategy::kSparseIterator) {
+ pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass());
+ pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass());
+ }
+
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
/*enableConvert=*/true));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index fe0e515a2d180..2be0193f0de83 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -139,6 +139,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->valBuffer.assign(numTensors, nullptr);
this->lvls.resize(numTensors);
this->iters.resize(numTensors);
+ this->spIterVals.resize(numTensors);
// These zeros will be overwritten below, but we need to initialize
// them to something since we'll need random-access assignment.
@@ -173,6 +174,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
lvls[tid].resize(lvlRank);
iters[tid].resize(lvlRank);
+ spIterVals[tid].resize(lvlRank);
loopHighs.assign(numLoops, nullptr);
// Slice-driven loops related initialization.
@@ -229,6 +231,57 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
void LoopEmitter::initializeLoopEmit(
OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
LoopEmitter::SynTensorBoundSetter synSetter) {
+
+ // For every manifest tensor, set up the values buffer.
+ for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
+ t++) {
+ // TODO: this should be done through a folding pass after switching to
+ // `sparse_tensor.iterate`-based sparsification.
+ const Value tensor = tryFoldTensors(tensors[t]);
+ const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
+ // Skips only scalar, zero ranked tensor still need to be bufferized and
+ // (probably) filled with zeros by users.
+ if (!rtp)
+ continue;
+
+ auto stt = getSparseTensorType(tensor);
+ const auto shape = rtp.getShape();
+
+ // Perform the required bufferization. Dense inputs materialize from the
+ // input tensors. Sparse inputs use sparse primitives to obtain the values.
+ // Delegates extra output initialization to clients.
+ bool isOutput = isOutputTensor(t);
+ Type elementType = stt.getElementType();
+ if (!stt.hasEncoding()) {
+ // Non-annotated dense tensors.
+ BaseMemRefType denseTp = MemRefType::get(shape, elementType);
+
+ // TODO: if we unconditionally use fully dynamic layout here, it breaks
+ // some vectorization passes which requires static stride = 1.
+ // Is it possible to call vectorization pass after bufferization?
+ if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.getDefiningOp()))
+ denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp);
+
+ Value denseVal =
+ builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
+ // Dense outputs need special handling.
+ if (isOutput && updater)
+ denseVal = updater(builder, loc, denseVal, tensor);
+
+ valBuffer[t] = denseVal;
+ } else {
+ // Annotated sparse tensors.
+ // We also need the value buffer for all-dense annotated "sparse"
+ // tensors.
+ valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
+ }
+ }
+
+ // The sparse iterator values will only be available after the loop is
+ // constructed.
+ if (emitStrategy == SparseEmitStrategy::kSparseIterator)
+ return;
+
// For every synthetic tensor, set the high bound by calling the callback.
if (synSetter) {
TensorId synId = getSynTensorId();
@@ -241,7 +294,6 @@ void LoopEmitter::initializeLoopEmit(
}
// For every manifest tensor:
- // * get the values buffer.
// * For every level:
// * get the positions and coordinates buffers
// * get/compute the level-size, which is also used as the upper-bound
@@ -256,12 +308,9 @@ void LoopEmitter::initializeLoopEmit(
// Skips only scalar, zero ranked tensor still need to be bufferized and
// (probably) filled with zeros by users.
continue;
- // FIXME: the definition of `lvlRank` looks more like a dim-rank;
- // but the variable is used as a level everywhere below, which
- // suggests there may be some dim/lvl confusion going on here.
+
auto stt = getSparseTensorType(tensor);
const Level lvlRank = stt.getLvlRank();
- const auto shape = rtp.getShape();
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
@@ -273,36 +322,6 @@ void LoopEmitter::initializeLoopEmit(
auto it = makeLevelIterator(builder, loc, t, l);
iters[t][l].emplace_back(std::move(it));
}
-
- // Perform the required bufferization. Dense inputs materialize
- // from the input tensors. Sparse inputs use sparse primitives to obtain the
- // values.
- // Delegates extra output initialization to clients.
- bool isOutput = isOutputTensor(t);
- Type elementType = stt.getElementType();
- if (!stt.hasEncoding()) {
- // Non-annotated dense tensors.
- BaseMemRefType denseTp = MemRefType::get(shape, elementType);
-
- // TODO: if we unconditionally use fully dynamic layout here, it breaks
- // some vectorization passes which requires static stride = 1.
- // Is it possible to call vectorization pass after bufferization?
- if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.getDefiningOp()))
- denseTp = bufferization::getMemRefTypeWithFullyDynamicLayout(rtp);
-
- Value denseVal =
- builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
- // Dense outputs need special handling.
- if (isOutput && updater)
- denseVal = updater(builder, loc, denseVal, tensor);
-
- valBuffer[t] = denseVal;
- } else {
- // Annotated sparse tensors.
- // We also need the value buffer for all-dense annotated "sparse"
- // tensors.
- valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
- }
// NOTE: we can also prepare for 0 lvl here in advance, this will hoist
// some loop preparation from tensor iteration, but will also (undesirably)
// hoist the code ouside if-conditions.
@@ -396,11 +415,13 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
ArrayRef<TensorLevel> tidLvls) {
// TODO: sort
assert(loopSeqStack.size() == loopStack.size());
- // Prepares for all the tensors used in the current loop sequence.
- for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
- levelReducedDep[tid][lvl]++;
- prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+ if (emitStrategy != SparseEmitStrategy::kSparseIterator) {
+ // Prepares for all the tensors used in the current loop sequence.
+ for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
+ levelReducedDep[tid][lvl]++;
+ prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+ }
}
// Universal Index starts from 0.
@@ -598,6 +619,31 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
+ // TODO: handle coiteration with sparse iterator.
+ if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
+ assert(tidLvls.size() == 1);
+ auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
+ Value t = tensors[tid];
+
+ // Extract and iterate over the iteration space.
+ ExtractIterSpaceOp extractSpaceOp =
+ lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
+ : builder.create<ExtractIterSpaceOp>(
+ loc, t, spIterVals[tid][lvl - 1], lvl);
+
+ IterateOp iterOp = builder.create<IterateOp>(
+ loc, extractSpaceOp.getExtractedSpace(), reduc);
+ spIterVals[tid][lvl] = iterOp.getIterator();
+
+ // Update the reduction varaibles.
+ llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
+ // Set the insertion point to loop body.
+ builder.setInsertionPointToStart(iterOp.getBody());
+ loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
+ iterOp.getIterator(), loopTag);
+ return iterOp;
+ }
+
// TODO: support multiple return on parallel for?
tryParallel = tryParallel && reduc.size() <= 1;
@@ -685,6 +731,16 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
void LoopEmitter::exitFo...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/95858
More information about the Mlir-commits
mailing list