[Mlir-commits] [mlir] deedf55 - [mlir][sparse] Cleanup sparse_tensor::LvlOp's folder (#71085)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 2 11:31:20 PDT 2023


Author: Peiming Liu
Date: 2023-11-02T11:31:16-07:00
New Revision: deedf554fbaf54aa908e4aa3ccea5977a08354d3

URL: https://github.com/llvm/llvm-project/commit/deedf554fbaf54aa908e4aa3ccea5977a08354d3
DIFF: https://github.com/llvm/llvm-project/commit/deedf554fbaf54aa908e4aa3ccea5977a08354d3.diff

LOG: [mlir][sparse] Cleanup sparse_tensor::LvlOp's folder (#71085)

Reuse the util function instead.

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6080317d07a64e0..99214fadf4ba3db 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1293,38 +1293,9 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
   };
 
-  // TODO: we can remove this after SparseTensorEncoding always returns non-null
-  // dimToLvl map.
-  ArrayRef<Size> shape = stt.getDimShape();
-  if (stt.isPermutation()) {
-    Dimension dim = toOrigDim(stt, lvl);
-    if (!ShapedType::isDynamic(shape[dim])) {
-      return getIndexAttr(shape[dim]);
-    }
-    return {};
-  }
-
-  // Non-permutation dim2lvl/lvl2dim maps.
-  AffineExpr lvlExpr = stt.getDimToLvl().getResult(lvl);
-  if (auto binExpr = lvlExpr.dyn_cast<AffineBinaryOpExpr>()) {
-    if (lvlExpr.getKind() == AffineExprKind::Mod) {
-      // j % block_sz, the level size equals to the block size.
-      int64_t lvlSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
-      return getIndexAttr(lvlSz);
-    }
-    if (lvlExpr.getKind() == AffineExprKind::FloorDiv) {
-      // j / block_sz, the level size equals to dim[j] / block_sz.
-      Dimension dim = binExpr.getLHS().cast<AffineDimExpr>().getPosition();
-      int64_t blockSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
-      if (ShapedType::isDynamic(shape[dim]))
-        return {};
-      return getIndexAttr(shape[dim] / blockSz);
-    }
-  }
-
-  auto dim = lvlExpr.cast<AffineDimExpr>().getPosition();
-  if (!ShapedType::isDynamic(dim))
-    return getIndexAttr(shape[dim]);
+  SmallVector<Size> lvlShape = stt.getLvlShape();
+  if (!ShapedType::isDynamic(lvlShape[lvl]))
+    return getIndexAttr(lvlShape[lvl]);
 
   return {};
 }


        


More information about the Mlir-commits mailing list