[Mlir-commits] [mlir] [mlir][sparse] ensure [dis]assembler wrapper methods properly inline (PR #81907)

Aart Bik llvmlistbot at llvm.org
Thu Feb 15 11:22:22 PST 2024


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/81907

None

>From f402ac26978de419771a6f295412c7f5492457f3 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 15 Feb 2024 10:56:27 -0800
Subject: [PATCH] [mlir][sparse] ensure [dis]assembler wrapper methods properly
 inline

---
 .../SparseTensor/Transforms/SparseAssembler.cpp        |  7 +++----
 mlir/test/Dialect/SparseTensor/torch_linalg.mlir       | 10 +++++++++-
 2 files changed, 12 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 98f9d15d09fa32..9414d81e6bf5c6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -61,7 +61,7 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
   }
 }
 
-// Convert input and output values to [dis[assemble ops for sparse tensors.
+// Convert input and output values to [dis]assemble ops for sparse tensors.
 void convVals(OpBuilder &builder, Location loc, TypeRange types,
               ValueRange fromVals, ValueRange extraVals,
               SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
@@ -161,8 +161,6 @@ namespace {
 //
 // TODO: refine output sparse tensors to work well with external framework
 //
-// TODO: use "inlining" instead of a wrapper?
-//
 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -211,7 +209,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
              ValueRange(), inputs, 0, /*isIn=*/true);
 
-    // Call original, now internal method.
+    // Call the original, now private method. A subsequent inlining pass can
+    // determine whether cloning the method body in place is worthwhile.
     auto org = SymbolRefAttr::get(context, wrapper);
     auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
                                               inputs);
diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
index f29e6b143783a6..4bb5938b2e44ec 100644
--- a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
+++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s --sparse-assembler                 | FileCheck %s --check-prefix=CHECK-HI
 // RUN: mlir-opt %s --sparse-assembler \
+// RUN:             --inline                           | FileCheck %s --check-prefix=CHECK-INL
+// RUN: mlir-opt %s --sparse-assembler \
 // RUN:             --linalg-generalize-named-ops \
 // RUN:             --linalg-fuse-elementwise-ops \
 // RUN:             --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
@@ -20,7 +22,13 @@
 // CHECK-HI:       func.func private @_internal_main
 // CHECK-HI:         linalg.matmul
 // CHECK-HI:         return
-//
+
+// CHECK-INL-LABEL: func.func @main
+// CHECK-INL:         sparse_tensor.assemble
+// CHECK-INL:         linalg.matmul
+// CHECK-INL:         return
+// CHECK-INL-NOT:   func.func private @_internal_main
+
 // CHECK-MID-LABEL: func.func @main
 // CHECK-MID:          memref.load
 // CHECK-MID:          call @_internal_main



More information about the Mlir-commits mailing list