[Mlir-commits] [mlir] [mlir][sparse] refine sparse assembler strategy (PR #80521)

Aart Bik llvmlistbot at llvm.org
Fri Feb 2 16:52:44 PST 2024


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/80521

Rewrite *all* public methods, making original internal, private methods, and exposing wrappers under the original name. This works a bit better in practice (when combined with c-interface mechanism of torch-mlir for example).

>From 548a4eab0b2e1d9c22567023b29e212d83198714 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 2 Feb 2024 16:51:38 -0800
Subject: [PATCH] [mlir][sparse] refine sparse assembler strategy

Rewrite *all* public methods, making original internal,
private methods, and exposing wrappers under the original
name. This works a bit better in practice (when combined
with c-interface mechanism of torch-mlir for example).
---
 .../Dialect/SparseTensor/Transforms/Passes.td |  2 +-
 .../Transforms/SparseAssembler.cpp            | 55 ++++++++++---------
 mlir/test/Dialect/SparseTensor/external.mlir  | 49 +++++++++--------
 .../Dialect/SparseTensor/torch_linalg.mlir    | 55 +++++++++++++++++++
 4 files changed, 113 insertions(+), 48 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/torch_linalg.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 8772d5f127949..58e2d6f32386c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -15,7 +15,7 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
   let summary = "Add [dis]assemble operations on external sparse tensors";
   let description = [{
     A pass that converts public entry methods that use sparse tensors as
-    input parameters and/or output return values into wrapper functions
+    input parameters and/or output return values into wrapper methods
     that [dis]assemble the individual tensors that constitute the actual
     storage used externally into MLIR sparse tensors. This pass can be used
     to prepare the public entry methods of a program that is compiled by the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index f9b6397e0f086..b4cefec8fb21f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -132,29 +132,29 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
 namespace {
 
 // A rewriting rules that converts public entry methods that use sparse tensors
-// as input parameters and/or output return values into wrapper functions
-// that [dis]assemble the individual tensors that constitute the actual
-// storage used externally into MLIR sparse tensors.
+// as input parameters and/or output return values into wrapper methods that
+// [dis]assemble the individual tensors that constitute the actual storage used
+// externally into MLIR sparse tensors before calling the origal method.
 //
 // In particular, each sparse tensor input
 //
 // void foo(..., t, ...) { }
 //
-// adds the following strucuture in a wrapper
+// makes the original foo() internal and adds the following wrapper method
 //
-// void spiface_foo(..., t1..tn, ...) {
+// void foo(..., t1..tn, ...) {
 //   t = assemble t1..tn
-//   foo(..., t, ...)
+//   _internal_foo(..., t, ...)
 // }
 //
 // and likewise, each output tensor
 //
 // ... T ... bar(...) { return ..., t, ...; }
 //
-// adds the following structure in a wrapper
+// makes the original bar() internal and adds the following wrapper method
 //
-// ... T1..TN ... spiface_bar(..., t1'..tn') {
-//   ..., t, ... = bar(...)
+// ... T1..TN ... bar(..., t1'..tn') {
+//   ..., t, ... = _internal_bar(...)
 //   t1..tn = disassemble t, t1'..tn'
 //   return ..., t1..tn, ...
 // }
@@ -168,9 +168,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
 
   LogicalResult matchAndRewrite(func::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
-    // Only a rewrite an entry with the c-interface requested.
-    if (!funcOp->getAttrOfType<UnitAttr>(
-            LLVM::LLVMDialect::getEmitCWrapperAttrName()))
+    // Only rewrite public entry methods.
+    if (funcOp.isPrivate())
       return failure();
 
     // Translate sparse tensor types to external types.
@@ -180,29 +179,29 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     convTypes(funcOp.getArgumentTypes(), inputTypes);
     convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
 
-    // Only sparse inputs or outputs need a wrapper function.
+    // Only sparse inputs or outputs need a wrapper method.
     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
         outputTypes.size() == funcOp.getResultTypes().size())
       return failure();
 
-    // Start the new wrapper function. Together with the c-interface mangling,
-    // a sparse external entry point eventually will have a name like:
-    //    _mlir_ciface_spiface_XXX(...)
+    // Modify the original method into an internal, private method.
+    auto orgName = funcOp.getName();
+    std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
+    funcOp.setName(wrapper);
+    funcOp.setPrivate();
+
+    // Start the new public wrapper method with original name.
     Location loc = funcOp.getLoc();
     ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
     MLIRContext *context = modOp.getContext();
     OpBuilder moduleBuilder(modOp.getBodyRegion());
-    std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str();
     unsigned extra = inputTypes.size();
     inputTypes.append(extraTypes);
     auto func = moduleBuilder.create<func::FuncOp>(
-        loc, wrapper, FunctionType::get(context, inputTypes, outputTypes));
+        loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
     func.setPublic();
-    func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
-                  UnitAttr::get(context));
 
-    // Construct new wrapper function body.
-    auto org = SymbolRefAttr::get(context, funcOp.getName());
+    // Construct new wrapper method body.
     OpBuilder::InsertionGuard insertionGuard(rewriter);
     Block *body = func.addEntryBlock();
     rewriter.setInsertionPointToStart(body);
@@ -212,7 +211,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
              ValueRange(), inputs, 0, /*isIn=*/true);
 
-    // Call original function.
+    // Call original, now internal method.
+    auto org = SymbolRefAttr::get(context, wrapper);
     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
                                               inputs);
 
@@ -222,8 +222,13 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
              body->getArguments(), outputs, extra, /*isIn=*/false);
     rewriter.create<func::ReturnOp>(loc, outputs);
 
