[Mlir-commits] [mlir] [mlir][sparse] refine sparse assembler strategy (PR #80521)
Aart Bik
llvmlistbot at llvm.org
Fri Feb 2 16:52:44 PST 2024
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/80521
Rewrite *all* public methods, making original internal, private methods, and exposing wrappers under the original name. This works a bit better in practice (when combined with c-interface mechanism of torch-mlir for example).
>From 548a4eab0b2e1d9c22567023b29e212d83198714 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 2 Feb 2024 16:51:38 -0800
Subject: [PATCH] [mlir][sparse] refine sparse assembler strategy
Rewrite *all* public methods, making original internal,
private methods, and exposing wrappers under the original
name. This works a bit better in practice (when combined
with c-interface mechanism of torch-mlir for example).
---
.../Dialect/SparseTensor/Transforms/Passes.td | 2 +-
.../Transforms/SparseAssembler.cpp | 55 ++++++++++---------
mlir/test/Dialect/SparseTensor/external.mlir | 49 +++++++++--------
.../Dialect/SparseTensor/torch_linalg.mlir | 55 +++++++++++++++++++
4 files changed, 113 insertions(+), 48 deletions(-)
create mode 100644 mlir/test/Dialect/SparseTensor/torch_linalg.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 8772d5f127949..58e2d6f32386c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -15,7 +15,7 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
let summary = "Add [dis]assemble operations on external sparse tensors";
let description = [{
A pass that converts public entry methods that use sparse tensors as
- input parameters and/or output return values into wrapper functions
+ input parameters and/or output return values into wrapper methods
that [dis]assemble the individual tensors that constitute the actual
storage used externally into MLIR sparse tensors. This pass can be used
to prepare the public entry methods of a program that is compiled by the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index f9b6397e0f086..b4cefec8fb21f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -132,29 +132,29 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
namespace {
// A rewriting rules that converts public entry methods that use sparse tensors
-// as input parameters and/or output return values into wrapper functions
-// that [dis]assemble the individual tensors that constitute the actual
-// storage used externally into MLIR sparse tensors.
+// as input parameters and/or output return values into wrapper methods that
+// [dis]assemble the individual tensors that constitute the actual storage used
+// externally into MLIR sparse tensors before calling the origal method.
//
// In particular, each sparse tensor input
//
// void foo(..., t, ...) { }
//
-// adds the following strucuture in a wrapper
+// makes the original foo() internal and adds the following wrapper method
//
-// void spiface_foo(..., t1..tn, ...) {
+// void foo(..., t1..tn, ...) {
// t = assemble t1..tn
-// foo(..., t, ...)
+// _internal_foo(..., t, ...)
// }
//
// and likewise, each output tensor
//
// ... T ... bar(...) { return ..., t, ...; }
//
-// adds the following structure in a wrapper
+// makes the original bar() internal and adds the following wrapper method
//
-// ... T1..TN ... spiface_bar(..., t1'..tn') {
-// ..., t, ... = bar(...)
+// ... T1..TN ... bar(..., t1'..tn') {
+// ..., t, ... = _internal_bar(...)
// t1..tn = disassemble t, t1'..tn'
// return ..., t1..tn, ...
// }
@@ -168,9 +168,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
- // Only a rewrite an entry with the c-interface requested.
- if (!funcOp->getAttrOfType<UnitAttr>(
- LLVM::LLVMDialect::getEmitCWrapperAttrName()))
+ // Only rewrite public entry methods.
+ if (funcOp.isPrivate())
return failure();
// Translate sparse tensor types to external types.
@@ -180,29 +179,29 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
convTypes(funcOp.getArgumentTypes(), inputTypes);
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
- // Only sparse inputs or outputs need a wrapper function.
+ // Only sparse inputs or outputs need a wrapper method.
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
outputTypes.size() == funcOp.getResultTypes().size())
return failure();
- // Start the new wrapper function. Together with the c-interface mangling,
- // a sparse external entry point eventually will have a name like:
- // _mlir_ciface_spiface_XXX(...)
+ // Modify the original method into an internal, private method.
+ auto orgName = funcOp.getName();
+ std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
+ funcOp.setName(wrapper);
+ funcOp.setPrivate();
+
+ // Start the new public wrapper method with original name.
Location loc = funcOp.getLoc();
ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
MLIRContext *context = modOp.getContext();
OpBuilder moduleBuilder(modOp.getBodyRegion());
- std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str();
unsigned extra = inputTypes.size();
inputTypes.append(extraTypes);
auto func = moduleBuilder.create<func::FuncOp>(
- loc, wrapper, FunctionType::get(context, inputTypes, outputTypes));
+ loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
func.setPublic();
- func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
- UnitAttr::get(context));
- // Construct new wrapper function body.
- auto org = SymbolRefAttr::get(context, funcOp.getName());
+ // Construct new wrapper method body.
OpBuilder::InsertionGuard insertionGuard(rewriter);
Block *body = func.addEntryBlock();
rewriter.setInsertionPointToStart(body);
@@ -212,7 +211,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
ValueRange(), inputs, 0, /*isIn=*/true);
- // Call original function.
+ // Call original, now internal method.
+ auto org = SymbolRefAttr::get(context, wrapper);
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
inputs);
@@ -222,8 +222,13 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
body->getArguments(), outputs, extra, /*isIn=*/false);
rewriter.create<func::ReturnOp>(loc, outputs);
- // Strip the c-interface attribute from the original function.
- funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
+ // Finally, migrate a potential c-interface property.
+ if (funcOp->getAttrOfType<UnitAttr>(
+ LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
+ func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
+ UnitAttr::get(context));
+ funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
+ }
return success();
}
};
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index 57df8aca3a6a5..c17ba13e86c92 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -3,95 +3,100 @@
// -----
// CHECK-LABEL: func.func @nop(
-// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> {
// CHECK: return %[[A]] : tensor<100xf32>
// CHECK: }
-func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes { llvm.emit_c_interface } {
+func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
return %arg0 : tensor<100xf32>
}
// -----
-// CHECK-LABEL: func.func @spiface_sparse_in(
+// CHECK-LABEL: func.func @sparse_in(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK: %[[F:.*]] = call @sparse_in(%[[I]])
+// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
+// CHECK: func.func private @_internal_sparse_in
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
+func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}
// -----
-// CHECK-LABEL: func.func @spiface_sparse_in2(
+// CHECK-LABEL: func.func @sparse_in2(
// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK: %[[F:.*]] = call @sparse_in2(%[[X]], %[[I]])
+// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
+// CHECK: func.func private @_internal_sparse_in2
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
+func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
%0 = sparse_tensor.convert %arg1 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}
// -----
-// CHECK-LABEL: func.func @spiface_sparse_out(
+// CHECK-LABEL: func.func @sparse_out(
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
-// CHECK: %[[F:.*]] = call @sparse_out(%[[X]])
+// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
// CHECK: }
+// CHECK: func.func private @_internal_sparse_out
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
+func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
return %0 : tensor<64x64xf32, #sparse>
}
// -----
-// CHECK-LABEL: func.func @spiface_sparse_out2(
+// CHECK-LABEL: func.func @sparse_out2(
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
-// CHECK: %[[F:.*]]:2 = call @sparse_out2(%[[X]])
+// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]#1
// CHECK: return %[[F]]#0
// CHECK: }
+// CHECK: func.func private @_internal_sparse_out2
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) attributes { llvm.emit_c_interface } {
+func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
}
// -----
-// CHECK-LABEL: func.func @spiface_sparse_inout(
+// CHECK-LABEL: func.func @sparse_inout(
// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
+// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
-// CHECK: %[[F:.*]] = call @sparse_inout(%[[I]])
+// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
// CHECK: }
+// CHECK: func.func private @_internal_sparse_inout
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
+func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
return %arg0 : tensor<64x64xf32, #sparse>
}
diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
new file mode 100644
index 0000000000000..f29e6b143783a
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s --sparse-assembler | FileCheck %s --check-prefix=CHECK-HI
+// RUN: mlir-opt %s --sparse-assembler \
+// RUN: --linalg-generalize-named-ops \
+// RUN: --linalg-fuse-elementwise-ops \
+// RUN: --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
+// RUN: mlir-opt %s --sparse-assembler \
+// RUN: --sparsifier | FileCheck %s --check-prefix=CHECK-LOW
+
+//
+// An example of a module generated by torch-mlir with a sparse tensor from
+// torch.sparse. The MLIR sparsifier should be able to provide the external
+// API through a wrapper method (spiface and ciface). Various passes should
+// compose without trouble.
+//
+
+// CHECK-HI-LABEL: func.func @main
+// CHECK-HI: sparse_tensor.assemble
+// CHECK-HI: call @_internal_main
+// CHECK-HI: return
+// CHECK-HI: func.func private @_internal_main
+// CHECK-HI: linalg.matmul
+// CHECK-HI: return
+//
+// CHECK-MID-LABEL: func.func @main
+// CHECK-MID: memref.load
+// CHECK-MID: call @_internal_main
+// CHECK-MID: return
+// CHECK-MID: func.func private @_internal_main
+// CHECK-MID: scf.for
+// CHECK-MID: scf.for
+// CHECK-MID: return
+
+// CHECK-LOW-LABEL: llvm.func @main
+// CHECK-LOW: llvm.call @_internal_main
+// CHECK-LOW: llvm.return
+// CHECK-LOW: llvm.func @_mlir_ciface_main
+// CHECK-LOW: llvm.call @main
+// CHECK-LOW: llvm.return
+// CHECK-LOW: llvm.func @_internal_main
+// CHECK-SAME: {sym_visibility = "private"}
+// CHECK-LOW: llvm.return
+
+#csc = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+module {
+ func.func @main(%arg0: tensor<64x64xf32, #csc>,
+ %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<64x64xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ %2 = linalg.matmul
+ ins(%arg0, %arg1 : tensor<64x64xf32, #csc>, tensor<64x64xf32>)
+ outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ return %2 : tensor<64x64xf32>
+ }
+}
More information about the Mlir-commits
mailing list