[Mlir-commits] [mlir] [mlir][sparse] ensure [dis]assembler wrapper methods properly inline (PR #81907)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 15 11:22:52 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/81907.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+3-4) 
- (modified) mlir/test/Dialect/SparseTensor/torch_linalg.mlir (+9-1) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 98f9d15d09fa32..9414d81e6bf5c6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -61,7 +61,7 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
   }
 }
 
-// Convert input and output values to [dis[assemble ops for sparse tensors.
+// Convert input and output values to [dis]assemble ops for sparse tensors.
 void convVals(OpBuilder &builder, Location loc, TypeRange types,
               ValueRange fromVals, ValueRange extraVals,
               SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
@@ -161,8 +161,6 @@ namespace {
 //
 // TODO: refine output sparse tensors to work well with external framework
 //
-// TODO: use "inlining" instead of a wrapper?
-//
 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -211,7 +209,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
              ValueRange(), inputs, 0, /*isIn=*/true);
 
-    // Call original, now internal method.
+    // Call the original, now private method. A subsequent inlining pass can
+    // determine whether cloning the method body in place is worthwhile.
     auto org = SymbolRefAttr::get(context, wrapper);
     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
                                               inputs);
diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
index f29e6b143783a6..4bb5938b2e44ec 100644
--- a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
+++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s --sparse-assembler                 | FileCheck %s --check-prefix=CHECK-HI
 // RUN: mlir-opt %s --sparse-assembler \
+// RUN:             --inline                           | FileCheck %s --check-prefix=CHECK-INL
+// RUN: mlir-opt %s --sparse-assembler \
 // RUN:             --linalg-generalize-named-ops \
 // RUN:             --linalg-fuse-elementwise-ops \
 // RUN:             --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
@@ -20,7 +22,13 @@
 // CHECK-HI:       func.func private @_internal_main
 // CHECK-HI:         linalg.matmul
 // CHECK-HI:         return
-//
+
+// CHECK-INL-LABEL: func.func @main
+// CHECK-INL:         sparse_tensor.assemble
+// CHECK-INL:         linalg.matmul
+// CHECK-INL:         return
+// CHECK-INL-NOT:   func.func private @_internal_main
+
 // CHECK-MID-LABEL: func.func @main
 // CHECK-MID:          memref.load
 // CHECK-MID:          call @_internal_main

``````````

</details>


https://github.com/llvm/llvm-project/pull/81907


More information about the Mlir-commits mailing list