[Mlir-commits] [mlir] [mlir][sparse] implements sparse_tensor.reinterpret_map (PR #70388)
Peiming Liu
llvmlistbot at llvm.org
Thu Oct 26 15:50:01 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/70388
>From 434160a15cde8c6e9c5518031f89e01df996da64 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 26 Oct 2023 22:27:57 +0000
Subject: [PATCH 1/3] [mlir][sparse] implements sparse_tensor.reinterpret_map
---
.../Transforms/SparseTensorCodegen.cpp | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e855a6e19a717a4..5cdf8cd7ccc9d8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -725,6 +725,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
}
};
+class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Simply fold the operation.
+ rewriter.replaceOp(op, adaptor.getSource());
+ return success();
+ }
+};
+
/// Sparse codegen rule for the alloc operator.
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
class SparseTensorAllocConverter
@@ -1564,7 +1576,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
- SparseReorderCOOConverter,
+ SparseReorderCOOConverter, SparseReMapConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
>From 37350181a55cb0f5d4f84317223e6d7af7f3af91 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 26 Oct 2023 22:46:35 +0000
Subject: [PATCH 2/3] add test cases
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 13 +++++++++++--
.../Transforms/SparseTensorConversion.cpp | 14 +++++++++++++-
.../Dialect/SparseTensor/CPU/block.mlir | 12 ++++++++++++
3 files changed, 36 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index be44a0d31c92a7d..d069eab1f81aea1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -453,11 +453,20 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
// Do constant propagation on the affine map.
AffineExpr evalExp =
simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
- if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
+ if (auto c = evalExp.dyn_cast<AffineConstantExpr>()) {
ret.push_back(c.getValue() + 1);
- else
+ } else {
+ if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
+ mod && mod.getKind() == AffineExprKind::Mod) {
+ if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
+ ret.push_back(bound.getValue());
+ continue;
+ }
+ }
ret.push_back(ShapedType::kDynamic);
+ }
}
+ assert(ret.size() == rank);
return ret;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c3f046e52fd6790..a92038ce7c98d4e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -336,6 +336,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
}
};
+class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Simply fold the operation.
+ rewriter.replaceOp(op, adaptor.getSource());
+ return success();
+ }
+};
+
/// Sparse conversion rule for the new operator.
class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
public:
@@ -770,7 +782,7 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns
.add<SparseReturnConverter, SparseTensorLvlOpConverter,
- SparseCastConverter, SparseTensorNewConverter,
+ SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
SparseTensorAllocConverter, SparseTensorEmptyConverter,
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
index d0b5e77bd4a724e..78d35ada6acc11c 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
@@ -34,6 +34,11 @@
)
}>
+#DSDD = #sparse_tensor.encoding<{
+ map = (i, j, k, l) -> ( i : dense, j : compressed, k : dense, l : dense)
+}>
+
+
!Filename = !llvm.ptr<i8>
//
@@ -77,6 +82,13 @@ module {
%vecv = vector.transfer_read %val[%c0], %f0 : memref<?xf64>, vector<12xf64>
vector.print %vecv : vector<12xf64>
+ // CHECK-NEXT: ( 1, 2, 0, 3, 4, 0, 0, 5, 6, 7, 8, 0 )
+ %t1 = sparse_tensor.reinterpret_map %A : tensor<?x?xf64, #BSR>
+ to tensor<?x?x2x2xf64, #DSDD>
+ %vdsdd = sparse_tensor.values %t1 : tensor<?x?x2x2xf64, #DSDD> to memref<?xf64>
+ %vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref<?xf64>, vector<12xf64>
+ vector.print %vecdsdd : vector<12xf64>
+
// Release the resources.
bufferization.dealloc_tensor %A: tensor<?x?xf64, #BSR>
>From 725f4c572c966e3b72afbc3c358c92462474b8ce Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 26 Oct 2023 22:49:47 +0000
Subject: [PATCH 3/3] add some comments.
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index d069eab1f81aea1..fe25ad848c7449c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -458,6 +458,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
} else {
if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
mod && mod.getKind() == AffineExprKind::Mod) {
+ // We can still infer a static bound for expressions in form
+ // "d % constant" since d % constant \in [0, constant).
if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
ret.push_back(bound.getValue());
continue;
More information about the Mlir-commits
mailing list