[Mlir-commits] [mlir] ef22298 - [mlir][sparse] implements sparse_tensor.reinterpret_map (#70388)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 26 16:00:36 PDT 2023


Author: Peiming Liu
Date: 2023-10-26T16:00:32-07:00
New Revision: ef222988b477bf91fe3b6e9cd7d881d19af2d605

URL: https://github.com/llvm/llvm-project/commit/ef222988b477bf91fe3b6e9cd7d881d19af2d605
DIFF: https://github.com/llvm/llvm-project/commit/ef222988b477bf91fe3b6e9cd7d881d19af2d605.diff

LOG: [mlir][sparse] implements sparse_tensor.reinterpret_map (#70388)

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4f1c446faec3714..9a6d3161be3d6e4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -453,11 +453,22 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
     // Do constant propagation on the affine map.
     AffineExpr evalExp =
         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
-    if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
+    if (auto c = evalExp.dyn_cast<AffineConstantExpr>()) {
       ret.push_back(c.getValue() + 1);
-    else
+    } else {
+      if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
+          mod && mod.getKind() == AffineExprKind::Mod) {
+        // We can still infer a static bound for expressions in form
+        // "d % constant" since d % constant \in [0, constant).
+        if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
+          ret.push_back(bound.getValue());
+          continue;
+        }
+      }
       ret.push_back(ShapedType::kDynamic);
+    }
   }
+  assert(ret.size() == rank);
   return ret;
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e855a6e19a717a4..5cdf8cd7ccc9d8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -725,6 +725,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
   }
 };
 
+class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Simply fold the operation.
+    rewriter.replaceOp(op, adaptor.getSource());
+    return success();
+  }
+};
+
 /// Sparse codegen rule for the alloc operator.
 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
 class SparseTensorAllocConverter
@@ -1564,7 +1576,7 @@ void mlir::populateSparseTensorCodegenPatterns(
                SparseCastConverter, SparseExtractSliceConverter,
                SparseTensorLoadConverter, SparseExpandConverter,
                SparseCompressConverter, SparseInsertConverter,
-               SparseReorderCOOConverter,
+               SparseReorderCOOConverter, SparseReMapConverter,
                SparseSliceGetterOpConverter<ToSliceOffsetOp,
                                             StorageSpecifierKind::DimOffset>,
                SparseSliceGetterOpConverter<ToSliceStrideOp,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c3f046e52fd6790..a92038ce7c98d4e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -336,6 +336,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
   }
 };
 
+class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Simply fold the operation.
+    rewriter.replaceOp(op, adaptor.getSource());
+    return success();
+  }
+};
+
 /// Sparse conversion rule for the new operator.
 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
 public:
@@ -770,7 +782,7 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
                                                   RewritePatternSet &patterns) {
   patterns
       .add<SparseReturnConverter, SparseTensorLvlOpConverter,
-           SparseCastConverter, SparseTensorNewConverter,
+           SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
            SparseTensorAllocConverter, SparseTensorEmptyConverter,
            SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
            SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
index d0b5e77bd4a724e..78d35ada6acc11c 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
@@ -34,6 +34,11 @@
     )
 }>
 
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> ( i  : dense, j  : compressed, k  : dense, l  : dense)
+}>
+
+
 !Filename = !llvm.ptr<i8>
 
 //
@@ -77,6 +82,13 @@ module {
     %vecv = vector.transfer_read %val[%c0], %f0 : memref<?xf64>, vector<12xf64>
     vector.print %vecv : vector<12xf64>
 
+    // CHECK-NEXT: ( 1, 2, 0, 3, 4, 0, 0, 5, 6, 7, 8, 0 )
+    %t1 = sparse_tensor.reinterpret_map %A : tensor<?x?xf64, #BSR>
+                                          to tensor<?x?x2x2xf64, #DSDD>
+    %vdsdd = sparse_tensor.values %t1 : tensor<?x?x2x2xf64, #DSDD> to memref<?xf64>
+    %vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref<?xf64>, vector<12xf64>
+    vector.print %vecdsdd : vector<12xf64>
+
     // Release the resources.
     bufferization.dealloc_tensor %A: tensor<?x?xf64, #BSR>
 


        


More information about the Mlir-commits mailing list