[Mlir-commits] [mlir] [mlir][sparse] Cleanup sparse_tensor::LvlOp's folder (PR #71085)
Peiming Liu
llvmlistbot at llvm.org
Thu Nov 2 11:14:45 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/71085
None
>From 2071d87b560779babe5db93b807522c16679875b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 2 Nov 2023 18:14:11 +0000
Subject: [PATCH] [mlir][sparse] Cleanup sparse_tensor::LvlOp's folder
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 35 ++-----------------
1 file changed, 3 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6080317d07a64e0..99214fadf4ba3db 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1293,38 +1293,9 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
};
- // TODO: we can remove this after SparseTensorEncoding always returns non-null
- // dimToLvl map.
- ArrayRef<Size> shape = stt.getDimShape();
- if (stt.isPermutation()) {
- Dimension dim = toOrigDim(stt, lvl);
- if (!ShapedType::isDynamic(shape[dim])) {
- return getIndexAttr(shape[dim]);
- }
- return {};
- }
-
- // Non-permutation dim2lvl/lvl2dim maps.
- AffineExpr lvlExpr = stt.getDimToLvl().getResult(lvl);
- if (auto binExpr = lvlExpr.dyn_cast<AffineBinaryOpExpr>()) {
- if (lvlExpr.getKind() == AffineExprKind::Mod) {
- // j % block_sz, the level size equals to the block size.
- int64_t lvlSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
- return getIndexAttr(lvlSz);
- }
- if (lvlExpr.getKind() == AffineExprKind::FloorDiv) {
- // j / block_sz, the level size equals to dim[j] / block_sz.
- Dimension dim = binExpr.getLHS().cast<AffineDimExpr>().getPosition();
- int64_t blockSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
- if (ShapedType::isDynamic(shape[dim]))
- return {};
- return getIndexAttr(shape[dim] / blockSz);
- }
- }
-
- auto dim = lvlExpr.cast<AffineDimExpr>().getPosition();
- if (!ShapedType::isDynamic(dim))
- return getIndexAttr(shape[dim]);
+ SmallVector<Size> lvlShape = stt.getLvlShape();
+ if (!ShapedType::isDynamic(lvlShape[lvl]))
+ return getIndexAttr(lvlShape[lvl]);
return {};
}
More information about the Mlir-commits
mailing list