[Mlir-commits] [mlir] [mlir][sparse] support type conversion from batched sparse tensors to… (PR #83163)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 27 10:06:02 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
… memrefs.
---
Full diff: https://github.com/llvm/llvm-project/pull/83163.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h (+7-7)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+8)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+15-6)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+2-3)
- (modified) mlir/test/Dialect/SparseTensor/codegen.mlir (+15)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index ca98665256be5a..5d1db2323f95f0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -374,6 +374,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// is non-null (since no fixed result is valid for every dense-tensor).
::mlir::sparse_tensor::Level getLvlRank() const;
+ uint64_t getBatchLvlRank() const;
+
//
// lvlTypes methods.
//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
index 27dc39609cdadd..ce34ae43d1c181 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -30,15 +30,15 @@ namespace sparse_tensor {
/// ; if dense:
/// <nothing>
/// ; if compressed:
-/// memref<? x pos> positions ; positions for level l
-/// memref<? x crd> coordinates ; coordinates for level l
-/// ; if loose-compressed:
-/// memref<? x pos> positions ; lo/hi position pairs for level l
-/// memref<? x crd> coordinates ; coordinates for level l
+/// memref<[batch] x ? x pos> positions ; positions for level l
+/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
+/// ; if loose-[batch] x compressed:
+/// memref<[batch] x ? x pos> positions ; lo/hi pos pairs for level l
+/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
/// ; if singleton/2-out-of-4:
-/// memref<? x crd> coordinates ; coordinates for level l
+/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
///
-/// memref<? x eltType> values ; values
+/// memref<[batch] x ? x eltType> values ; values
///
/// struct sparse_tensor.storage_specifier {
/// array<rank x int> lvlSizes ; sizes/cardinalities for each level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 1a090ddb782fdb..c93a4fcd922c28 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -253,6 +253,14 @@ class SparseTensorType {
CrdTransDirectionKind::dim2lvl);
}
+ /// Returns the Level-shape.
+ SmallVector<Size> getBatchLvlShape() const {
+ auto lvlShape = getEncoding().tranlateShape(getDimShape(),
+ CrdTransDirectionKind::dim2lvl);
+ lvlShape.truncate(getEncoding().getBatchLvlRank());
+ return lvlShape;
+ }
+
/// Returns the type with an identity mapping.
RankedTensorType getDemappedType() const {
return RankedTensorType::get(getLvlShape(), getElementType(),
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fd0ed26fbde072..69c3413f35ea9c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -126,13 +126,16 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
const Type posType = stt.getPosType();
const Type eltType = stt.getElementType();
+ SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
+ memrefShape.push_back(ShapedType::kDynamic);
+
const Type specType = StorageSpecifierType::get(stt.getEncoding());
- // memref<? x pos> positions
- const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
- // memref<? x crd> coordinates
- const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
- // memref<? x eltType> values
- const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
+ // memref<[batch] x ? x pos> positions
+ const Type posMemType = MemRefType::get(memrefShape, posType);
+ // memref<[batch] x ? x crd> coordinates
+ const Type crdMemType = MemRefType::get(memrefShape, crdType);
+ // memref<[batch] x ? x eltType> values
+ const Type valMemType = MemRefType::get(memrefShape, eltType);
StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
callback](FieldIndex fieldIdx,
@@ -336,6 +339,12 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
}
+uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
+ ArrayRef<LevelType> lvlTypes = getLvlTypes();
+ auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
+ return std::distance(lastBatch, lvlTypes.rend());
+}
+
bool SparseTensorEncodingAttr::isAllDense() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0ccb11f3a6b858..d5eec4ae67e798 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1293,7 +1293,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
Value tensor = fKind == SparseTensorFieldKind::ValMemRef
? op.getValues()
: op.getLevels()[fIdx];
-
+ // TODO: handle batch.
TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
if (mem.getType().getRank() > 1) {
// Flattens the buffer to rank 1.
@@ -1322,9 +1322,8 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
- // FIXME: dim/lvl confusion!
// Sets up the level size.
- auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
+ auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
desc.setLvlSize(rewriter, loc, lvl, lvlSize);
// We use a single AOS array to store the trailing COO, so there is only
// one memory size to set for the entire COO section.
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index c1a976c84fecca..64a515a38588a2 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -34,6 +34,10 @@
map = (d0, d1) -> (d1 : dense, d0 : compressed)
}>
+#BCSR = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0: batch, d1: batch, d2 : dense, d3 : compressed)
+}>
+
#DCSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed, d1 : compressed),
crdWidth = 64,
@@ -182,6 +186,17 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
return
}
+// CHECK-LABEL: func @sparse_bcsr(
+// CHECK-SAME: %[[A0:.*0]]: memref<?x2x?xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?x2x?xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?x2x?xindex>,
+// CHECK-SAME: %[[A3:.*]]: memref<?x2x?xf64>,
+// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+// CHECK: return
+func.func @sparse_bcsr(%arg0: tensor<?x2x?x?xf64, #BCSR>) {
+ return
+}
+
// CHECK-LABEL: func @sparse_dcsr(
// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/83163
More information about the Mlir-commits
mailing list