[Mlir-commits] [mlir] [mlir][sparse] implement non-permutation MapRef encoding (PR #69406)
Yinying Li
llvmlistbot at llvm.org
Wed Oct 18 09:12:23 PDT 2023
================
@@ -688,25 +688,70 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
return dimSizesBuffer;
}
// Otherwise, some code needs to be generated to set up the buffers.
- // TODO: use the lvl2dim once available and deal with non-permutations!
+ // This code deals with permutations as well as non-permutations that
+ // arise from rank changing blocking.
const auto dimToLvl = stt.getDimToLvl();
- assert(dimToLvl.isPermutation());
- SmallVector<Value> dim2lvlValues(dimRank);
- SmallVector<Value> lvl2dimValues(lvlRank);
+ SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
+ SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
SmallVector<Value> lvlSizesValues(lvlRank);
+ // Generate dim2lvl.
+ assert(lvlRank == dimToLvl.getNumResults());
for (Level l = 0; l < lvlRank; l++) {
- // The `d`th source variable occurs in the `l`th result position.
- Dimension d = dimToLvl.getDimPosition(l);
- Value lvl = constantIndex(builder, loc, l);
- Value dim = constantIndex(builder, loc, d);
- dim2lvlValues[d] = lvl;
- lvl2dimValues[l] = dim;
- if (stt.isDynamicDim(d))
- lvlSizesValues[l] =
- builder.create<memref::LoadOp>(loc, dimSizesBuffer, dim);
- else
- lvlSizesValues[l] = dimShapesValues[d];
+ AffineExpr exp = dimToLvl.getResult(l);
+ // We expect:
+ // (1) l = d
+ // (2) l = d / c
+ // (3) l = d % c
+ Dimension d = 0;
+ uint64_t cf = 0, cm = 0;
+ switch (exp.getKind()) {
+ case AffineExprKind::DimId:
+ d = exp.cast<AffineDimExpr>().getPosition();
+ break;
+ case AffineExprKind::FloorDiv:
+ d = exp.cast<AffineBinaryOpExpr>()
+ .getLHS()
+ .cast<AffineDimExpr>()
+ .getPosition();
+ cf = exp.cast<AffineBinaryOpExpr>()
+ .getRHS()
+ .cast<AffineConstantExpr>()
+ .getValue();
+ break;
+ case AffineExprKind::Mod:
+ d = exp.cast<AffineBinaryOpExpr>()
+ .getLHS()
+ .cast<AffineDimExpr>()
+ .getPosition();
+ cm = exp.cast<AffineBinaryOpExpr>()
+ .getRHS()
+ .cast<AffineConstantExpr>()
+ .getValue();
+ break;
----------------
yinying-lisa-li wrote:
This is nice! I might "borrow" this and refactor my code. ;)
https://github.com/llvm/llvm-project/pull/69406
More information about the Mlir-commits
mailing list