[Mlir-commits] [mlir] [mlir][sparse] unify sparse_tensor.out rewriting rules (PR #70518)

Peiming Liu llvmlistbot at llvm.org
Fri Oct 27 15:43:06 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/70518

>From 056d8bc5724d184746c3c6e4d9ed7491e6f2d772 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 22:41:05 +0000
Subject: [PATCH 1/2] [mlir][sparse] unify sparse_tensor.out rewriting rules

---
 .../Transforms/SparseTensorConversion.cpp     | 34 +------------------
 .../Transforms/SparseTensorRewriting.cpp      |  4 +--
 .../test/Dialect/SparseTensor/conversion.mlir | 28 ---------------
 .../SparseTensor/python/test_output.py        | 23 +++++++++----
 4 files changed, 19 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a92038ce7c98d4e..96006d6cb82c545 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -707,37 +707,6 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
   }
 };
 
-/// Sparse conversion rule for the output operator.
-class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(OutOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    const Location loc = op->getLoc();
-    const auto srcTp = getSparseTensorType(op.getTensor());
-    // Convert to default permuted COO.
-    Value src = adaptor.getOperands()[0];
-    SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
-    Value coo = NewCallParams(rewriter, loc)
-                    .genBuffers(srcTp.withoutDimToLvl(), dimSizes)
-                    .genNewCall(Action::kToCOO, src);
-    // Then output the tensor to external file with coordinates in the
-    // externally visible lexicographic coordinate order.  A sort is
-    // required if the source was not in that order yet (note that the
-    // sort can be dropped altogether if external format does not care
-    // about the order at all, but here we assume it does).
-    const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
-    SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
-    const Type elemTp = srcTp.getElementType();
-    SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
-    createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
-    genDelCOOCall(rewriter, loc, elemTp, coo);
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
 /// Sparse conversion rule for the sparse_tensor.pack operator.
 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
 public:
@@ -789,6 +758,5 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
            SparseTensorLoadConverter, SparseTensorInsertConverter,
            SparseTensorExpandConverter, SparseTensorCompressConverter,
-           SparseTensorOutConverter, SparseTensorAssembleConverter>(
-          typeConverter, patterns.getContext());
+           SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6dcac38eb4f357c..e9bcb5dc070ade9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1312,12 +1312,12 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                ReshapeRewriter<tensor::CollapseShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
-               SparseTensorDimOpRewriter, TensorReshapeRewriter>(
+               SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
       patterns.getContext());
   if (enableForeach)
     patterns.add<ForeachRewriter>(patterns.getContext());
   if (enableConvert)
     patterns.add<DirectConvertRewriter>(patterns.getContext());
   if (!enableRT)
-    patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
+    patterns.add<NewRewriter>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 1d7599b3a4edb87..092ba6b8358b598 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -398,34 +398,6 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
   return %0 : tensor<8x8xf64, #CSR>
 }
 
-// CHECK-LABEL: func @sparse_out1(
-//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
-//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
-//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
-//  CHECK-DAG:  %[[Sort:.*]] = arith.constant false
-//       CHECK: %[[COO:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
-//       CHECK: call @outSparseTensorF64(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
-//       CHECK: call @delSparseTensorCOOF64(%[[COO]])
-//       CHECK: return
-func.func @sparse_out1(%arg0: tensor<?x?xf64, #CSR>, %arg1: !llvm.ptr<i8>) {
-  sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #CSR>, !llvm.ptr<i8>
-  return
-}
-
-// CHECK-LABEL: func @sparse_out2(
-//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
-//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
-//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
-//  CHECK-DAG:  %[[Sort:.*]] = arith.constant true
-//       CHECK: %[[COO:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
-//       CHECK: call @outSparseTensorF32(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
-//       CHECK: call @delSparseTensorCOOF32(%[[COO]])
-//       CHECK: return
-func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr<i8>) {
-  sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
-  return
-}
-
 // CHECK-LABEL: func @sparse_and_dense_init(
 //       CHECK: %[[S:.*]] = call @newSparseTensor
 //       CHECK: %[[D:.*]] = tensor.empty
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 1afac10be3adb45..4ef524591d0d44f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -6,6 +6,8 @@
 import sys
 import tempfile
 
+sys.path.append("/usr/local/google/home/peiming/projects/llvm-project/build-static/tools/mlir/python_packages/mlir_core")
+
 from mlir import ir
 from mlir import runtime as rt
 from mlir.dialects import builtin
@@ -29,14 +31,14 @@ def boilerplate(attr: st.EncodingAttr):
 """
 
 
-def expected():
+def expected(id_map):
     """Returns expected contents of output.
 
     Regardless of the dimension ordering, compression, and bitwidths that are
     used in the sparse tensor, the output is always lexicographically sorted
     by natural index order.
     """
-    return f"""; extended FROSTT format
+    return f"""# extended FROSTT format
 2 5
 10 10
 1 1 1
@@ -44,6 +46,14 @@ def expected():
 2 2 2
 5 5 5
 10 1 4
+""" if id_map else f"""# extended FROSTT format
+2 5
+10 10
+1 1 1
+10 1 4
+2 2 2
+5 5 5
+1 10 3
 """
 
 
@@ -51,7 +61,6 @@ def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
     # Build and Compile.
     module = ir.Module.parse(boilerplate(attr))
     engine = compiler.compile_and_jit(module)
-
     # Invoke the kernel and compare output.
     with tempfile.TemporaryDirectory() as test_dir:
         out = os.path.join(test_dir, "out.tns")
@@ -83,8 +92,8 @@ def main():
             [st.DimLevelType.compressed, st.DimLevelType.compressed],
         ]
         orderings = [
-            ir.AffineMap.get_permutation([0, 1]),
-            ir.AffineMap.get_permutation([1, 0]),
+            (ir.AffineMap.get_permutation([0, 1]), True),
+            (ir.AffineMap.get_permutation([1, 0]), False),
         ]
         bitwidths = [8, 16, 32, 64]
         compiler = sparse_compiler.SparseCompiler(
@@ -94,9 +103,9 @@ def main():
             for ordering in orderings:
                 for bwidth in bitwidths:
                     attr = st.EncodingAttr.get(
-                        level, ordering, ordering, bwidth, bwidth
+                        level, ordering[0], ordering[0], bwidth, bwidth
                     )
-                    build_compile_and_run_output(attr, compiler, expected())
+                    build_compile_and_run_output(attr, compiler, expected(ordering[1]))
                     count = count + 1
 
         # Now do the same for BSR.

>From f3f39076bccf5b5587d5ebbd905e7871e1c383fd Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 22:42:51 +0000
Subject: [PATCH 2/2] revert unintended change

---
 .../test/Integration/Dialect/SparseTensor/python/test_output.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 4ef524591d0d44f..96470bec219bfad 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -6,8 +6,6 @@
 import sys
 import tempfile
 
-sys.path.append("/usr/local/google/home/peiming/projects/llvm-project/build-static/tools/mlir/python_packages/mlir_core")
-
 from mlir import ir
 from mlir import runtime as rt
 from mlir.dialects import builtin



More information about the Mlir-commits mailing list