[Mlir-commits] [mlir] 33267f4 - [mlir][sparse] convert a sparse tensor slice to sparse tensor correctly.

Peiming Liu llvmlistbot at llvm.org
Tue Mar 28 14:39:37 PDT 2023


Author: Peiming Liu
Date: 2023-03-28T21:39:31Z
New Revision: 33267f4007c7525331d24bc0e91ecba2e520679f

URL: https://github.com/llvm/llvm-project/commit/33267f4007c7525331d24bc0e91ecba2e520679f
DIFF: https://github.com/llvm/llvm-project/commit/33267f4007c7525331d24bc0e91ecba2e520679f.diff

LOG: [mlir][sparse] convert a sparse tensor slice to sparse tensor correctly.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147074

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 3bf11189d3805..11e6ac81b7e14 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -701,6 +701,11 @@ LogicalResult ConvertOp::verify() {
     if (auto tp2 = getDest().getType().dyn_cast<RankedTensorType>()) {
       if (tp1.getRank() != tp2.getRank())
         return emitError("unexpected conversion mismatch in rank");
+      auto dstEnc =
+          tp2.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
+      if (dstEnc && dstEnc.isSlice())
+        return emitError("cannot convert to a sparse tensor slice");
+
       auto shape1 = tp1.getShape();
       auto shape2 = tp2.getShape();
       // Accept size matches between the source and the destination type

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index cc7524cf55be6..e2a2fcc986b82 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1058,9 +1058,14 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
     SparseTensorEncodingAttr encSrc =
         getSparseTensorEncoding(op.getSource().getType());
+    // The output tensor can not be a slice and those cases should have been
+    // rejected by ConvertOp::verify() already.
+    assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
     // Different encoding (except for 
diff erent bitwidth) should be handled by
     // rewriting.
-    if (encDst.withoutBitWidths() != encSrc.withoutBitWidths()) {
+    // We need further rewrites if the input tensor is a slice too.
+    if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
+        encSrc.isSlice()) {
       return failure();
     }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b08ca4961f467..4a3e62ffaca04 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -618,7 +618,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
                                 PatternRewriter &rewriter) const override {
     auto encDst = getSparseTensorEncoding(op.getType());
     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
-    if (encDst && encSrc &&
+    if (encDst && encSrc && !encSrc.isSlice() &&
         encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
       // Trivial tensor conversion and simple element type conversion is handled
       // in codegen.

diff  --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 5c8ebe325d6d2..21f3b2faf35ee 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -25,6 +25,10 @@
   dimLevelType = ["compressed"]
 }>
 
+#SortedCOO2D = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed-nu", "singleton" ],
+}>
+
 #SortedCOO3D = #sparse_tensor.encoding<{
   dimLevelType = [ "compressed-nu", "singleton-nu", "singleton" ]
 
@@ -35,6 +39,11 @@
   dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
 }>
 
+#COOSlice = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed-nu", "singleton" ],
+  slice = [ (2, 2, 1), (12, 13, 1) ]
+}>
+
 // CHECK-LABEL: func @sparse_nop_convert(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
 //       CHECK: return %[[A]] : !llvm.ptr<i8>
@@ -185,3 +194,20 @@ func.func @sparse_convert_permuted(%arg0: tensor<?x?x?xf32, #SortedCOO3D>) -> te
   %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf32, #SortedCOO3D> to tensor<?x?x?xf32, #TsssPermuted>
   return %0 : tensor<?x?x?xf32, #TsssPermuted>
 }
+
+// CHECK-RWT-LABEL: func.func @sparse_convert_slice(
+//  CHECK-RWT-SAME: %[[VAL_0:.*]]: tensor<2x13xi32, #{{.*}}>) -> tensor<2x13xi32, #{{.*}}> {
+//       CHECK-RWT: %[[VAL_1:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor<2x13xi32, #{{.*}}>
+//       CHECK-RWT: %[[VAL_2:.*]] = bufferization.alloc_tensor() size_hint=%[[VAL_1]] : tensor<2x13xi32, #{{.*}}>
+//       CHECK-RWT: %[[VAL_3:.*]] = sparse_tensor.foreach in %[[VAL_0]] init(%[[VAL_2]]) : tensor<2x13xi32, #{{.*}}>, tensor<2x13xi32, #{{.*}}> -> tensor<2x13xi32, #{{.*}}> do {
+//       CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: tensor<2x13xi32, #{{.*}}>):
+//       CHECK-RWT:   %[[VAL_8:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]] : tensor<2x13xi32, #{{.*}}>
+//       CHECK-RWT:   sparse_tensor.yield %[[VAL_8]] : tensor<2x13xi32, #{{.*}}>
+//       CHECK-RWT: }
+//       CHECK-RWT: %[[VAL_9:.*]] = sparse_tensor.load %[[VAL_10:.*]] hasInserts : tensor<2x13xi32, #{{.*}}>
+//       CHECK-RWT: %[[VAL_11:.*]] = sparse_tensor.convert %[[VAL_9]] : tensor<2x13xi32, #{{.*}}> to tensor<2x13xi32, #{{.*}}>
+//       CHECK-RWT: return %[[VAL_11]] : tensor<2x13xi32, #{{.*}}>
+func.func @sparse_convert_slice(%arg0: tensor<2x13xi32, #COOSlice>) -> (tensor<2x13xi32, #SortedCOO2D>)  {
+  %0 = sparse_tensor.convert %arg0 : tensor<2x13xi32, #COOSlice> to tensor<2x13xi32, #SortedCOO2D>
+  return %0 : tensor<2x13xi32, #SortedCOO2D>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 6c0f13ae1b564..b9f45fb97334c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -423,6 +423,19 @@ func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr<i8>) {
 
 // -----
 
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = ["dense", "compressed"],
+  slice = [ (1, 4, 1), (1, 4, 2) ]
+}>
+
+func.func @sparse_convert_to_slice(%arg0: tensor<10x?xf32>) -> tensor<10x10xf32, #CSR> {
+  // expected-error at +1 {{cannot convert to a sparse tensor slice}}
+  %0 = sparse_tensor.convert %arg0 : tensor<10x?xf32> to tensor<10x10xf32, #CSR>
+  return %0 : tensor<10x10xf32, #CSR>
+}
+
+// -----
+
 func.func @invalid_binary_num_args_mismatch_overlap(%arg0: f64, %arg1: f64) -> f64 {
   // expected-error at +1 {{overlap region must have exactly 2 arguments}}
   %r = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64


        


More information about the Mlir-commits mailing list