[Mlir-commits] [mlir] 7d608ee - [mlir][sparse] unify sparse_tensor.out rewriting rules (#70518)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 27 16:47:01 PDT 2023


Author: Peiming Liu
Date: 2023-10-27T16:46:58-07:00
New Revision: 7d608ee2bb0a9511387491eef2209bea9fdcbb03

URL: https://github.com/llvm/llvm-project/commit/7d608ee2bb0a9511387491eef2209bea9fdcbb03
DIFF: https://github.com/llvm/llvm-project/commit/7d608ee2bb0a9511387491eef2209bea9fdcbb03.diff

LOG: [mlir][sparse] unify sparse_tensor.out rewriting rules (#70518)

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Integration/Dialect/SparseTensor/python/test_output.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a92038ce7c98d4e..570be951cab845e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -270,13 +270,6 @@ static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
       .getResult(0);
 }
 
-/// Generates a call to release/delete a `SparseTensorCOO`.
-static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
-                          Value coo) {
-  SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
-  createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
-}
-
 //===----------------------------------------------------------------------===//
 // Conversion rules.
 //===----------------------------------------------------------------------===//
@@ -707,37 +700,6 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
   }
 };
 
-/// Sparse conversion rule for the output operator.
-class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(OutOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    const Location loc = op->getLoc();
-    const auto srcTp = getSparseTensorType(op.getTensor());
-    // Convert to default permuted COO.
-    Value src = adaptor.getOperands()[0];
-    SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
-    Value coo = NewCallParams(rewriter, loc)
-                    .genBuffers(srcTp.withoutDimToLvl(), dimSizes)
-                    .genNewCall(Action::kToCOO, src);
-    // Then output the tensor to external file with coordinates in the
-    // externally visible lexicographic coordinate order.  A sort is
-    // required if the source was not in that order yet (note that the
-    // sort can be dropped altogether if external format does not care
-    // about the order at all, but here we assume it does).
-    const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
-    SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
-    const Type elemTp = srcTp.getElementType();
-    SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
-    createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
-    genDelCOOCall(rewriter, loc, elemTp, coo);
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
 /// Sparse conversion rule for the sparse_tensor.pack operator.
 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
 public:
@@ -789,6 +751,5 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
            SparseTensorLoadConverter, SparseTensorInsertConverter,
            SparseTensorExpandConverter, SparseTensorCompressConverter,
-           SparseTensorOutConverter, SparseTensorAssembleConverter>(
-          typeConverter, patterns.getContext());
+           SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6dcac38eb4f357c..e9bcb5dc070ade9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1312,12 +1312,12 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                ReshapeRewriter<tensor::CollapseShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
-               SparseTensorDimOpRewriter, TensorReshapeRewriter>(
+               SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
       patterns.getContext());
   if (enableForeach)
     patterns.add<ForeachRewriter>(patterns.getContext());
   if (enableConvert)
     patterns.add<DirectConvertRewriter>(patterns.getContext());
   if (!enableRT)
-    patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
+    patterns.add<NewRewriter>(patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 1d7599b3a4edb87..092ba6b8358b598 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -398,34 +398,6 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
   return %0 : tensor<8x8xf64, #CSR>
 }
 
-// CHECK-LABEL: func @sparse_out1(
-//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
-//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
-//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
-//  CHECK-DAG:  %[[Sort:.*]] = arith.constant false
-//       CHECK: %[[COO:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
-//       CHECK: call @outSparseTensorF64(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
-//       CHECK: call @delSparseTensorCOOF64(%[[COO]])
-//       CHECK: return
-func.func @sparse_out1(%arg0: tensor<?x?xf64, #CSR>, %arg1: !llvm.ptr<i8>) {
-  sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #CSR>, !llvm.ptr<i8>
-  return
-}
-
-// CHECK-LABEL: func @sparse_out2(
-//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
-//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
-//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
-//  CHECK-DAG:  %[[Sort:.*]] = arith.constant true
-//       CHECK: %[[COO:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
-//       CHECK: call @outSparseTensorF32(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
-//       CHECK: call @delSparseTensorCOOF32(%[[COO]])
-//       CHECK: return
-func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr<i8>) {
-  sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
-  return
-}
-
 // CHECK-LABEL: func @sparse_and_dense_init(
 //       CHECK: %[[S:.*]] = call @newSparseTensor
 //       CHECK: %[[D:.*]] = tensor.empty

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 1afac10be3adb45..5a8c92f7cd21fc5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -29,14 +29,14 @@ def boilerplate(attr: st.EncodingAttr):
 """
 
 
-def expected():
+def expected(id_map):
     """Returns expected contents of output.
 
-    Regardless of the dimension ordering, compression, and bitwidths that are
-    used in the sparse tensor, the output is always lexicographically sorted
-    by natural index order.
+    Output appears as dimension coordinates but lexicographically
+    sorted by level coordinates.
     """
-    return f"""; extended FROSTT format
+    return (
+        f"""# extended FROSTT format
 2 5
 10 10
 1 1 1
@@ -45,13 +45,23 @@ def expected():
 5 5 5
 10 1 4
 """
+        if id_map
+        else f"""# extended FROSTT format
+2 5
+10 10
+1 1 1
+10 1 4
+2 2 2
+5 5 5
+1 10 3
+"""
+    )
 
 
 def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
     # Build and Compile.
     module = ir.Module.parse(boilerplate(attr))
     engine = compiler.compile_and_jit(module)
-
     # Invoke the kernel and compare output.
     with tempfile.TemporaryDirectory() as test_dir:
         out = os.path.join(test_dir, "out.tns")
@@ -83,20 +93,20 @@ def main():
             [st.DimLevelType.compressed, st.DimLevelType.compressed],
         ]
         orderings = [
-            ir.AffineMap.get_permutation([0, 1]),
-            ir.AffineMap.get_permutation([1, 0]),
+            (ir.AffineMap.get_permutation([0, 1]), True),
+            (ir.AffineMap.get_permutation([1, 0]), False),
         ]
         bitwidths = [8, 16, 32, 64]
         compiler = sparse_compiler.SparseCompiler(
             options="", opt_level=2, shared_libs=[support_lib]
         )
         for level in levels:
-            for ordering in orderings:
+            for ordering, id_map in orderings:
                 for bwidth in bitwidths:
                     attr = st.EncodingAttr.get(
                         level, ordering, ordering, bwidth, bwidth
                     )
-                    build_compile_and_run_output(attr, compiler, expected())
+                    build_compile_and_run_output(attr, compiler, expected(id_map))
                     count = count + 1
 
         # Now do the same for BSR.


        


More information about the Mlir-commits mailing list