[Mlir-commits] [mlir] [mlir][sparse] implements sparse_tensor.reinterpret_map (PR #70388)

Peiming Liu llvmlistbot at llvm.org
Thu Oct 26 15:50:01 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/70388

>From 434160a15cde8c6e9c5518031f89e01df996da64 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 26 Oct 2023 22:27:57 +0000
Subject: [PATCH 1/3] [mlir][sparse] implements sparse_tensor.reinterpret_map

---
 .../Transforms/SparseTensorCodegen.cpp             | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e855a6e19a717a4..5cdf8cd7ccc9d8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -725,6 +725,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
   }
 };
 
+class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Simply fold the operation.
+    rewriter.replaceOp(op, adaptor.getSource());
+    return success();
+  }
+};
+
 /// Sparse codegen rule for the alloc operator.
 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
 class SparseTensorAllocConverter
@@ -1564,7 +1576,7 @@ void mlir::populateSparseTensorCodegenPatterns(
                SparseCastConverter, SparseExtractSliceConverter,
                SparseTensorLoadConverter, SparseExpandConverter,
                SparseCompressConverter, SparseInsertConverter,
-               SparseReorderCOOConverter,
+               SparseReorderCOOConverter, SparseReMapConverter,
                SparseSliceGetterOpConverter<ToSliceOffsetOp,
                                             StorageSpecifierKind::DimOffset>,
                SparseSliceGetterOpConverter<ToSliceStrideOp,

>From 37350181a55cb0f5d4f84317223e6d7af7f3af91 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 26 Oct 2023 22:46:35 +0000
Subject: [PATCH 2/3] add test cases

---
 .../SparseTensor/IR/SparseTensorDialect.cpp        | 13 +++++++++++--
 .../Transforms/SparseTensorConversion.cpp          | 14 +++++++++++++-
 .../Dialect/SparseTensor/CPU/block.mlir            | 12 ++++++++++++
 3 files changed, 36 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index be44a0d31c92a7d..d069eab1f81aea1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -453,11 +453,20 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
     // Do constant propagation on the affine map.
     AffineExpr evalExp =
         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
-    if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
+    if (auto c = evalExp.dyn_cast<AffineConstantExpr>()) {
       ret.push_back(c.getValue() + 1);
-    else
+    } else {
+      if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
+          mod && mod.getKind() == AffineExprKind::Mod) {
+        if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
+          ret.push_back(bound.getValue());
+          continue;
+        }
+      }
       ret.push_back(ShapedType::kDynamic);
+    }
   }
+  assert(ret.size() == rank);
   return ret;
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c3f046e52fd6790..a92038ce7c98d4e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -336,6 +336,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
   }
 };
 
+class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Simply fold the operation.
+    rewriter.replaceOp(op, adaptor.getSource());
+    return success();
+  }
+};
+
 /// Sparse conversion rule for the new operator.
 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
 public:
@@ -770,7 +782,7 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
                                                   RewritePatternSet &patterns) {
   patterns
       .add<SparseReturnConverter, SparseTensorLvlOpConverter,
-           SparseCastConverter, SparseTensorNewConverter,
+           SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
            SparseTensorAllocConverter, SparseTensorEmptyConverter,
            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
index d0b5e77bd4a724e..78d35ada6acc11c 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
@@ -34,6 +34,11 @@
     )
 }>
 
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> ( i  : dense, j  : compressed, k  : dense, l  : dense)
+}>
+
+
 !Filename = !llvm.ptr<i8>
 
 //
@@ -77,6 +82,13 @@ module {
     %vecv = vector.transfer_read %val[%c0], %f0 : memref<?xf64>, vector<12xf64>
     vector.print %vecv : vector<12xf64>
 
+    // CHECK-NEXT: ( 1, 2, 0, 3, 4, 0, 0, 5, 6, 7, 8, 0 )
+    %t1 = sparse_tensor.reinterpret_map %A : tensor<?x?xf64, #BSR>
+                                          to tensor<?x?x2x2xf64, #DSDD>
+    %vdsdd = sparse_tensor.values %t1 : tensor<?x?x2x2xf64, #DSDD> to memref<?xf64>
+    %vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref<?xf64>, vector<12xf64>
+    vector.print %vecdsdd : vector<12xf64>
+
     // Release the resources.
     bufferization.dealloc_tensor %A: tensor<?x?xf64, #BSR>
 

>From 725f4c572c966e3b72afbc3c358c92462474b8ce Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 26 Oct 2023 22:49:47 +0000
Subject: [PATCH 3/3] add some comments.

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index d069eab1f81aea1..fe25ad848c7449c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -458,6 +458,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
     } else {
       if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
           mod && mod.getKind() == AffineExprKind::Mod) {
+        // We can still infer a static bound for expressions in form
+        // "d % constant" since d % constant \in [0, constant).
         if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
           ret.push_back(bound.getValue());
           continue;



More information about the Mlir-commits mailing list