[Mlir-commits] [mlir] [mlir][sparse] ensure [dis]assembler wrapper methods properly inline (PR #81907)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 15 11:22:52 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/81907.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+3-4)
- (modified) mlir/test/Dialect/SparseTensor/torch_linalg.mlir (+9-1)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 98f9d15d09fa32..9414d81e6bf5c6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -61,7 +61,7 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
}
}
-// Convert input and output values to [dis[assemble ops for sparse tensors.
+// Convert input and output values to [dis]assemble ops for sparse tensors.
void convVals(OpBuilder &builder, Location loc, TypeRange types,
ValueRange fromVals, ValueRange extraVals,
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
@@ -161,8 +161,6 @@ namespace {
//
// TODO: refine output sparse tensors to work well with external framework
//
-// TODO: use "inlining" instead of a wrapper?
-//
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
@@ -211,7 +209,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
ValueRange(), inputs, 0, /*isIn=*/true);
- // Call original, now internal method.
+ // Call the original, now private method. A subsequent inlining pass can
+ // determine whether cloning the method body in place is worthwhile.
auto org = SymbolRefAttr::get(context, wrapper);
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
inputs);
diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
index f29e6b143783a6..4bb5938b2e44ec 100644
--- a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
+++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s --sparse-assembler | FileCheck %s --check-prefix=CHECK-HI
// RUN: mlir-opt %s --sparse-assembler \
+// RUN: --inline | FileCheck %s --check-prefix=CHECK-INL
+// RUN: mlir-opt %s --sparse-assembler \
// RUN: --linalg-generalize-named-ops \
// RUN: --linalg-fuse-elementwise-ops \
// RUN: --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
@@ -20,7 +22,13 @@
// CHECK-HI: func.func private @_internal_main
// CHECK-HI: linalg.matmul
// CHECK-HI: return
-//
+
+// CHECK-INL-LABEL: func.func @main
+// CHECK-INL: sparse_tensor.assemble
+// CHECK-INL: linalg.matmul
+// CHECK-INL: return
+// CHECK-INL-NOT: func.func private @_internal_main
+
// CHECK-MID-LABEL: func.func @main
// CHECK-MID: memref.load
// CHECK-MID: call @_internal_main
``````````
</details>
https://github.com/llvm/llvm-project/pull/81907
More information about the Mlir-commits
mailing list