[Mlir-commits] [mlir] 781eabe - [mlir][sparse] refactoring loop emitter into its own files.
Peiming Liu
llvmlistbot at llvm.org
Tue Dec 27 11:12:11 PST 2022
Author: Peiming Liu
Date: 2022-12-27T19:12:05Z
New Revision: 781eabeb40b8e47e3a46b0b927784e63f0aad9ab
URL: https://github.com/llvm/llvm-project/commit/781eabeb40b8e47e3a46b0b927784e63f0aad9ab
DIFF: https://github.com/llvm/llvm-project/commit/781eabeb40b8e47e3a46b0b927784e63f0aad9ab.diff
LOG: [mlir][sparse] refactoring loop emitter into its own files.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D140701
Added:
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 410bf343b8fc..8107f2472537 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
BufferizableOpInterfaceImpl.cpp
CodegenEnv.cpp
CodegenUtils.cpp
+ LoopEmitter.cpp
SparseBufferRewriting.cpp
SparseStorageSpecifierToLLVM.cpp
SparseTensorCodegen.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 13109137c007..60d356f3dc62 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -24,8 +24,7 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
expFilled(), expAdded(), expCount(), redVal(), redExp(-1u),
redCustom(-1u) {}
-void CodegenEnv::startEmit(OpOperand *so, unsigned lv,
- SparseTensorLoopEmitter *le) {
+void CodegenEnv::startEmit(OpOperand *so, unsigned lv, LoopEmitter *le) {
assert(sparseOut == nullptr && loopEmitter == nullptr &&
insChain == nullptr && "must only start emitting once");
sparseOut = so;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 47ce70c0637c..cb483aa97b15 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENENV_H_
#include "CodegenUtils.h"
+#include "LoopEmitter.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -45,9 +46,9 @@ class CodegenEnv {
linalg::GenericOp op() const { return linalgOp; }
const SparsificationOptions &options() const { return sparseOptions; }
Merger &merger() { return latticeMerger; }
- SparseTensorLoopEmitter *emitter() { return loopEmitter; }
+ LoopEmitter *emitter() { return loopEmitter; }
- void startEmit(OpOperand *so, unsigned lv, SparseTensorLoopEmitter *le);
+ void startEmit(OpOperand *so, unsigned lv, LoopEmitter *le);
/// Generates loop boundary statements (entering/exiting loops). The function
/// passes and updates the passed-in parameters.
@@ -135,7 +136,7 @@ class CodegenEnv {
// Loop emitter helper class (keep reference in scope!).
// TODO: move emitter constructor up in time?
- SparseTensorLoopEmitter *loopEmitter;
+ LoopEmitter *loopEmitter;
// Topological sort.
std::vector<unsigned> topSort;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index e3ab5ce1d040..83df60ed097f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -21,25 +21,6 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
-/// Generates a pointer/index load from the sparse storage scheme. Narrower
-/// data types need to be zero extended before casting the value into the
-/// index type used for looping and indexing.
-static Value genIndexLoad(OpBuilder &builder, Location loc, Value ptr,
- Value s) {
- // For the scalar case, we simply zero extend narrower indices into 64-bit
- // values before casting to index without a performance penalty. Here too,
- // however, indices that already are 64-bit, in theory, cannot express the
- // full range as explained above.
- Value load = builder.create<memref::LoadOp>(loc, ptr, s);
- if (!load.getType().isa<IndexType>()) {
- if (load.getType().getIntOrFloatBitWidth() < 64)
- load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
- load =
- builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
- }
- return load;
-}
-
/// If the tensor is a sparse constant, generates and returns the pair of
/// the constants for the indices and the values.
static Optional<std::pair<Value, Value>>
@@ -90,652 +71,6 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
return val;
}
-//===----------------------------------------------------------------------===//
-// Sparse tensor loop emitter class implementations
-//===----------------------------------------------------------------------===//
-
-SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors,
- StringAttr loopTag,
- bool hasOutput,
- bool isSparseOut,
- ArrayRef<unsigned> topSort) {
- initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
-}
-
-void SparseTensorLoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
- bool hasOutput, bool isSparseOut,
- ArrayRef<unsigned> topSort) {
- // First initializes fields.
- this->loopTag = loopTag;
- this->hasOutput = hasOutput;
- this->isSparseOut = isSparseOut;
- this->tensors.assign(tensors.begin(), tensors.end());
- this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
- this->pidxs.assign(tensors.size(), std::vector<Value>());
- this->coord.assign(tensors.size(), std::vector<Value>());
- this->highs.assign(tensors.size(), std::vector<Value>());
- this->ptrBuffer.assign(tensors.size(), std::vector<Value>());
- this->idxBuffer.assign(tensors.size(), std::vector<Value>());
- this->valBuffer.assign(tensors.size(), nullptr);
- this->loopStack.reserve(topSort.size());
- this->sparsiferLoopLvlMap.assign(topSort.size(), 0);
-
- for (size_t tid = 0, e = tensors.size(); tid < e; tid++) {
- auto t = tensors[tid];
- // a scalar or 0-dimension tensors
- if (isZeroRankedTensorOrScalar(t.getType()))
- continue;
- auto rtp = t.getType().cast<RankedTensorType>();
- auto rank = static_cast<size_t>(rtp.getRank());
- auto enc = getSparseTensorEncoding(rtp);
- // We always treat sparse output tensor as dense so that we always iterate
- // it based on dim size.
- if (enc && !(isOutputTensor(tid) && isSparseOut))
- for (auto dimTp : enc.getDimLevelType())
- dimTypes[tid].push_back(dimTp);
- else
- dimTypes[tid].assign(rank, DimLevelType::Dense);
-
- // Initialize using empty value.
- pidxs[tid].assign(rank, Value());
- coord[tid].assign(rank, Value());
- highs[tid].assign(rank, Value());
- ptrBuffer[tid].assign(rank, Value());
- idxBuffer[tid].assign(rank, Value());
- }
-
- // FIXME: This map should be maintained outside loop emitter.
- for (unsigned i = 0, e = topSort.size(); i < e; i++) {
- // This is an inverse map of the topologically sorted loop index from
- // sparsifier. This is needed to map the AffineDimExpr back to the loopStack
- // index used in loop emitter.
- sparsiferLoopLvlMap[topSort[i]] = i;
- }
-}
-
-void SparseTensorLoopEmitter::initializeLoopEmit(
- OpBuilder &builder, Location loc,
- SparseTensorLoopEmitter::OutputUpdater updater) {
- // For every tensor, find lower and upper bound on dimensions, set the
- // same bounds on loop indices, and obtain dense or sparse buffer(s).
- for (size_t t = 0, e = tensors.size(); t < e; t++) {
- auto tensor = tensors[t];
- auto rtp = tensor.getType().dyn_cast<RankedTensorType>();
- if (!rtp)
- // Skips only scalar, zero ranked tensor still need to be bufferized and
- // (probably) filled with zeros by users.
- continue;
- auto rank = rtp.getRank();
- auto shape = rtp.getShape();
- auto enc = getSparseTensorEncoding(rtp);
- auto dynShape = {ShapedType::kDynamic};
- // Scan all dimensions of current tensor.
- for (int64_t d = 0; d < rank; d++) {
- // This should be called only once at beginning.
- assert(!ptrBuffer[t][d] && !idxBuffer[t][d] && !highs[t][d]);
- // Handle sparse storage schemes.
- if (isCompressedDLT(dimTypes[t][d])) {
- auto ptrTp =
- MemRefType::get(dynShape, getPointerOverheadType(builder, enc));
- auto indTp =
- MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
- auto dim = builder.getIndexAttr(d);
- // Generate sparse primitives to obtains pointer and indices.
- ptrBuffer[t][d] = builder.create<ToPointersOp>(loc, ptrTp, tensor, dim);
- idxBuffer[t][d] = builder.create<ToIndicesOp>(loc, indTp, tensor, dim);
- } else if (isSingletonDLT(dimTypes[t][d])) {
- // Singleton dimension, fetch indices.
- auto indTp =
- MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
- auto dim = builder.getIndexAttr(d);
- idxBuffer[t][d] = builder.create<ToIndicesOp>(loc, indTp, tensor, dim);
- } else {
- // Dense dimension, nothing to fetch.
- assert(isDenseDLT(dimTypes[t][d]));
- }
-
- // Find upper bound in current dimension.
- unsigned p = toOrigDim(enc, d);
- Value up = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, p);
- highs[t][d] = up;
- }
-
- // Perform the required bufferization. Dense inputs materialize
- // from the input tensors. Sparse inputs use sparse primitives to obtain the
- // values.
- // Delegates extra output initialization to clients.
- bool isOutput = isOutputTensor(t);
- Type elementType = rtp.getElementType();
- if (!enc) {
- // Non-annotated dense tensors.
- auto denseTp = MemRefType::get(shape, elementType);
- Value denseVal =
- builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
- // Dense outputs need special handling.
- if (isOutput && updater)
- denseVal = updater(builder, loc, denseVal, tensor);
-
- valBuffer[t] = denseVal;
- } else {
- // Annotated sparse tensors.
- // We also need the value buffer for annotated all dense `sparse` tensor.
- auto dynShape = {ShapedType::kDynamic};
- auto sparseTp = MemRefType::get(dynShape, elementType);
- valBuffer[t] = builder.create<ToValuesOp>(loc, sparseTp, tensor);
- }
- // NOTE: we can also prepare for 0 dim here in advance, this will hosit
- // some loop preparation from tensor iteration, but will also (undesirably)
- // hosit the code ouside if conditions.
- }
-}
-
-void SparseTensorLoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
- ArrayRef<size_t> tids,
- ArrayRef<size_t> dims) {
- // Universal Index start from 0
- assert(loopSeqStack.size() == loopStack.size());
- // Universal index starts from 0
- loopSeqStack.emplace_back(constantIndex(builder, loc, 0));
- // Prepares for all the tensors used in the current loop sequence.
- for (auto [tid, dim] : llvm::zip(tids, dims))
- prepareLoopOverTensorAtDim(builder, loc, tid, dim);
-}
-
-Value SparseTensorLoopEmitter::genAffine(OpBuilder &builder, AffineExpr a,
- Location loc) {
- switch (a.getKind()) {
- case AffineExprKind::DimId: {
- unsigned idx = a.cast<AffineDimExpr>().getPosition();
- return loopStack[sparsiferLoopLvlMap[idx]].iv;
- }
- case AffineExprKind::Add: {
- auto binOp = a.cast<AffineBinaryOpExpr>();
- return builder.create<arith::AddIOp>(
- loc, genAffine(builder, binOp.getLHS(), loc),
- genAffine(builder, binOp.getRHS(), loc));
- }
- case AffineExprKind::Mul: {
- auto binOp = a.cast<AffineBinaryOpExpr>();
- return builder.create<arith::MulIOp>(
- loc, genAffine(builder, binOp.getLHS(), loc),
- genAffine(builder, binOp.getRHS(), loc));
- }
- case AffineExprKind::Constant: {
- int64_t c = a.cast<AffineConstantExpr>().getValue();
- return constantIndex(builder, loc, c);
- }
- default:
- llvm_unreachable("unexpected affine subscript");
- }
-}
-
-Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
- OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims, MutableArrayRef<Value> reduc, bool isParallel) {
- // TODO: support multiple return on parallel for?
- assert(!isParallel || reduc.size() <= 1);
-
- bool isSparseInput = false;
- size_t tid = tids.front(), dim = dims.front();
- for (auto [t, d] : llvm::zip(tids, dims)) {
- assert(dimTypes[t].size() > d); // Must be a valid tid, dim pair
- assert(!coord[t][d]); // We cannot re-enter the same level
- auto dimType = dimTypes[t][d];
- // Must be a recognizable DLT.
- assert(isDenseDLT(dimType) || isCompressedDLT(dimType) ||
- isSingletonDLT(dimType));
- bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType);
- // We can at most have one sparse input, otherwise, a while loop is required
- // to co-iterate multiple sparse tensors.
- assert(!isSparseInput || !isSparse);
- if (isSparse) {
- tid = t;
- dim = d;
- }
- isSparseInput = isSparseInput || isSparse;
- }
-
- Value step = constantIndex(builder, loc, 1);
- Value lo = isSparseInput ? pidxs[tid][dim] // current offset
- : loopSeqStack.back(); // univeral tid
- Value hi = highs[tid][dim];
- Operation *loop = nullptr;
- Value iv;
- if (isParallel) {
- scf::ParallelOp parOp =
- builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc);
- builder.setInsertionPointToStart(parOp.getBody());
- assert(parOp.getNumReductions() == reduc.size());
- iv = parOp.getInductionVars()[0];
-
- // In-place update on the reduction variable vector.
- // Note that the init vals is not the actual reduction variables but instead
- // used as a `special handle` to (temporarily) represent them. The
- // expression on init vals will be moved into scf.reduce and replaced with
- // the block arguments when exiting the loop (see exitForLoop). This is
- // needed as we can not build the actual reduction block and get the actual
- // reduction varaible before users fill parallel loop body.
- for (int i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = parOp.getInitVals()[i];
- loop = parOp;
- } else {
- scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
- builder.setInsertionPointToStart(forOp.getBody());
- iv = forOp.getInductionVar();
-
- // In-place update on the reduction variable vector.
- assert(forOp.getNumRegionIterArgs() == reduc.size());
- for (int i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = forOp.getRegionIterArg(i);
- loop = forOp;
- }
- assert(loop && iv);
-
- if (isSparseInput) {
- pidxs[tid][dim] = iv;
- // Generating a load on the indices array yields the coordinate.
- Value ptr = idxBuffer[tid][dim];
- coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv);
- } else {
- // Dense tensor, the coordinates is the inducation variable.
- coord[tid][dim] = iv;
- }
- // NOTE: we can also prepare for next dim here in advance
- // Push the loop into stack
- loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
- coord[tid][dim], loopTag);
- // Emit extra locals.
- emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
-
- return loop;
-}
-
-Operation *SparseTensorLoopEmitter::enterFilterLoopOverTensorAtDim(
- OpBuilder &builder, Location loc, size_t tid, size_t dim, AffineExpr affine,
- MutableArrayRef<Value> reduc) {
- assert(!affine.isa<AffineDimExpr>() && !isDenseDLT(dimTypes[tid][dim]));
- assert(dimTypes[tid].size() > dim);
- // We can not re-enter the same level.
- assert(!coord[tid][dim]);
-
- Value step = constantIndex(builder, loc, 1);
-
- Value lo = pidxs[tid][dim];
- Value hi = highs[tid][dim];
-
- // TODO: We should instead use a whileOp for filter loop to allow early
- // break when exceeding (for ordered dimensions).
- // TODO: There are many other potiential opportunities that we might apply in
- // the future. E.g., we could use binary search to located the pointer index.
- scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
-
- // In-place update on the reduction variable vector.
- assert(forOp.getNumRegionIterArgs() == reduc.size());
- for (int i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = forOp.getRegionIterArg(i);
-
- builder.setInsertionPointToStart(forOp.getBody());
- Value iv = forOp.getInductionVar();
-
- pidxs[tid][dim] = iv;
- // Generating a load on the indices array yields the coordinate.
- Value ptr = idxBuffer[tid][dim];
- coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv);
-
- // Generate an if condition to filter out indices that is not equal to the
- // result of the affine expression.
- Value expected = genAffine(builder, affine, loc);
- auto pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- coord[tid][dim], expected);
- SmallVector<Type> types;
- for (Value red : reduc) {
- types.push_back(red.getType());
- }
-
- bool hasReduc = !types.empty();
- scf::IfOp ifOp =
- builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
- if (hasReduc) {
- // scf.for (a) -> v
- // %s = scf.if (a) -> v
- // user-generated code.
- // else
- // yield a
- // yield %s
- builder.create<scf::YieldOp>(loc, ifOp.getResults());
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- // On mismatch.
- builder.create<scf::YieldOp>(loc, reduc);
- }
- // Set the insert point to matched branch.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-
- // NOTE: we can also prepare for next dim here in advance
- // Push the loop into stack
- loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), forOp,
- coord[tid][dim], nullptr);
- return forOp;
-}
-
-void SparseTensorLoopEmitter::genDenseAffineAddressAtCurLevel(
- OpBuilder &builder, Location loc, size_t tid, size_t dim,
- AffineExpr affine) {
- Value affineV = genAffine(builder, affine, loc);
- pidxs[tid][dim] = genAddress(builder, loc, tid, dim, affineV);
-}
-
-Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims(
- OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc) {
- assert(tids.size() == dims.size());
- SmallVector<Type> types;
- SmallVector<Value> operands;
- // Construct the while-loop with a parameter for each index.
- Type indexType = builder.getIndexType();
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
- assert(pidxs[tid][dim]);
- types.push_back(indexType);
- operands.push_back(pidxs[tid][dim]);
- }
- }
- // The position where user-supplied reduction variable starts.
- for (Value rec : reduc) {
- types.push_back(rec.getType());
- operands.push_back(rec);
- }
- if (needsUniv) {
- types.push_back(indexType);
- // Update universal index.
- operands.push_back(loopSeqStack.back());
- }
- assert(types.size() == operands.size());
- scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
-
- SmallVector<Location> locs(types.size(), loc);
- Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
- Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
-
- // Build the "before" region, which effectively consists
- // of a conjunction of "i < upper" tests on all induction.
- builder.setInsertionPointToStart(&whileOp.getBefore().front());
- Value cond;
- unsigned o = 0;
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
- Value op1 = before->getArgument(o);
- Value op2 = highs[tid][dim];
- Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- op1, op2);
- cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc;
- // Update
- pidxs[tid][dim] = after->getArgument(o++);
- }
- }
- builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
-
- // Generates while body.
- builder.setInsertionPointToStart(&whileOp.getAfter().front());
- Value min;
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- // Prepares for next level.
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
- Value ptr = idxBuffer[tid][dim];
- Value s = pidxs[tid][dim];
- Value load = genIndexLoad(builder, loc, ptr, s);
- coord[tid][dim] = load;
- if (!needsUniv) {
- if (min) {
- Value cmp = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, load, min);
- min = builder.create<arith::SelectOp>(loc, cmp, load, min);
- } else {
- min = load;
- }
- }
- }
- }
-
- if (needsUniv) {
- assert(!min);
- // Otherwise, universal index is the minimal pidx.
- min = after->getArguments().back();
- }
-
- // Sets up the loop stack.
- loopStack.emplace_back(tids, dims, whileOp, min, loopTag);
- assert(loopStack.size() == loopSeqStack.size());
-
- // Emits extra locals
- emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
-
- // Updates reduction variables
- assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0));
- // In-place update on reduction variable.
- for (unsigned i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = after->getArgument(o + i);
-
- return whileOp;
-}
-
-void SparseTensorLoopEmitter::prepareLoopOverTensorAtDim(OpBuilder &builder,
- Location loc,
- size_t tid,
- size_t dim) {
- assert(dimTypes[tid].size() > dim);
- auto dimType = dimTypes[tid][dim];
-
- if (isDenseDLT(dimType))
- return;
-
- // Either the first dimension, or the previous dimension has been set.
- assert(dim == 0 || pidxs[tid][dim - 1]);
- Value c0 = constantIndex(builder, loc, 0);
- Value c1 = constantIndex(builder, loc, 1);
- if (isCompressedDLT(dimType)) {
- Value ptr = ptrBuffer[tid][dim];
-
- Value pLo = dim == 0 ? c0 : pidxs[tid][dim - 1];
- pidxs[tid][dim] = genIndexLoad(builder, loc, ptr, pLo);
-
- Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
- highs[tid][dim] = genIndexLoad(builder, loc, ptr, pHi);
- return;
- }
- if (isSingletonDLT(dimType)) {
- Value pLo = dim == 0 ? c0 : pidxs[tid][dim - 1];
- Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
-
- pidxs[tid][dim] = pLo;
- highs[tid][dim] = pHi;
- return;
- }
-
- llvm_unreachable("Unrecognizable dimesion type!");
-}
-
-void SparseTensorLoopEmitter::emitExtraLocalsForTensorsAtDenseDims(
- OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims) {
- // Initialize dense positions. Note that we generate dense indices of the
- // output tensor unconditionally, since they may not appear in the lattice,
- // but may be needed for linearized codegen.
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isDenseDLT(dimTypes[tid][dim])) {
- auto enc = getSparseTensorEncoding(tensors[tid].getType());
- if (enc && !isSparseOutput(tid)) {
- bool validPidx = dim == 0 || pidxs[tid][dim - 1];
- if (!validPidx) {
- // We might not find the pidx for the sparse output tensor as it is
- // unconditionally required by the sparsification.
- assert(isOutputTensor(tid));
- continue;
- }
- pidxs[tid][dim] =
- genAddress(builder, loc, tid, dim, loopStack.back().iv);
- // NOTE: we can also prepare for next dim here in advance
- }
- }
- }
-}
-
-void SparseTensorLoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
- MutableArrayRef<Value> reduc) {
- LoopLevelInfo &loopInfo = loopStack.back();
- auto &dims = loopStack.back().dims;
- auto &tids = loopStack.back().tids;
- auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop);
- if (forOp) {
- if (!reduc.empty()) {
- assert(reduc.size() == forOp.getNumResults());
- rewriter.create<scf::YieldOp>(loc, reduc);
- }
- // Exit the loop.
- rewriter.setInsertionPointAfter(forOp);
- // In-place update reduction variables.
- for (unsigned i = 0, e = forOp.getResults().size(); i < e; i++)
- reduc[i] = forOp.getResult(i);
- } else {
- auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
- if (!reduc.empty()) {
- assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
- Operation *redExp = reduc.front().getDefiningOp();
- // Reduction expression should have no use.
- assert(redExp->getUses().empty());
- // This must be a binary operation.
- // NOTE: This is users' responsibilty to ensure the operation are
- // commutative.
- assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1);
-
- Value redVal = parOp.getInitVals().front();
- Value curVal;
- if (redExp->getOperand(0) == redVal)
- curVal = redExp->getOperand(1);
- else if (redExp->getOperand(1) == redVal)
- curVal = redExp->getOperand(0);
- // One of the operands must be the init value (which is also the
- // previous reduction value).
- assert(curVal);
- // The reduction expression should be the only user of the reduction val
- // inside the parallel for.
- unsigned numUsers = 0;
- for (Operation *op : redVal.getUsers()) {
- if (op->getParentOp() == parOp)
- numUsers++;
- }
- assert(numUsers == 1);
- (void)numUsers; // to silence unused variable warning in release build
-
- rewriter.setInsertionPointAfter(redExp);
- auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal);
- // Attach to the reduction op.
- Block *redBlock = &redOp.getRegion().getBlocks().front();
- rewriter.setInsertionPointToEnd(redBlock);
- Operation *newRed = rewriter.clone(*redExp);
- // Replaces arguments of the reduction expression by using the block
- // arguments from scf.reduce.
- rewriter.updateRootInPlace(
- newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
- // Erases the out-dated reduction expression.
- rewriter.eraseOp(redExp);
- rewriter.setInsertionPointToEnd(redBlock);
- rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0));
- }
- rewriter.setInsertionPointAfter(parOp);
- // In-place update reduction variables.
- for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
- reduc[i] = parOp.getResult(i);
- }
-
- // Finished iterating a tensor, clean up
- // We only do the clean up on for loop as while loops do not necessarily
- // finish the iteration on a sparse tensor
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- // Reset to null.
- coord[tid][dim] = Value();
- pidxs[tid][dim] = Value();
- // Dense dimension, high is fixed.
- if (!isDenseDLT(dimTypes[tid][dim]))
- highs[tid][dim] = Value();
- }
-}
-
-void SparseTensorLoopEmitter::exitCoIterationLoop(
- OpBuilder &builder, Location loc, MutableArrayRef<Value> reduc) {
- auto whileOp = llvm::cast<scf::WhileOp>(loopStack.back().loop);
- auto &dims = loopStack.back().dims;
- auto &tids = loopStack.back().tids;
- Value iv = loopStack.back().iv;
- // Generation while loop induction at the end.
- builder.setInsertionPointToEnd(&whileOp.getAfter().front());
- // Finalize the induction. Note that the induction could be performed
- // in the individual if-branches to avoid re-evaluating the conditions.
- // However, that would result in a rather elaborate forest of yield
- // instructions during code generation. Moreover, performing the induction
- // after the if-statements more closely resembles code generated by TACO.
- unsigned o = 0;
- SmallVector<Value> operands;
- Value one = constantIndex(builder, loc, 1);
- for (auto [tid, dim] : llvm::zip(tids, dims)) {
- if (isCompressedDLT(dimTypes[tid][dim]) ||
- isSingletonDLT(dimTypes[tid][dim])) {
- Value op1 = coord[tid][dim];
- Value op3 = pidxs[tid][dim];
- Value cmp =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1, iv);
- Value add = builder.create<arith::AddIOp>(loc, op3, one);
- operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3));
- // Following loops continue iteration from the break point of the
- // current while loop.
- pidxs[tid][dim] = whileOp->getResult(o++);
- // The coordinates are invalid now.
- coord[tid][dim] = nullptr;
- // highs remains unchanged.
- }
- }
-
- // Reduction value from users.
- for (auto &i : reduc) {
- operands.push_back(i);
- // In place update reduction variable.
- i = whileOp->getResult(o++);
- }
-
- // An (optional) universal index.
- if (operands.size() < whileOp.getNumResults()) {
- assert(operands.size() + 1 == whileOp.getNumResults());
- // The last one is the universial index.
- operands.push_back(builder.create<arith::AddIOp>(loc, iv, one));
- // update the loop starting point of current loop sequence
- loopSeqStack.back() = whileOp->getResult(o++);
- }
-
- assert(o == operands.size());
- builder.create<scf::YieldOp>(loc, operands);
- builder.setInsertionPointAfter(whileOp);
-}
-
-void SparseTensorLoopEmitter::exitCurrentLoop(RewriterBase &rewriter,
- Location loc,
- MutableArrayRef<Value> reduc) {
- // Clean up the values, it would help use to discover potential bug at a
- // earlier stage (instead of silently using a wrong value).
- LoopLevelInfo &loopInfo = loopStack.back();
- assert(loopInfo.tids.size() == loopInfo.dims.size());
- SmallVector<Value> red;
- if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
- exitCoIterationLoop(rewriter, loc, reduc);
- } else {
- exitForLoop(rewriter, loc, reduc);
- }
-
- assert(loopStack.size() == loopSeqStack.size());
- loopStack.pop_back();
-}
-
//===----------------------------------------------------------------------===//
// ExecutionEngine/SparseTensorUtils helper functions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 46f02141bb82..da94b27c42d6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -313,274 +313,6 @@ inline bool isZeroRankedTensorOrScalar(Type type) {
return !rtp || rtp.getRank() == 0;
}
-//===----------------------------------------------------------------------===//
-// SparseTensorLoopEmiter class, manages sparse tensors and helps to
-// generate loop structure to (co)-iterate sparse tensors.
-//
-// An example usage:
-// To generate the following loops over T1<?x?> and T2<?x?>
-//
-// for i in TENSOR_1_0 {
-// for j : TENSOR_2_0 {
-// for k : TENSOR_1_1 {}
-// for k : TENSOR_2_1 {}
-// }
-// }
-//
-// One can use
-//
-// SparseTensorLoopEmiter loopEmiter({T1, T1});
-// loopEmiter.initializeLoopEmit();
-// loopEmiter.enterLoopOverTensorAtDim(T1, 0);
-// loopEmiter.enterLoopOverTensorAtDim(T2, 0);
-// loopEmiter.enterLoopOverTensorAtDim(T1, 1);
-// loopEmiter.exitCurrentLoop();
-// loopEmiter.enterLoopOverTensorAtDim(T2, 1);
-// loopEmiter.exitCurrentLoop(); // exit k
-// loopEmiter.exitCurrentLoop(); // exit j
-// loopEmiter.exitCurrentLoop(); // exit i
-//===----------------------------------------------------------------------===//
-
-class SparseTensorLoopEmitter {
-public:
- /// Optional callback function to setup dense output tensors when
- /// initializing the loop emitter (e.g., to fill a dense output with zeros).
- using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
- Value memref, Value tensor)>;
-
- SparseTensorLoopEmitter() = default;
-
- /// Takes an array of tensors inputs, on which the generated loops will
- /// iterate on. The index of the tensor in the array is also the tensor id
- /// (tid) used in related functions. If isSparseOut is set, loop emitter
- /// assume that the sparse output tensor is empty, and will always generate
- /// loops on it based on the dim sizes. An optional array could be provided
- /// (by sparsification) to indicate the loop id sequence that will be
- /// generated. It is used to establish the mapping between affineDimExpr to
- /// the corresponding loop index in the loop stack that are maintained by the
- /// loop emitter.
- void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
- bool hasOutput = false, bool isSparseOut = false,
- ArrayRef<unsigned> topSort = {});
-
- explicit SparseTensorLoopEmitter(ValueRange tensors,
- StringAttr loopTag = nullptr,
- bool hasOutput = false,
- bool isSparseOut = false,
- ArrayRef<unsigned> topSort = {});
-
- /// Starts a loop emitting session by generating all the buffers needed to
- /// iterate tensors.
- void initializeLoopEmit(OpBuilder &builder, Location loc,
- OutputUpdater updater = nullptr);
-
- /// Generates a list of operations to compute the affine expression.
- Value genAffine(OpBuilder &builder, AffineExpr a, Location loc);
-
- /// Enters a new loop sequence, the loops within the same sequence starts
- /// from the break points of previous loop instead of starting over from 0.
- /// e.g.,
- /// {
- /// // loop sequence start.
- /// p0 = while(xxx)
- /// ...
- /// break p0
- ///
- /// // Starts loop from p0
- /// for (i = p0; i < end; i++)
- /// ...
- /// // loop sequence end.
- /// }
- void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims);
-
- // exit the current loop sequence, this will reset universal index to 0.
- void exitCurrentLoopSeq() {
- assert(loopSeqStack.size() == loopStack.size() + 1);
- loopSeqStack.pop_back();
- }
-
- // TODO: Gets rid of `dim` in the argument list? Track the dimension we
- // are currently at internally. Then it would be enterNextDimForTensor.
- // Still need a way to specify the dim for non annoated dense tensor though,
- // as it can be accessed out of order.
- /// Emits loop over tensor_tid_dim, it assumes that loops between
- /// tensor_tid_[0, dim - 1] have already been generated.
- /// The function will also perform in-place update on the `reduc` vector to
- /// return the reduction variable used inside the generated loop.
- Operation *enterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
- ArrayRef<size_t> tids,
- ArrayRef<size_t> dims,
- MutableArrayRef<Value> reduc = {},
- bool isParallel = false);
-
- Operation *enterFilterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
- size_t tid, size_t dim,
- AffineExpr affine,
- MutableArrayRef<Value> reduc = {});
-
- void genDenseAffineAddressAtCurLevel(OpBuilder &builder, Location loc,
- size_t tid, size_t dim,
- AffineExpr affine);
-
- /// Emits a co-iteration loop over a set of tensors.
- Operation *enterCoIterationOverTensorsAtDims(
- OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
- ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc = {});
-
- void exitCurrentLoop(RewriterBase &rewriter, Location loc,
- MutableArrayRef<Value> reduc = {});
-
- /// Returns the array of coordinate for all the loop generated till now.
- void getCoordinateArray(SmallVectorImpl<Value> &coords) const {
- for (auto &l : loopStack)
- coords.push_back(l.iv);
- }
-
- /// Gets loop induction variable at the given level.
- unsigned getCurrentDepth() const { return loopStack.size(); }
-
- /// Gets loop induction variable at the given level.
- Value getLoopIV(size_t level) const {
- if (level < loopStack.size())
- return loopStack[level].iv;
- return nullptr;
- }
-
- ///
- /// Getters.
- ///
- const std::vector<std::vector<Value>> &getPidxs() const { return pidxs; };
- const std::vector<std::vector<Value>> &getCoord() const { return coord; };
- const std::vector<std::vector<Value>> &getHighs() const { return highs; };
- const std::vector<std::vector<Value>> &getPtrBuffer() const {
- return ptrBuffer;
- };
- const std::vector<std::vector<Value>> &getIdxBuffer() const {
- return idxBuffer;
- };
- const std::vector<Value> &getValBuffer() const { return valBuffer; };
-
- constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
- return llvm::StringLiteral("Emitted from");
- }
-
-private:
- struct LoopLevelInfo {
- LoopLevelInfo(ArrayRef<size_t> tids, ArrayRef<size_t> dims, Operation *loop,
- Value iv, StringAttr loopTag)
- : tids(tids), dims(dims), loop(loop), iv(iv) {
- // Attached a special tag to loop emitter generated loop.
- if (loopTag)
- loop->setAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(),
- loopTag);
- }
- // TODO: maybe use a vector<pair> for tid and dim?
- // The set of tensors that the loop is operating on
- const llvm::SmallVector<size_t> tids;
- // The corresponding dims for the tensors
- const llvm::SmallVector<size_t> dims;
- const Operation *loop; // the loop operation
- const Value iv; // the induction variable for the loop
- };
-
- /// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
- Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim,
- Value iv) {
- Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
- Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
- Value add = builder.create<arith::AddIOp>(loc, mul, iv);
- return add;
- }
-
- bool isOutputTensor(size_t tid) {
- return hasOutput && tid == tensors.size() - 1;
- }
-
- bool isSparseOutput(size_t tid) { return isOutputTensor(tid) && isSparseOut; }
-
- /// Setups [lo, hi] for iterating tensor[dim], it assumes that tensor[0
- /// ...dims-1] has already been setup.
- void prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid,
- size_t dim);
-
- /// Emits extra locals, since the locals might not be in simplified lattices
- /// point used to generate the loops, but are still required to generates
- /// expressions.
- void emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder, Location loc,
- ArrayRef<size_t> tids,
- ArrayRef<size_t> dims);
-
- /// Exits a for loop, returns the reduction results, e.g.,
- /// For sequential for loops:
- /// %ret = for () {
- /// ...
- /// %val = addi %args, %c
- /// yield %val
- /// }
- /// For parallel loops, the following generated code by users:
- /// %ret = parallel () init(%args) {
- /// ...
- /// %val = op %args, %c
- /// }
- /// will be transformed into
- /// %ret = parallel () init(%args) {
- /// ...
- /// scf.reduce(%c) bb0(%0, %1){
- /// %val = op %0, %1
- /// scf.reduce.return %val
- /// }
- /// }
- /// NOTE: only one instruction will be moved into reduce block,
- /// transformation will fail if multiple instructions are used to compute
- /// the reduction value. Return %ret to user, while %val is provided by
- /// users (`reduc`).
- void exitForLoop(RewriterBase &rewriter, Location loc,
- MutableArrayRef<Value> reduc);
-
- /// Exits a while loop, returns the reduction results.
- void exitCoIterationLoop(OpBuilder &builder, Location loc,
- MutableArrayRef<Value> reduc);
-
- /// A optional string attribute that should be attached to the loop
- /// generated by loop emitter, it might help following passes to identify
- /// loops that operates on sparse tensors more easily.
- StringAttr loopTag;
- /// Whether the loop emitter needs to treat the last tensor as the output
- /// tensor.
- bool hasOutput;
- bool isSparseOut;
- /// Input and (optional) output tensors.
- std::vector<Value> tensors;
- /// The dim type array for each tensor.
- std::vector<std::vector<DimLevelType>> dimTypes;
- /// Sparse iteration information (by tensor and dim). These arrays
- /// are updated to remain current within the current loop.
- std::vector<std::vector<Value>> pidxs;
- std::vector<std::vector<Value>> coord;
- std::vector<std::vector<Value>> highs;
- std::vector<std::vector<Value>> ptrBuffer; // to_pointers
- std::vector<std::vector<Value>> idxBuffer; // to_indices
- std::vector<Value> valBuffer; // to_value
-
- // Loop Stack, stores the information of all the nested loops that are
- // alive.
- std::vector<LoopLevelInfo> loopStack;
-
- // Loop Sequence Stack, stores the unversial index for the current loop
- // sequence.
- std::vector<Value> loopSeqStack;
-
- // Maps AffineDimExpr to the index of the loop in loopStack.
- // TODO: We should probably use a callback function here to make it more
- // general.
- std::vector<unsigned> sparsiferLoopLvlMap;
-
- // TODO: not yet used, it should track the current level for each tensor
- // to help eliminate `dim` paramters from above APIs.
- // std::vector<size_t> curLv;
-};
-
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
new file mode 100644
index 000000000000..d77b060d931d
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -0,0 +1,691 @@
+//===- LoopEmitter.cpp ----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LoopEmitter.h"
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+//===----------------------------------------------------------------------===//
+// File local helper functions.
+//===----------------------------------------------------------------------===//
+
+/// Generates a pointer/index load from the sparse storage scheme. Narrower
+/// data types need to be zero extended before casting the value into the
+/// index type used for looping and indexing.
+static Value genIndexLoad(OpBuilder &builder, Location loc, Value ptr,
+ Value s) {
+ // For the scalar case, we simply zero extend narrower indices into 64-bit
+ // values before casting to index without a performance penalty. Here too,
+ // however, indices that already are 64-bit, in theory, cannot express the
+ // full range as explained above.
+ Value load = builder.create<memref::LoadOp>(loc, ptr, s);
+ if (!load.getType().isa<IndexType>()) {
+ if (load.getType().getIntOrFloatBitWidth() < 64)
+ load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
+ load =
+ builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
+ }
+ return load;
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse tensor loop emitter class implementations
+//===----------------------------------------------------------------------===//
+
+Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
+ size_t dim, Value iv) {
+ Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
+ Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
+ Value add = builder.create<arith::AddIOp>(loc, mul, iv);
+ return add;
+}
+
+LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
+ bool isSparseOut, ArrayRef<unsigned> topSort) {
+ initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
+}
+
+void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
+ bool hasOutput, bool isSparseOut,
+ ArrayRef<unsigned> topSort) {
+ // First initializes fields.
+ this->loopTag = loopTag;
+ this->hasOutput = hasOutput;
+ this->isSparseOut = isSparseOut;
+ this->tensors.assign(tensors.begin(), tensors.end());
+ this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
+ this->pidxs.assign(tensors.size(), std::vector<Value>());
+ this->coord.assign(tensors.size(), std::vector<Value>());
+ this->highs.assign(tensors.size(), std::vector<Value>());
+ this->ptrBuffer.assign(tensors.size(), std::vector<Value>());
+ this->idxBuffer.assign(tensors.size(), std::vector<Value>());
+ this->valBuffer.assign(tensors.size(), nullptr);
+ this->loopStack.reserve(topSort.size());
+ this->sparsiferLoopLvlMap.assign(topSort.size(), 0);
+
+ for (size_t tid = 0, e = tensors.size(); tid < e; tid++) {
+ auto t = tensors[tid];
+ // a scalar or 0-dimension tensors
+ if (isZeroRankedTensorOrScalar(t.getType()))
+ continue;
+ auto rtp = t.getType().cast<RankedTensorType>();
+ auto rank = static_cast<size_t>(rtp.getRank());
+ auto enc = getSparseTensorEncoding(rtp);
+ // We always treat sparse output tensor as dense so that we always iterate
+ // it based on dim size.
+ if (enc && !(isOutputTensor(tid) && isSparseOut))
+ for (auto dimTp : enc.getDimLevelType())
+ dimTypes[tid].push_back(dimTp);
+ else
+ dimTypes[tid].assign(rank, DimLevelType::Dense);
+
+ // Initialize using empty value.
+ pidxs[tid].assign(rank, Value());
+ coord[tid].assign(rank, Value());
+ highs[tid].assign(rank, Value());
+ ptrBuffer[tid].assign(rank, Value());
+ idxBuffer[tid].assign(rank, Value());
+ }
+
+ // FIXME: This map should be maintained outside loop emitter.
+ for (unsigned i = 0, e = topSort.size(); i < e; i++) {
+ // This is an inverse map of the topologically sorted loop index from
+ // sparsifier. This is needed to map the AffineDimExpr back to the loopStack
+ // index used in loop emitter.
+ sparsiferLoopLvlMap[topSort[i]] = i;
+ }
+}
+
+void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
+ LoopEmitter::OutputUpdater updater) {
+ // For every tensor, find lower and upper bound on dimensions, set the
+ // same bounds on loop indices, and obtain dense or sparse buffer(s).
+ for (size_t t = 0, e = tensors.size(); t < e; t++) {
+ auto tensor = tensors[t];
+ auto rtp = tensor.getType().dyn_cast<RankedTensorType>();
+ if (!rtp)
+ // Skips only scalar, zero ranked tensor still need to be bufferized and
+ // (probably) filled with zeros by users.
+ continue;
+ auto rank = rtp.getRank();
+ auto shape = rtp.getShape();
+ auto enc = getSparseTensorEncoding(rtp);
+ auto dynShape = {ShapedType::kDynamic};
+ // Scan all dimensions of current tensor.
+ for (int64_t d = 0; d < rank; d++) {
+ // This should be called only once at beginning.
+ assert(!ptrBuffer[t][d] && !idxBuffer[t][d] && !highs[t][d]);
+ // Handle sparse storage schemes.
+ if (isCompressedDLT(dimTypes[t][d])) {
+ auto ptrTp =
+ MemRefType::get(dynShape, getPointerOverheadType(builder, enc));
+ auto indTp =
+ MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
+ auto dim = builder.getIndexAttr(d);
+ // Generate sparse primitives to obtains pointer and indices.
+ ptrBuffer[t][d] = builder.create<ToPointersOp>(loc, ptrTp, tensor, dim);
+ idxBuffer[t][d] = builder.create<ToIndicesOp>(loc, indTp, tensor, dim);
+ } else if (isSingletonDLT(dimTypes[t][d])) {
+ // Singleton dimension, fetch indices.
+ auto indTp =
+ MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
+ auto dim = builder.getIndexAttr(d);
+ idxBuffer[t][d] = builder.create<ToIndicesOp>(loc, indTp, tensor, dim);
+ } else {
+ // Dense dimension, nothing to fetch.
+ assert(isDenseDLT(dimTypes[t][d]));
+ }
+
+ // Find upper bound in current dimension.
+ unsigned p = toOrigDim(enc, d);
+ Value up = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, p);
+ highs[t][d] = up;
+ }
+
+ // Perform the required bufferization. Dense inputs materialize
+ // from the input tensors. Sparse inputs use sparse primitives to obtain the
+ // values.
+ // Delegates extra output initialization to clients.
+ bool isOutput = isOutputTensor(t);
+ Type elementType = rtp.getElementType();
+ if (!enc) {
+ // Non-annotated dense tensors.
+ auto denseTp = MemRefType::get(shape, elementType);
+ Value denseVal =
+ builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
+ // Dense outputs need special handling.
+ if (isOutput && updater)
+ denseVal = updater(builder, loc, denseVal, tensor);
+
+ valBuffer[t] = denseVal;
+ } else {
+ // Annotated sparse tensors.
+ // We also need the value buffer for annotated all dense `sparse` tensor.
+ auto dynShape = {ShapedType::kDynamic};
+ auto sparseTp = MemRefType::get(dynShape, elementType);
+ valBuffer[t] = builder.create<ToValuesOp>(loc, sparseTp, tensor);
+ }
+ // NOTE: we can also prepare for 0 dim here in advance, this will hosit
+ // some loop preparation from tensor iteration, but will also (undesirably)
+ // hosit the code ouside if conditions.
+ }
+}
+
+void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
+ ArrayRef<size_t> tids,
+ ArrayRef<size_t> dims) {
+ // Universal Index start from 0
+ assert(loopSeqStack.size() == loopStack.size());
+ // Universal index starts from 0
+ loopSeqStack.emplace_back(constantIndex(builder, loc, 0));
+ // Prepares for all the tensors used in the current loop sequence.
+ for (auto [tid, dim] : llvm::zip(tids, dims))
+ prepareLoopOverTensorAtDim(builder, loc, tid, dim);
+}
+
+Value LoopEmitter::genAffine(OpBuilder &builder, AffineExpr a, Location loc) {
+ switch (a.getKind()) {
+ case AffineExprKind::DimId: {
+ unsigned idx = a.cast<AffineDimExpr>().getPosition();
+ return loopStack[sparsiferLoopLvlMap[idx]].iv;
+ }
+ case AffineExprKind::Add: {
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ return builder.create<arith::AddIOp>(
+ loc, genAffine(builder, binOp.getLHS(), loc),
+ genAffine(builder, binOp.getRHS(), loc));
+ }
+ case AffineExprKind::Mul: {
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ return builder.create<arith::MulIOp>(
+ loc, genAffine(builder, binOp.getLHS(), loc),
+ genAffine(builder, binOp.getRHS(), loc));
+ }
+ case AffineExprKind::Constant: {
+ int64_t c = a.cast<AffineConstantExpr>().getValue();
+ return constantIndex(builder, loc, c);
+ }
+ default:
+ llvm_unreachable("unexpected affine subscript");
+ }
+}
+
+Operation *LoopEmitter::enterLoopOverTensorAtDim(
+ OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
+ ArrayRef<size_t> dims, MutableArrayRef<Value> reduc, bool isParallel) {
+ // TODO: support multiple return on parallel for?
+ assert(!isParallel || reduc.size() <= 1);
+
+ bool isSparseInput = false;
+ size_t tid = tids.front(), dim = dims.front();
+ for (auto [t, d] : llvm::zip(tids, dims)) {
+ assert(dimTypes[t].size() > d); // Must be a valid tid, dim pair
+ assert(!coord[t][d]); // We cannot re-enter the same level
+ auto dimType = dimTypes[t][d];
+ // Must be a recognizable DLT.
+ assert(isDenseDLT(dimType) || isCompressedDLT(dimType) ||
+ isSingletonDLT(dimType));
+ bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType);
+ // We can at most have one sparse input, otherwise, a while loop is required
+ // to co-iterate multiple sparse tensors.
+ assert(!isSparseInput || !isSparse);
+ if (isSparse) {
+ tid = t;
+ dim = d;
+ }
+ isSparseInput = isSparseInput || isSparse;
+ }
+
+ Value step = constantIndex(builder, loc, 1);
+ Value lo = isSparseInput ? pidxs[tid][dim] // current offset
+ : loopSeqStack.back(); // univeral tid
+ Value hi = highs[tid][dim];
+ Operation *loop = nullptr;
+ Value iv;
+ if (isParallel) {
+ scf::ParallelOp parOp =
+ builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc);
+ builder.setInsertionPointToStart(parOp.getBody());
+ assert(parOp.getNumReductions() == reduc.size());
+ iv = parOp.getInductionVars()[0];
+
+ // In-place update on the reduction variable vector.
+ // Note that the init vals is not the actual reduction variables but instead
+ // used as a `special handle` to (temporarily) represent them. The
+ // expression on init vals will be moved into scf.reduce and replaced with
+ // the block arguments when exiting the loop (see exitForLoop). This is
+ // needed as we can not build the actual reduction block and get the actual
+ // reduction varaible before users fill parallel loop body.
+ for (int i = 0, e = reduc.size(); i < e; i++)
+ reduc[i] = parOp.getInitVals()[i];
+ loop = parOp;
+ } else {
+ scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
+ builder.setInsertionPointToStart(forOp.getBody());
+ iv = forOp.getInductionVar();
+
+ // In-place update on the reduction variable vector.
+ assert(forOp.getNumRegionIterArgs() == reduc.size());
+ for (int i = 0, e = reduc.size(); i < e; i++)
+ reduc[i] = forOp.getRegionIterArg(i);
+ loop = forOp;
+ }
+ assert(loop && iv);
+
+ if (isSparseInput) {
+ pidxs[tid][dim] = iv;
+ // Generating a load on the indices array yields the coordinate.
+ Value ptr = idxBuffer[tid][dim];
+ coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv);
+ } else {
+ // Dense tensor, the coordinates is the inducation variable.
+ coord[tid][dim] = iv;
+ }
+ // NOTE: we can also prepare for next dim here in advance
+ // Push the loop into stack
+ loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
+ coord[tid][dim], loopTag);
+ // Emit extra locals.
+ emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
+
+ return loop;
+}
+
+Operation *LoopEmitter::enterFilterLoopOverTensorAtDim(
+ OpBuilder &builder, Location loc, size_t tid, size_t dim, AffineExpr affine,
+ MutableArrayRef<Value> reduc) {
+ assert(!affine.isa<AffineDimExpr>() && !isDenseDLT(dimTypes[tid][dim]));
+ assert(dimTypes[tid].size() > dim);
+ // We can not re-enter the same level.
+ assert(!coord[tid][dim]);
+
+ Value step = constantIndex(builder, loc, 1);
+
+ Value lo = pidxs[tid][dim];
+ Value hi = highs[tid][dim];
+
+ // TODO: We should instead use a whileOp for filter loop to allow early
+ // break when exceeding (for ordered dimensions).
+ // TODO: There are many other potiential opportunities that we might apply in
+ // the future. E.g., we could use binary search to located the pointer index.
+ scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
+
+ // In-place update on the reduction variable vector.
+ assert(forOp.getNumRegionIterArgs() == reduc.size());
+ for (int i = 0, e = reduc.size(); i < e; i++)
+ reduc[i] = forOp.getRegionIterArg(i);
+
+ builder.setInsertionPointToStart(forOp.getBody());
+ Value iv = forOp.getInductionVar();
+
+ pidxs[tid][dim] = iv;
+ // Generating a load on the indices array yields the coordinate.
+ Value ptr = idxBuffer[tid][dim];
+ coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv);
+
+ // Generate an if condition to filter out indices that is not equal to the
+ // result of the affine expression.
+ Value expected = genAffine(builder, affine, loc);
+ auto pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ coord[tid][dim], expected);
+ SmallVector<Type> types;
+ for (Value red : reduc) {
+ types.push_back(red.getType());
+ }
+
+ bool hasReduc = !types.empty();
+ scf::IfOp ifOp =
+ builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
+ if (hasReduc) {
+ // scf.for (a) -> v
+ // %s = scf.if (a) -> v
+ // user-generated code.
+ // else
+ // yield a
+ // yield %s
+ builder.create<scf::YieldOp>(loc, ifOp.getResults());
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // On mismatch.
+ builder.create<scf::YieldOp>(loc, reduc);
+ }
+ // Set the insert point to matched branch.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ // NOTE: we can also prepare for next dim here in advance
+ // Push the loop into stack
+ loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), forOp,
+ coord[tid][dim], nullptr);
+ return forOp;
+}
+
+void LoopEmitter::genDenseAffineAddressAtCurLevel(OpBuilder &builder,
+ Location loc, size_t tid,
+ size_t dim,
+ AffineExpr affine) {
+ Value affineV = genAffine(builder, affine, loc);
+ pidxs[tid][dim] = genAddress(builder, loc, tid, dim, affineV);
+}
+
+Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
+ OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
+ ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc) {
+ assert(tids.size() == dims.size());
+ SmallVector<Type> types;
+ SmallVector<Value> operands;
+ // Construct the while-loop with a parameter for each index.
+ Type indexType = builder.getIndexType();
+ for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ if (isCompressedDLT(dimTypes[tid][dim]) ||
+ isSingletonDLT(dimTypes[tid][dim])) {
+ assert(pidxs[tid][dim]);
+ types.push_back(indexType);
+ operands.push_back(pidxs[tid][dim]);
+ }
+ }
+ // The position where user-supplied reduction variable starts.
+ for (Value rec : reduc) {
+ types.push_back(rec.getType());
+ operands.push_back(rec);
+ }
+ if (needsUniv) {
+ types.push_back(indexType);
+ // Update universal index.
+ operands.push_back(loopSeqStack.back());
+ }
+ assert(types.size() == operands.size());
+ scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
+
+ SmallVector<Location> locs(types.size(), loc);
+ Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
+ Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
+
+ // Build the "before" region, which effectively consists
+ // of a conjunction of "i < upper" tests on all induction.
+ builder.setInsertionPointToStart(&whileOp.getBefore().front());
+ Value cond;
+ unsigned o = 0;
+ for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ if (isCompressedDLT(dimTypes[tid][dim]) ||
+ isSingletonDLT(dimTypes[tid][dim])) {
+ Value op1 = before->getArgument(o);
+ Value op2 = highs[tid][dim];
+ Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ op1, op2);
+ cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc;
+ // Update
+ pidxs[tid][dim] = after->getArgument(o++);
+ }
+ }
+ builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
+
+ // Generates while body.
+ builder.setInsertionPointToStart(&whileOp.getAfter().front());
+ Value min;
+ for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ // Prepares for next level.
+ if (isCompressedDLT(dimTypes[tid][dim]) ||
+ isSingletonDLT(dimTypes[tid][dim])) {
+ Value ptr = idxBuffer[tid][dim];
+ Value s = pidxs[tid][dim];
+ Value load = genIndexLoad(builder, loc, ptr, s);
+ coord[tid][dim] = load;
+ if (!needsUniv) {
+ if (min) {
+ Value cmp = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, load, min);
+ min = builder.create<arith::SelectOp>(loc, cmp, load, min);
+ } else {
+ min = load;
+ }
+ }
+ }
+ }
+
+ if (needsUniv) {
+ assert(!min);
+ // Otherwise, universal index is the minimal pidx.
+ min = after->getArguments().back();
+ }
+
+ // Sets up the loop stack.
+ loopStack.emplace_back(tids, dims, whileOp, min, loopTag);
+ assert(loopStack.size() == loopSeqStack.size());
+
+ // Emits extra locals
+ emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
+
+ // Updates reduction variables
+ assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0));
+ // In-place update on reduction variable.
+ for (unsigned i = 0, e = reduc.size(); i < e; i++)
+ reduc[i] = after->getArgument(o + i);
+
+ return whileOp;
+}
+
+void LoopEmitter::prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc,
+ size_t tid, size_t dim) {
+ assert(dimTypes[tid].size() > dim);
+ auto dimType = dimTypes[tid][dim];
+
+ if (isDenseDLT(dimType))
+ return;
+
+ // Either the first dimension, or the previous dimension has been set.
+ assert(dim == 0 || pidxs[tid][dim - 1]);
+ Value c0 = constantIndex(builder, loc, 0);
+ Value c1 = constantIndex(builder, loc, 1);
+ if (isCompressedDLT(dimType)) {
+ Value ptr = ptrBuffer[tid][dim];
+
+ Value pLo = dim == 0 ? c0 : pidxs[tid][dim - 1];
+ pidxs[tid][dim] = genIndexLoad(builder, loc, ptr, pLo);
+
+ Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
+ highs[tid][dim] = genIndexLoad(builder, loc, ptr, pHi);
+ return;
+ }
+ if (isSingletonDLT(dimType)) {
+ Value pLo = dim == 0 ? c0 : pidxs[tid][dim - 1];
+ Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
+
+ pidxs[tid][dim] = pLo;
+ highs[tid][dim] = pHi;
+ return;
+ }
+
+ llvm_unreachable("Unrecognizable dimesion type!");
+}
+
+void LoopEmitter::emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder,
+ Location loc,
+ ArrayRef<size_t> tids,
+ ArrayRef<size_t> dims) {
+ // Initialize dense positions. Note that we generate dense indices of the
+ // output tensor unconditionally, since they may not appear in the lattice,
+ // but may be needed for linearized codegen.
+ for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ if (isDenseDLT(dimTypes[tid][dim])) {
+ auto enc = getSparseTensorEncoding(tensors[tid].getType());
+ if (enc && !isSparseOutput(tid)) {
+ bool validPidx = dim == 0 || pidxs[tid][dim - 1];
+ if (!validPidx) {
+ // We might not find the pidx for the sparse output tensor as it is
+ // unconditionally required by the sparsification.
+ assert(isOutputTensor(tid));
+ continue;
+ }
+ pidxs[tid][dim] =
+ genAddress(builder, loc, tid, dim, loopStack.back().iv);
+ // NOTE: we can also prepare for next dim here in advance
+ }
+ }
+ }
+}
+
+void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
+ MutableArrayRef<Value> reduc) {
+ LoopLevelInfo &loopInfo = loopStack.back();
+ auto &dims = loopStack.back().dims;
+ auto &tids = loopStack.back().tids;
+ auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop);
+ if (forOp) {
+ if (!reduc.empty()) {
+ assert(reduc.size() == forOp.getNumResults());
+ rewriter.create<scf::YieldOp>(loc, reduc);
+ }
+ // Exit the loop.
+ rewriter.setInsertionPointAfter(forOp);
+ // In-place update reduction variables.
+ for (unsigned i = 0, e = forOp.getResults().size(); i < e; i++)
+ reduc[i] = forOp.getResult(i);
+ } else {
+ auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
+ if (!reduc.empty()) {
+ assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
+ Operation *redExp = reduc.front().getDefiningOp();
+ // Reduction expression should have no use.
+ assert(redExp->getUses().empty());
+ // This must be a binary operation.
+ // NOTE: This is users' responsibilty to ensure the operation are
+ // commutative.
+ assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1);
+
+ Value redVal = parOp.getInitVals().front();
+ Value curVal;
+ if (redExp->getOperand(0) == redVal)
+ curVal = redExp->getOperand(1);
+ else if (redExp->getOperand(1) == redVal)
+ curVal = redExp->getOperand(0);
+ // One of the operands must be the init value (which is also the
+ // previous reduction value).
+ assert(curVal);
+ // The reduction expression should be the only user of the reduction val
+ // inside the parallel for.
+ unsigned numUsers = 0;
+ for (Operation *op : redVal.getUsers()) {
+ if (op->getParentOp() == parOp)
+ numUsers++;
+ }
+ assert(numUsers == 1);
+ (void)numUsers; // to silence unused variable warning in release build
+
+ rewriter.setInsertionPointAfter(redExp);
+ auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal);
+ // Attach to the reduction op.
+ Block *redBlock = &redOp.getRegion().getBlocks().front();
+ rewriter.setInsertionPointToEnd(redBlock);
+ Operation *newRed = rewriter.clone(*redExp);
+ // Replaces arguments of the reduction expression by using the block
+ // arguments from scf.reduce.
+ rewriter.updateRootInPlace(
+ newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
+ // Erases the out-dated reduction expression.
+ rewriter.eraseOp(redExp);
+ rewriter.setInsertionPointToEnd(redBlock);
+ rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0));
+ }
+ rewriter.setInsertionPointAfter(parOp);
+ // In-place update reduction variables.
+ for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
+ reduc[i] = parOp.getResult(i);
+ }
+
+ // Finished iterating a tensor, clean up
+ // We only do the clean up on for loop as while loops do not necessarily
+ // finish the iteration on a sparse tensor
+ for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ // Reset to null.
+ coord[tid][dim] = Value();
+ pidxs[tid][dim] = Value();
+ // Dense dimension, high is fixed.
+ if (!isDenseDLT(dimTypes[tid][dim]))
+ highs[tid][dim] = Value();
+ }
+}
+
+void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
+ MutableArrayRef<Value> reduc) {
+ auto whileOp = llvm::cast<scf::WhileOp>(loopStack.back().loop);
+ auto &dims = loopStack.back().dims;
+ auto &tids = loopStack.back().tids;
+ Value iv = loopStack.back().iv;
+ // Generation while loop induction at the end.
+ builder.setInsertionPointToEnd(&whileOp.getAfter().front());
+ // Finalize the induction. Note that the induction could be performed
+ // in the individual if-branches to avoid re-evaluating the conditions.
+ // However, that would result in a rather elaborate forest of yield
+ // instructions during code generation. Moreover, performing the induction
+ // after the if-statements more closely resembles code generated by TACO.
+ unsigned o = 0;
+ SmallVector<Value> operands;
+ Value one = constantIndex(builder, loc, 1);
+ for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ if (isCompressedDLT(dimTypes[tid][dim]) ||
+ isSingletonDLT(dimTypes[tid][dim])) {
+ Value op1 = coord[tid][dim];
+ Value op3 = pidxs[tid][dim];
+ Value cmp =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1, iv);
+ Value add = builder.create<arith::AddIOp>(loc, op3, one);
+ operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3));
+ // Following loops continue iteration from the break point of the
+ // current while loop.
+ pidxs[tid][dim] = whileOp->getResult(o++);
+ // The coordinates are invalid now.
+ coord[tid][dim] = nullptr;
+ // highs remains unchanged.
+ }
+ }
+
+ // Reduction value from users.
+ for (auto &i : reduc) {
+ operands.push_back(i);
+ // In place update reduction variable.
+ i = whileOp->getResult(o++);
+ }
+
+ // An (optional) universal index.
+ if (operands.size() < whileOp.getNumResults()) {
+ assert(operands.size() + 1 == whileOp.getNumResults());
+ // The last one is the universial index.
+ operands.push_back(builder.create<arith::AddIOp>(loc, iv, one));
+ // update the loop starting point of current loop sequence
+ loopSeqStack.back() = whileOp->getResult(o++);
+ }
+
+ assert(o == operands.size());
+ builder.create<scf::YieldOp>(loc, operands);
+ builder.setInsertionPointAfter(whileOp);
+}
+
+void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
+ MutableArrayRef<Value> reduc) {
+ // Clean up the values, it would help use to discover potential bug at a
+ // earlier stage (instead of silently using a wrong value).
+ LoopLevelInfo &loopInfo = loopStack.back();
+ assert(loopInfo.tids.size() == loopInfo.dims.size());
+ SmallVector<Value> red;
+ if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
+ exitCoIterationLoop(rewriter, loc, reduc);
+ } else {
+ exitForLoop(rewriter, loc, reduc);
+ }
+
+ assert(loopStack.size() == loopSeqStack.size());
+ loopStack.pop_back();
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
new file mode 100644
index 000000000000..a1db60c21195
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -0,0 +1,283 @@
+//===- LoopEmitter.h --------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORLOOPEMITTER_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORLOOPEMITTER_H_
+
+#include <vector>
+
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+//===----------------------------------------------------------------------===//
+// SparseTensorLoopEmiter class, manages sparse tensors and helps to
+// generate loop structure to (co)-iterate sparse tensors.
+//
+// An example usage:
+// To generate the following loops over T1<?x?> and T2<?x?>
+//
+// for i in TENSOR_1_0 {
+// for j : TENSOR_2_0 {
+// for k : TENSOR_1_1 {}
+// for k : TENSOR_2_1 {}
+// }
+// }
+//
+// One can use
+//
+// SparseTensorLoopEmiter loopEmiter({T1, T1});
+// loopEmiter.initializeLoopEmit();
+// loopEmiter.enterLoopOverTensorAtDim(T1, 0);
+// loopEmiter.enterLoopOverTensorAtDim(T2, 0);
+// loopEmiter.enterLoopOverTensorAtDim(T1, 1);
+// loopEmiter.exitCurrentLoop();
+// loopEmiter.enterLoopOverTensorAtDim(T2, 1);
+// loopEmiter.exitCurrentLoop(); // exit k
+// loopEmiter.exitCurrentLoop(); // exit j
+// loopEmiter.exitCurrentLoop(); // exit i
+//===----------------------------------------------------------------------===//
+
+class LoopEmitter {
+public:
+ /// Optional callback function to setup dense output tensors when
+ /// initializing the loop emitter (e.g., to fill a dense output with zeros).
+ using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
+ Value memref, Value tensor)>;
+
+ LoopEmitter() = default;
+
+ /// Takes an array of tensors inputs, on which the generated loops will
+ /// iterate on. The index of the tensor in the array is also the tensor id
+ /// (tid) used in related functions. If isSparseOut is set, loop emitter
+ /// assume that the sparse output tensor is empty, and will always generate
+ /// loops on it based on the dim sizes. An optional array could be provided
+ /// (by sparsification) to indicate the loop id sequence that will be
+ /// generated. It is used to establish the mapping between affineDimExpr to
+ /// the corresponding loop index in the loop stack that are maintained by the
+ /// loop emitter.
+ void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
+ bool hasOutput = false, bool isSparseOut = false,
+ ArrayRef<unsigned> topSort = {});
+
+ explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
+ bool hasOutput = false, bool isSparseOut = false,
+ ArrayRef<unsigned> topSort = {});
+
+ /// Starts a loop emitting session by generating all the buffers needed to
+ /// iterate tensors.
+ void initializeLoopEmit(OpBuilder &builder, Location loc,
+ OutputUpdater updater = nullptr);
+
+ /// Generates a list of operations to compute the affine expression.
+ Value genAffine(OpBuilder &builder, AffineExpr a, Location loc);
+
+ /// Enters a new loop sequence, the loops within the same sequence starts
+ /// from the break points of previous loop instead of starting over from 0.
+ /// e.g.,
+ /// {
+ /// // loop sequence start.
+ /// p0 = while(xxx)
+ /// ...
+ /// break p0
+ ///
+ /// // Starts loop from p0
+ /// for (i = p0; i < end; i++)
+ /// ...
+ /// // loop sequence end.
+ /// }
+ void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
+ ArrayRef<size_t> dims);
+
+ // exit the current loop sequence, this will reset universal index to 0.
+ void exitCurrentLoopSeq() {
+ assert(loopSeqStack.size() == loopStack.size() + 1);
+ loopSeqStack.pop_back();
+ }
+
+ // TODO: Gets rid of `dim` in the argument list? Track the dimension we
+ // are currently at internally. Then it would be enterNextDimForTensor.
+ // Still need a way to specify the dim for non annoated dense tensor though,
+ // as it can be accessed out of order.
+ /// Emits loop over tensor_tid_dim, it assumes that loops between
+ /// tensor_tid_[0, dim - 1] have already been generated.
+ /// The function will also perform in-place update on the `reduc` vector to
+ /// return the reduction variable used inside the generated loop.
+ Operation *enterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
+ ArrayRef<size_t> tids,
+ ArrayRef<size_t> dims,
+ MutableArrayRef<Value> reduc = {},
+ bool isParallel = false);
+
+ Operation *enterFilterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
+ size_t tid, size_t dim,
+ AffineExpr affine,
+ MutableArrayRef<Value> reduc = {});
+
+ void genDenseAffineAddressAtCurLevel(OpBuilder &builder, Location loc,
+ size_t tid, size_t dim,
+ AffineExpr affine);
+
+ /// Emits a co-iteration loop over a set of tensors.
+ Operation *enterCoIterationOverTensorsAtDims(
+ OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
+ ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc = {});
+
+ void exitCurrentLoop(RewriterBase &rewriter, Location loc,
+ MutableArrayRef<Value> reduc = {});
+
+ /// Returns the array of coordinate for all the loop generated till now.
+ void getCoordinateArray(SmallVectorImpl<Value> &coords) const {
+ for (auto &l : loopStack)
+ coords.push_back(l.iv);
+ }
+
+ /// Gets loop induction variable at the given level.
+ unsigned getCurrentDepth() const { return loopStack.size(); }
+
+ /// Gets loop induction variable at the given level.
+ Value getLoopIV(size_t level) const {
+ if (level < loopStack.size())
+ return loopStack[level].iv;
+ return nullptr;
+ }
+
+ ///
+ /// Getters.
+ ///
+ const std::vector<std::vector<Value>> &getPidxs() const { return pidxs; };
+ const std::vector<std::vector<Value>> &getCoord() const { return coord; };
+ const std::vector<std::vector<Value>> &getHighs() const { return highs; };
+ const std::vector<std::vector<Value>> &getPtrBuffer() const {
+ return ptrBuffer;
+ };
+ const std::vector<std::vector<Value>> &getIdxBuffer() const {
+ return idxBuffer;
+ };
+ const std::vector<Value> &getValBuffer() const { return valBuffer; };
+
+ constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
+ return llvm::StringLiteral("Emitted from");
+ }
+
+private:
+ struct LoopLevelInfo {
+ LoopLevelInfo(ArrayRef<size_t> tids, ArrayRef<size_t> dims, Operation *loop,
+ Value iv, StringAttr loopTag)
+ : tids(tids), dims(dims), loop(loop), iv(iv) {
+ // Attached a special tag to loop emitter generated loop.
+ if (loopTag)
+ loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
+ }
+ // TODO: maybe use a vector<pair> for tid and dim?
+ // The set of tensors that the loop is operating on
+ const llvm::SmallVector<size_t> tids;
+ // The corresponding dims for the tensors
+ const llvm::SmallVector<size_t> dims;
+ const Operation *loop; // the loop operation
+ const Value iv; // the induction variable for the loop
+ };
+
+ /// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
+ Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim,
+ Value iv);
+
+ bool isOutputTensor(size_t tid) {
+ return hasOutput && tid == tensors.size() - 1;
+ }
+
+ bool isSparseOutput(size_t tid) { return isOutputTensor(tid) && isSparseOut; }
+
+ /// Setups [lo, hi] for iterating tensor[dim], it assumes that tensor[0
+ /// ...dims-1] has already been setup.
+ void prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid,
+ size_t dim);
+
+ /// Emits extra locals, since the locals might not be in simplified lattices
+ /// point used to generate the loops, but are still required to generates
+ /// expressions.
+ void emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder, Location loc,
+ ArrayRef<size_t> tids,
+ ArrayRef<size_t> dims);
+
+ /// Exits a for loop, returns the reduction results, e.g.,
+ /// For sequential for loops:
+ /// %ret = for () {
+ /// ...
+ /// %val = addi %args, %c
+ /// yield %val
+ /// }
+ /// For parallel loops, the following generated code by users:
+ /// %ret = parallel () init(%args) {
+ /// ...
+ /// %val = op %args, %c
+ /// }
+ /// will be transformed into
+ /// %ret = parallel () init(%args) {
+ /// ...
+ /// scf.reduce(%c) bb0(%0, %1){
+ /// %val = op %0, %1
+ /// scf.reduce.return %val
+ /// }
+ /// }
+ /// NOTE: only one instruction will be moved into reduce block,
+ /// transformation will fail if multiple instructions are used to compute
+ /// the reduction value. Return %ret to user, while %val is provided by
+ /// users (`reduc`).
+ void exitForLoop(RewriterBase &rewriter, Location loc,
+ MutableArrayRef<Value> reduc);
+
+ /// Exits a while loop, returns the reduction results.
+ void exitCoIterationLoop(OpBuilder &builder, Location loc,
+ MutableArrayRef<Value> reduc);
+
+ /// A optional string attribute that should be attached to the loop
+ /// generated by loop emitter, it might help following passes to identify
+ /// loops that operates on sparse tensors more easily.
+ StringAttr loopTag;
+ /// Whether the loop emitter needs to treat the last tensor as the output
+ /// tensor.
+ bool hasOutput;
+ bool isSparseOut;
+ /// Input and (optional) output tensors.
+ std::vector<Value> tensors;
+ /// The dim type array for each tensor.
+ std::vector<std::vector<DimLevelType>> dimTypes;
+ /// Sparse iteration information (by tensor and dim). These arrays
+ /// are updated to remain current within the current loop.
+ std::vector<std::vector<Value>> pidxs;
+ std::vector<std::vector<Value>> coord;
+ std::vector<std::vector<Value>> highs;
+ std::vector<std::vector<Value>> ptrBuffer; // to_pointers
+ std::vector<std::vector<Value>> idxBuffer; // to_indices
+ std::vector<Value> valBuffer; // to_value
+
+ // Loop Stack, stores the information of all the nested loops that are
+ // alive.
+ std::vector<LoopLevelInfo> loopStack;
+
+ // Loop Sequence Stack, stores the unversial index for the current loop
+ // sequence.
+ std::vector<Value> loopSeqStack;
+
+ // Maps AffineDimExpr to the index of the loop in loopStack.
+ // TODO: We should probably use a callback function here to make it more
+ // general.
+ std::vector<unsigned> sparsiferLoopLvlMap;
+
+ // TODO: not yet used, it should track the current level for each tensor
+ // to help eliminate `dim` paramters from above APIs.
+ // std::vector<size_t> curLv;
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORLOOPEMITTER_H_
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index ebc4c8152b00..013b8f17ba49 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
+#include "LoopEmitter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -893,7 +894,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
auto enc = getSparseTensorEncoding(rtp);
// 1. Generates loop for the sparse input.
- SparseTensorLoopEmitter loopEmitter(
+ LoopEmitter loopEmitter(
ValueRange{input},
StringAttr::get(getContext(), ForeachOp::getOperationName()));
loopEmitter.initializeLoopEmit(rewriter, loc);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index e652ebdff5cc..c27502ef8cc3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -17,6 +17,7 @@
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
+#include "LoopEmitter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -398,11 +399,11 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
}
return true;
} // An invariant or reduction. In both cases, we treat this as an
- // invariant value, and rely on later replacing and folding to
- // construct a proper reduction chain for the latter case.
- if (codegen)
- vexp = genVectorInvariantValue(rewriter, vl, exp);
- return true;
+ // invariant value, and rely on later replacing and folding to
+ // construct a proper reduction chain for the latter case.
+ if (codegen)
+ vexp = genVectorInvariantValue(rewriter, vl, exp);
+ return true;
}
// Something defined outside the loop-body is invariant.
Operation *def = exp.getDefiningOp();
@@ -540,9 +541,8 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
forOpNew = rewriter.create<scf::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
forOpNew->setAttr(
- SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(),
- forOp->getAttr(
- SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()));
+ LoopEmitter::getLoopEmitterLoopAttrName(),
+ forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
rewriter.setInsertionPointToStart(forOpNew.getBody());
} else {
forOp.setStep(step);
@@ -609,8 +609,8 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
ForOpRewriter(MLIRContext *context, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32)
- : OpRewritePattern(context),
- vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
+ : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
+ enableSIMDIndex32} {}
LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
@@ -618,7 +618,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
// sparse compiler, which means no data dependence analysis is required,
// and its loop-body is very restricted in form.
if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) ||
- !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()))
+ !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
return failure();
// Analyze (!codegen) and rewrite (codegen) loop-body.
if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
@@ -646,8 +646,7 @@ struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
Value inp = op.getSource();
if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
- if (forOp->hasAttr(
- SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) {
+ if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
rewriter.replaceOp(op, redOp.getVector());
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 7c558065b8d5..df82f351f336 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -12,6 +12,7 @@
#include "CodegenEnv.h"
#include "CodegenUtils.h"
+#include "LoopEmitter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1564,7 +1565,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
SmallVector<Value> tensors;
for (OpOperand &t : op->getOpOperands())
tensors.push_back(t.get());
- SparseTensorLoopEmitter lpe(
+ LoopEmitter lpe(
tensors,
StringAttr::get(op.getContext(), linalg::GenericOp::getOperationName()),
/*hasOutput=*/true, /*isSparseOut=*/sparseOut != nullptr,
More information about the Mlir-commits
mailing list