[Mlir-commits] [mlir] f16cb0e - [mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computation (#69540)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 18 18:01:01 PDT 2023
Author: Aart Bik
Date: 2023-10-18T18:00:56-07:00
New Revision: f16cb0eade08035fea5e8310bd4a64c8f286c929
URL: https://github.com/llvm/llvm-project/commit/f16cb0eade08035fea5e8310bd4a64c8f286c929
DIFF: https://github.com/llvm/llvm-project/commit/f16cb0eade08035fea5e8310bd4a64c8f286c929.diff
LOG: [mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computation (#69540)
This makes sure
- GEN MAP dim=2 lvl=4
(d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2)
--
(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3)
is indeed encoded as
MAP-REF (dim=2, lvl=4) isperm=0
d2l = [ d0/2 d1/2 d0%2 d1%2 ]
ld2 = [ l2+2*l0 l3+2*l1 ]
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 98b412c8ec9eb5b..b1b1d67ac2d420d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
// This code deals with permutations as well as non-permutations that
// arise from rank changing blocking.
const auto dimToLvl = stt.getDimToLvl();
+ const auto lvlToDim = stt.getLvlToDim();
SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
SmallVector<Value> lvlSizesValues(lvlRank);
@@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
Dimension d = 0;
uint64_t cf = 0, cm = 0;
switch (exp.getKind()) {
- case AffineExprKind::DimId:
+ case AffineExprKind::DimId: {
d = exp.cast<AffineDimExpr>().getPosition();
break;
- case AffineExprKind::FloorDiv:
- d = exp.cast<AffineBinaryOpExpr>()
- .getLHS()
- .cast<AffineDimExpr>()
- .getPosition();
- cf = exp.cast<AffineBinaryOpExpr>()
- .getRHS()
- .cast<AffineConstantExpr>()
- .getValue();
+ }
+ case AffineExprKind::FloorDiv: {
+ auto floor = exp.cast<AffineBinaryOpExpr>();
+ d = floor.getLHS().cast<AffineDimExpr>().getPosition();
+ cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
break;
- case AffineExprKind::Mod:
- d = exp.cast<AffineBinaryOpExpr>()
- .getLHS()
- .cast<AffineDimExpr>()
- .getPosition();
- cm = exp.cast<AffineBinaryOpExpr>()
- .getRHS()
- .cast<AffineConstantExpr>()
- .getValue();
+ }
+ case AffineExprKind::Mod: {
+ auto mod = exp.cast<AffineBinaryOpExpr>();
+ d = mod.getLHS().cast<AffineDimExpr>().getPosition();
+ cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
break;
+ }
default:
llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
}
dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
- lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim
// Compute the level sizes.
// (1) l = d : size(d)
// (2) l = d / c : size(d) / c
@@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
}
lvlSizesValues[l] = lvlSz;
}
+ // Generate lvl2dim.
+ assert(dimRank == lvlToDim.getNumResults());
+ for (Dimension d = 0; d < dimRank; d++) {
+ AffineExpr exp = lvlToDim.getResult(d);
+ // We expect:
+ // (1) d = l
+ // (2) d = l' * c + l
+ Level l = 0, ll = 0;
+ uint64_t c = 0;
+ switch (exp.getKind()) {
+ case AffineExprKind::DimId: {
+ l = exp.cast<AffineDimExpr>().getPosition();
+ break;
+ }
+ case AffineExprKind::Add: {
+ // Always mul on lhs, symbol/constant on rhs.
+ auto add = exp.cast<AffineBinaryOpExpr>();
+ assert(add.getLHS().getKind() == AffineExprKind::Mul);
+ auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
+ ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
+ c = mul.getRHS().cast<AffineConstantExpr>().getValue();
+ l = add.getRHS().cast<AffineDimExpr>().getPosition();
+ break;
+ }
+ default:
+ llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
+ }
+ lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
+ }
// Return buffers.
dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
More information about the Mlir-commits
mailing list