-    // Strip the c-interface attribute from the original function.
-    funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
+    // Finally, migrate a potential c-interface property.
+    if (funcOp->getAttrOfType<UnitAttr>(
+            LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
+      func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
+                    UnitAttr::get(context));
+      funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
+    }
     return success();
   }
 };
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index 57df8aca3a6a5..c17ba13e86c92 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -3,95 +3,100 @@
 // -----
 
 // CHECK-LABEL: func.func @nop(
-// CHECK-SAME:    %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> {
 // CHECK:         return %[[A]] : tensor<100xf32>
 // CHECK:       }
-func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes { llvm.emit_c_interface } {
+func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
   return %arg0 : tensor<100xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_in(
+// CHECK-LABEL: func.func @sparse_in(
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_in(%[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_in(%[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_in
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
+func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
   return %0 : tensor<64x64xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_in2(
+// CHECK-LABEL: func.func @sparse_in2(
 // CHECK-SAME:    %[[X:.*]]: tensor<100xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_in2(%[[X]], %[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_in2
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
+func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
   %0 = sparse_tensor.convert %arg1 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
   return %0 : tensor<64x64xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_out(
+// CHECK-LABEL: func.func @sparse_out(
 // CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
-// CHECK:         %[[F:.*]] = call @sparse_out(%[[X]])
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
+func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
   return %0 : tensor<64x64xf32, #sparse>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_out2(
+// CHECK-LABEL: func.func @sparse_out2(
 // CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
 // CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
-// CHECK:         %[[F:.*]]:2 = call @sparse_out2(%[[X]])
+// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]#1
 // CHECK:         return %[[F]]#0
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out2
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) attributes { llvm.emit_c_interface } {
+func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
   %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
   return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
 }
 
 // -----
 
-// CHECK-LABEL: func.func @spiface_sparse_inout(
+// CHECK-LABEL: func.func @sparse_inout(
 // CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
 // CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
 // CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
 // CHECK-SAME:    %[[D:.*3]]: tensor<?xf32>,
 // CHECK-SAME:    %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
+// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
 // CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK:         %[[F:.*]] = call @sparse_inout(%[[I]])
+// CHECK:         %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
 // CHECK:       }
+// CHECK:       func.func private @_internal_sparse_inout
 #sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
+func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
   return %arg0 : tensor<64x64xf32, #sparse>
 }
diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
new file mode 100644
index 0000000000000..f29e6b143783a
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s --sparse-assembler                 | FileCheck %s --check-prefix=CHECK-HI
+// RUN: mlir-opt %s --sparse-assembler \
+// RUN:             --linalg-generalize-named-ops \
+// RUN:             --linalg-fuse-elementwise-ops \
+// RUN:             --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
+// RUN: mlir-opt %s --sparse-assembler \
+// RUN:             --sparsifier                       | FileCheck %s --check-prefix=CHECK-LOW
+
+//
+// An example of a module generated by torch-mlir with a sparse tensor from
+// torch.sparse. The MLIR sparsifier should be able to provide the external
+// API through a wrapper method (spiface and ciface). Various passes should
+// compose without trouble.
+//
+
+// CHECK-HI-LABEL: func.func @main
+// CHECK-HI:         sparse_tensor.assemble
+// CHECK-HI:         call @_internal_main
+// CHECK-HI:         return
+// CHECK-HI:       func.func private @_internal_main
+// CHECK-HI:         linalg.matmul
+// CHECK-HI:         return
+//
+// CHECK-MID-LABEL: func.func @main
+// CHECK-MID:          memref.load
+// CHECK-MID:          call @_internal_main
+// CHECK-MID:          return
+// CHECK-MID:       func.func private @_internal_main
+// CHECK-MID:          scf.for
+// CHECK-MID:            scf.for
+// CHECK-MID:          return
+
+// CHECK-LOW-LABEL: llvm.func @main
+// CHECK-LOW:         llvm.call @_internal_main
+// CHECK-LOW:         llvm.return
+// CHECK-LOW:       llvm.func @_mlir_ciface_main
+// CHECK-LOW:         llvm.call @main
+// CHECK-LOW:         llvm.return
+// CHECK-LOW:       llvm.func @_internal_main
+// CHECK-SAME:        {sym_visibility = "private"}
+// CHECK-LOW:         llvm.return
+
+#csc = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+module {
+  func.func @main(%arg0: tensor<64x64xf32, #csc>,
+                  %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<64x64xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    %2 = linalg.matmul
+      ins(%arg0, %arg1 : tensor<64x64xf32, #csc>, tensor<64x64xf32>)
+      outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
+    return %2 : tensor<64x64xf32>
+  }
+}



More information about the Mlir-commits mailing list