[Mlir-commits] [mlir] ef22298 - [mlir][sparse] implements sparse_tensor.reinterpret_map (#70388)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 26 16:00:36 PDT 2023
Author: Peiming Liu
Date: 2023-10-26T16:00:32-07:00
New Revision: ef222988b477bf91fe3b6e9cd7d881d19af2d605
URL: https://github.com/llvm/llvm-project/commit/ef222988b477bf91fe3b6e9cd7d881d19af2d605
DIFF: https://github.com/llvm/llvm-project/commit/ef222988b477bf91fe3b6e9cd7d881d19af2d605.diff
LOG: [mlir][sparse] implements sparse_tensor.reinterpret_map (#70388)
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4f1c446faec3714..9a6d3161be3d6e4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -453,11 +453,22 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
// Do constant propagation on the affine map.
AffineExpr evalExp =
simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
- if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
+ if (auto c = evalExp.dyn_cast<AffineConstantExpr>()) {
ret.push_back(c.getValue() + 1);
- else
+ } else {
+ if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
+ mod && mod.getKind() == AffineExprKind::Mod) {
+ // We can still infer a static bound for expressions in form
+ // "d % constant" since d % constant \in [0, constant).
+ if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
+ ret.push_back(bound.getValue());
+ continue;
+ }
+ }
ret.push_back(ShapedType::kDynamic);
+ }
}
+ assert(ret.size() == rank);
return ret;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e855a6e19a717a4..5cdf8cd7ccc9d8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -725,6 +725,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
}
};
+class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Simply fold the operation.
+ rewriter.replaceOp(op, adaptor.getSource());
+ return success();
+ }
+};
+
/// Sparse codegen rule for the alloc operator.
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
class SparseTensorAllocConverter
@@ -1564,7 +1576,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
- SparseReorderCOOConverter,
+ SparseReorderCOOConverter, SparseReMapConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c3f046e52fd6790..a92038ce7c98d4e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -336,6 +336,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
}
};
+class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Simply fold the operation.
+ rewriter.replaceOp(op, adaptor.getSource());
+ return success();
+ }
+};
+
/// Sparse conversion rule for the new operator.
class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
public:
@@ -770,7 +782,7 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns
.add<SparseReturnConverter, SparseTensorLvlOpConverter,
- SparseCastConverter, SparseTensorNewConverter,
+ SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
SparseTensorAllocConverter, SparseTensorEmptyConverter,
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
index d0b5e77bd4a724e..78d35ada6acc11c 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
@@ -34,6 +34,11 @@
)
}>
+#DSDD = #sparse_tensor.encoding<{
+ map = (i, j, k, l) -> ( i : dense, j : compressed, k : dense, l : dense)
+}>
+
+
!Filename = !llvm.ptr<i8>
//
@@ -77,6 +82,13 @@ module {
%vecv = vector.transfer_read %val[%c0], %f0 : memref<?xf64>, vector<12xf64>
vector.print %vecv : vector<12xf64>
+ // CHECK-NEXT: ( 1, 2, 0, 3, 4, 0, 0, 5, 6, 7, 8, 0 )
+ %t1 = sparse_tensor.reinterpret_map %A : tensor<?x?xf64, #BSR>
+ to tensor<?x?x2x2xf64, #DSDD>
+ %vdsdd = sparse_tensor.values %t1 : tensor<?x?x2x2xf64, #DSDD> to memref<?xf64>
+ %vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref<?xf64>, vector<12xf64>
+ vector.print %vecdsdd : vector<12xf64>
+
// Release the resources.
bufferization.dealloc_tensor %A: tensor<?x?xf64, #BSR>
More information about the Mlir-commits
mailing list