[Mlir-commits] [mlir] [mlir][sparse] unify sparse_tensor.out rewriting rules (PR #70518)

Peiming Liu llvmlistbot at llvm.org
Fri Oct 27 16:23:39 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/70518

>From 056d8bc5724d184746c3c6e4d9ed7491e6f2d772 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 22:41:05 +0000
Subject: [PATCH 1/5] [mlir][sparse] unify sparse_tensor.out rewriting rules

---
 .../Transforms/SparseTensorConversion.cpp     | 34 +------------------
 .../Transforms/SparseTensorRewriting.cpp      |  4 +--
 .../test/Dialect/SparseTensor/conversion.mlir | 28 ---------------
 .../SparseTensor/python/test_output.py        | 23 +++++++++----
 4 files changed, 19 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a92038ce7c98d4e..96006d6cb82c545 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -707,37 +707,6 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
   }
 };
 
-/// Sparse conversion rule for the output operator.
-class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(OutOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    const Location loc = op->getLoc();
-    const auto srcTp = getSparseTensorType(op.getTensor());
-    // Convert to default permuted COO.
-    Value src = adaptor.getOperands()[0];
-    SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
-    Value coo = NewCallParams(rewriter, loc)
-                    .genBuffers(srcTp.withoutDimToLvl(), dimSizes)
-                    .genNewCall(Action::kToCOO, src);
-    // Then output the tensor to external file with coordinates in the
-    // externally visible lexicographic coordinate order.  A sort is
-    // required if the source was not in that order yet (note that the
-    // sort can be dropped altogether if external format does not care
-    // about the order at all, but here we assume it does).
-    const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
-    SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
-    const Type elemTp = srcTp.getElementType();
-    SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
-    createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
-    genDelCOOCall(rewriter, loc, elemTp, coo);
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
 /// Sparse conversion rule for the sparse_tensor.pack operator.
 class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
 public:
@@ -789,6 +758,5 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
            SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
            SparseTensorLoadConverter, SparseTensorInsertConverter,
            SparseTensorExpandConverter, SparseTensorCompressConverter,
-           SparseTensorOutConverter, SparseTensorAssembleConverter>(
-          typeConverter, patterns.getContext());
+           SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6dcac38eb4f357c..e9bcb5dc070ade9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1312,12 +1312,12 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                ReshapeRewriter<tensor::CollapseShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
-               SparseTensorDimOpRewriter, TensorReshapeRewriter>(
+               SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
       patterns.getContext());
   if (enableForeach)
     patterns.add<ForeachRewriter>(patterns.getContext());
   if (enableConvert)
     patterns.add<DirectConvertRewriter>(patterns.getContext());
   if (!enableRT)
-    patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
+    patterns.add<NewRewriter>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 1d7599b3a4edb87..092ba6b8358b598 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -398,34 +398,6 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
   return %0 : tensor<8x8xf64, #CSR>
 }
 
-// CHECK-LABEL: func @sparse_out1(
-//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
-//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
-//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
-//  CHECK-DAG:  %[[Sort:.*]] = arith.constant false
-//       CHECK: %[[COO:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
-//       CHECK: call @outSparseTensorF64(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
-//       CHECK: call @delSparseTensorCOOF64(%[[COO]])
-//       CHECK: return
-func.func @sparse_out1(%arg0: tensor<?x?xf64, #CSR>, %arg1: !llvm.ptr<i8>) {
-  sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #CSR>, !llvm.ptr<i8>
-  return
-}
-
-// CHECK-LABEL: func @sparse_out2(
-//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
-//  CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>)
-//  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32
-//  CHECK-DAG:  %[[Sort:.*]] = arith.constant true
-//       CHECK: %[[COO:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ToCOO]], %[[A]])
-//       CHECK: call @outSparseTensorF32(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i1) -> ()
-//       CHECK: call @delSparseTensorCOOF32(%[[COO]])
-//       CHECK: return
-func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr<i8>) {
-  sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
-  return
-}
-
 // CHECK-LABEL: func @sparse_and_dense_init(
 //       CHECK: %[[S:.*]] = call @newSparseTensor
 //       CHECK: %[[D:.*]] = tensor.empty
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 1afac10be3adb45..4ef524591d0d44f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -6,6 +6,8 @@
 import sys
 import tempfile
 
+sys.path.append("/usr/local/google/home/peiming/projects/llvm-project/build-static/tools/mlir/python_packages/mlir_core")
+
 from mlir import ir
 from mlir import runtime as rt
 from mlir.dialects import builtin
@@ -29,14 +31,14 @@ def boilerplate(attr: st.EncodingAttr):
 """
 
 
-def expected():
+def expected(id_map):
     """Returns expected contents of output.
 
     Regardless of the dimension ordering, compression, and bitwidths that are
     used in the sparse tensor, the output is always lexicographically sorted
     by natural index order.
     """
-    return f"""; extended FROSTT format
+    return f"""# extended FROSTT format
 2 5
 10 10
 1 1 1
@@ -44,6 +46,14 @@ def expected():
 2 2 2
 5 5 5
 10 1 4
+""" if id_map else f"""# extended FROSTT format
+2 5
+10 10
+1 1 1
+10 1 4
+2 2 2
+5 5 5
+1 10 3
 """
 
 
