[Mlir-commits] [mlir] [mlir][sparse] code cleanup (using inferred type to construct to_[buf… (PR #83361)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 28 16:36:22 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>

…fer] op).

---
Full diff: https://github.com/llvm/llvm-project/pull/83361.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+4-19) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (-11) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index b888dfadb9c714..370818e0de2578 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -556,37 +556,22 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
 
 Value sparse_tensor::genToPositions(OpBuilder &builder, Location loc,
                                     Value tensor, Level lvl) {
-  const auto srcTp = getSparseTensorType(tensor);
-  const Type posTp = srcTp.getPosType();
-  const Type memTp = get1DMemRefType(posTp, /*withLayout=*/false);
-  return builder.create<ToPositionsOp>(loc, memTp, tensor,
-                                       builder.getIndexAttr(lvl));
+  return builder.create<ToPositionsOp>(loc, tensor, lvl);
 }
 
 Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
                                       Value tensor, Level lvl) {
-  const auto srcTp = getSparseTensorType(tensor);
-  const Type crdTp = srcTp.getCrdType();
-  const Type memTp =
-      get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart());
-  return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
-                                         builder.getIndexAttr(lvl));
+  return builder.create<ToCoordinatesOp>(loc, tensor, lvl);
 }
 
 Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
                                             Value tensor) {
-  const auto srcTp = getSparseTensorType(tensor);
-  const Type crdTp = srcTp.getCrdType();
-  const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
-  return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
+  return builder.create<ToCoordinatesBufferOp>(loc, tensor);
 }
 
 Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
                                  Value tensor) {
-  RankedTensorType srcTp = getRankedTensorType(tensor);
-  Type valTp = get1DMemRefType(srcTp.getElementType(),
-                               /*withLayout=*/false);
-  return builder.create<ToValuesOp>(loc, valTp, tensor);
+  return builder.create<ToValuesOp>(loc, tensor);
 }
 
 Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index cc119bc7045595..eb246fc7c45bbb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -228,17 +228,6 @@ void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer);
 void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
                   Location loc, Value src);
 
-/// Generates a 1D MemRefType with a dynamic size. When withLayout is set, the
-/// returned memref has a layout has unknown strides and offsets. Otherwise,
-/// a memref with a standard unit stride zero offset layout is returned.
-inline MemRefType get1DMemRefType(Type etp, bool withLayout) {
-  auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get(
-                                 etp.getContext(), ShapedType::kDynamic,
-                                 {ShapedType::kDynamic})
-                           : StridedLayoutAttr();
-  return MemRefType::get(ShapedType::kDynamic, etp, layout);
-}
-
 /// Scans to top of generated loop.
 Operation *getTop(Operation *op);
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/83361


More information about the Mlir-commits mailing list