@@ -51,7 +61,6 @@ def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
     # Build and Compile.
     module = ir.Module.parse(boilerplate(attr))
     engine = compiler.compile_and_jit(module)
-
     # Invoke the kernel and compare output.
     with tempfile.TemporaryDirectory() as test_dir:
         out = os.path.join(test_dir, "out.tns")
@@ -83,8 +92,8 @@ def main():
             [st.DimLevelType.compressed, st.DimLevelType.compressed],
         ]
         orderings = [
-            ir.AffineMap.get_permutation([0, 1]),
-            ir.AffineMap.get_permutation([1, 0]),
+            (ir.AffineMap.get_permutation([0, 1]), True),
+            (ir.AffineMap.get_permutation([1, 0]), False),
         ]
         bitwidths = [8, 16, 32, 64]
         compiler = sparse_compiler.SparseCompiler(
@@ -94,9 +103,9 @@ def main():
             for ordering in orderings:
                 for bwidth in bitwidths:
                     attr = st.EncodingAttr.get(
-                        level, ordering, ordering, bwidth, bwidth
+                        level, ordering[0], ordering[0], bwidth, bwidth
                     )
-                    build_compile_and_run_output(attr, compiler, expected())
+                    build_compile_and_run_output(attr, compiler, expected(ordering[1]))
                     count = count + 1
 
         # Now do the same for BSR.

>From f3f39076bccf5b5587d5ebbd905e7871e1c383fd Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 22:42:51 +0000
Subject: [PATCH 2/5] revert unintended change

---
 .../test/Integration/Dialect/SparseTensor/python/test_output.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 4ef524591d0d44f..96470bec219bfad 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -6,8 +6,6 @@
 import sys
 import tempfile
 
-sys.path.append("/usr/local/google/home/peiming/projects/llvm-project/build-static/tools/mlir/python_packages/mlir_core")
-
 from mlir import ir
 from mlir import runtime as rt
 from mlir.dialects import builtin

>From 9b176815982352f346326acadfd2d111e8fbbb44 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 22:46:14 +0000
Subject: [PATCH 3/5] minor improvement

---
 .../Integration/Dialect/SparseTensor/python/test_output.py  | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 96470bec219bfad..24b83b42b9d1c52 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -98,12 +98,12 @@ def main():
             options="", opt_level=2, shared_libs=[support_lib]
         )
         for level in levels:
-            for ordering in orderings:
+            for (ordering, id_map) in orderings:
                 for bwidth in bitwidths:
                     attr = st.EncodingAttr.get(
-                        level, ordering[0], ordering[0], bwidth, bwidth
+                        level, ordering, ordering, bwidth, bwidth
                     )
-                    build_compile_and_run_output(attr, compiler, expected(ordering[1]))
+                    build_compile_and_run_output(attr, compiler, expected(id_map))
                     count = count + 1
 
         # Now do the same for BSR.

>From 7825a22906713ce04ee1132aeab8e589e8019998 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 23:16:49 +0000
Subject: [PATCH 4/5] address comments.

---
 .../Integration/Dialect/SparseTensor/python/test_output.py   | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 24b83b42b9d1c52..ee700850c0ac81a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -32,9 +32,8 @@ def boilerplate(attr: st.EncodingAttr):
 def expected(id_map):
     """Returns expected contents of output.
 
-    Regardless of the dimension ordering, compression, and bitwidths that are
-    used in the sparse tensor, the output is always lexicographically sorted
-    by natural index order.
+    output appears as dimension coordinates but lexicographically
+    sorted by level coordinates.
     """
     return f"""# extended FROSTT format
 2 5

>From 8dcfed9532342f9b4202bf34af70eb372dc7d4e3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 27 Oct 2023 23:23:23 +0000
Subject: [PATCH 5/5] format python

---
 .../Dialect/SparseTensor/python/test_output.py         | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index ee700850c0ac81a..7bfe78fc43112e0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -35,7 +35,8 @@ def expected(id_map):
     output appears as dimension coordinates but lexicographically
     sorted by level coordinates.
     """
-    return f"""# extended FROSTT format
+    return (
+        f"""# extended FROSTT format
 2 5
 10 10
 1 1 1
@@ -43,7 +44,9 @@ def expected(id_map):
 2 2 2
 5 5 5
 10 1 4
-""" if id_map else f"""# extended FROSTT format
+"""
+        if id_map
+        else f"""# extended FROSTT format
 2 5
 10 10
 1 1 1
@@ -52,6 +55,7 @@ def expected(id_map):
 5 5 5
 1 10 3
 """
+    )
 
 
 def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
@@ -97,7 +101,7 @@ def main():
             options="", opt_level=2, shared_libs=[support_lib]
         )
         for level in levels:
-            for (ordering, id_map) in orderings:
+            for ordering, id_map in orderings:
                 for bwidth in bitwidths:
                     attr = st.EncodingAttr.get(
                         level, ordering, ordering, bwidth, bwidth



More information about the Mlir-commits mailing list