[Mlir-commits] [mlir] [mlir][sparse] allow for direct-out passing of sparse tensor buffers (PR #88327)

Aart Bik llvmlistbot at llvm.org
Wed Apr 10 18:50:39 PDT 2024


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/88327

>From a21187969ba0a63b6e98ed615c666bf96b75094b Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 10 Apr 2024 16:08:47 -0700
Subject: [PATCH 1/4] [mlir][sparse] allow for direct-out passing of sparse
 tensor buffers

In order to support various external frameworks (JAX vs PyTorch)
we need a bit more flexibility in [dis]assembling external buffers
to and from sparse tensors in MLIR land. This PR adds a direct-out
option that avoids the rigid pre-allocated for copy-out semantics.

Note that over time, we expect the [dis]assemble operations to
converge into something that supports all sorts of external frameworks.
Until then, this option helps in experimenting with different options.
---
 .../Dialect/SparseTensor/Transforms/Passes.h  |  3 +-
 .../Dialect/SparseTensor/Transforms/Passes.td |  9 ++
 .../Transforms/SparseAssembler.cpp            | 96 ++++++++++++-------
 .../Transforms/SparseTensorConversion.cpp     |  9 +-
 .../Transforms/SparseTensorPasses.cpp         |  3 +-
 .../Dialect/SparseTensor/external_direct.mlir | 35 +++++++
 6 files changed, 116 insertions(+), 39 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/external_direct.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..d6d038ef65bdf4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -60,9 +60,10 @@ enum class SparseEmitStrategy {
 // The SparseAssembler pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseAssembler(RewritePatternSet &patterns);
+void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
 
 std::unique_ptr<Pass> createSparseAssembler();
+std::unique_ptr<Pass> createSparseAssembler(bool directOut);
 
 //===----------------------------------------------------------------------===//
 // The SparseReinterpretMap pass.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..4706d5ba2f218c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -23,12 +23,21 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
     sparse tensors as numpy arrays from and to Python. Note that eventual
     bufferization decisions (e.g. who [de]allocates the underlying memory)
     should be resolved in agreement with the external runtime.
+
+    By default, the pass uses the [dis]assemble operations to input and output
+    sparse tensors. When the direct-out option is set, however, the output
+    directly returns the MLIR allocated buffers to the external runtime.
   }];
   let constructor = "mlir::createSparseAssembler()";
   let dependentDialects = [
+    "bufferization::BufferizationDialect",
     "sparse_tensor::SparseTensorDialect",
     "tensor::TensorDialect",
   ];
+  let options = [
+    Option<"directOut", "direct-out", "bool",
+      "false", "Directly returns buffers externally">,
+  ];
 }
 
 def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index a91d32a23cac9f..a2edc75fc38c02 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -8,6 +8,7 @@
 
 #include "Utils/CodegenUtils.h"
 
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
@@ -24,7 +25,7 @@ using namespace sparse_tensor;
 
 // Convert type range to new types range, with sparse tensors externalized.
 static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
-                      SmallVectorImpl<Type> *extraTypes = nullptr) {
+                      SmallVectorImpl<Type> *extraTypes, bool directOut) {
   for (auto type : types) {
     // All "dense" data passes through unmodified.
     if (!getSparseTensorEncoding(type)) {
@@ -32,31 +33,38 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
       continue;
     }
 
-    // Convert the external representation of the position/coordinate array
+    // Convert the external representations of the pos/crd/val arrays.
     const SparseTensorType stt(cast<RankedTensorType>(type));
-    foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
-                                               Type t, FieldIndex,
-                                               SparseTensorFieldKind kind,
-                                               Level, LevelType) {
-      if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef ||
-          kind == SparseTensorFieldKind::ValMemRef) {
-        ShapedType st = t.cast<ShapedType>();
-        auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
-        convTypes.push_back(rtp);
-        if (extraTypes)
-          extraTypes->push_back(rtp);
-      }
-      return true;
-    });
+    foreachFieldAndTypeInSparseTensor(
+        stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
+                                                 SparseTensorFieldKind kind,
+                                                 Level, LevelType) {
+          if (kind == SparseTensorFieldKind::PosMemRef ||
+              kind == SparseTensorFieldKind::CrdMemRef ||
+              kind == SparseTensorFieldKind::ValMemRef) {
+            auto st = t.cast<ShapedType>();
+            auto shape = st.getShape();
+            auto eltTp = st.getElementType();
+            Type rtp;
+            if (directOut) {
+              rtp = MemRefType::get(shape, eltTp);
+            } else {
+              rtp = RankedTensorType::get(shape, eltTp);
+              if (extraTypes)
+                extraTypes->push_back(rtp);
+            }
+            convTypes.push_back(rtp);
+          }
+          return true;
+        });
   }
 }
 
 // Convert input and output values to [dis]assemble ops for sparse tensors.
 static void convVals(OpBuilder &builder, Location loc, TypeRange types,
                      ValueRange fromVals, ValueRange extraVals,
-                     SmallVectorImpl<Value> &toVals, unsigned extra,
-                     bool isIn) {
+                     SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
+                     bool directOut) {
   unsigned idx = 0;
   for (auto type : types) {
     // All "dense" data passes through unmodified.
@@ -73,21 +81,34 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
     if (!isIn)
       inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
 
-    // Collect the external representations of the pos/crd arrays.
+    // Collect the external representations of the pos/crd/val arrays.
     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
                                                      SparseTensorFieldKind kind,
-                                                     Level, LevelType) {
-      if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef ||
+                                                     Level lv, LevelType) {
+      if (kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::CrdMemRef ||
           kind == SparseTensorFieldKind::ValMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
         } else {
           ShapedType st = t.cast<ShapedType>();
           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
-          inputs.push_back(extraVals[extra++]);
-          retTypes.push_back(rtp);
-          cntTypes.push_back(builder.getIndexType());
+          if (directOut) {
+            Value mem;
+            if (kind == SparseTensorFieldKind::PosMemRef)
+              mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
+                                                                 lv);
+            else if (kind == SparseTensorFieldKind::CrdMemRef)
+              mem = builder.create<sparse_tensor::ToCoordinatesOp>(
+                  loc, inputs[0], lv);
+            else
+              mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
+            toVals.push_back(mem);
+          } else {
+            inputs.push_back(extraVals[extra++]);
+            retTypes.push_back(rtp);
+            cntTypes.push_back(builder.getIndexType());
+          }
         }
       }
       return true;
@@ -97,7 +118,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
       // Assemble multiple inputs into a single sparse tensor.
       auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
       toVals.push_back(a.getResult());
-    } else {
+    } else if (!directOut) {
       // Disassemble a single sparse input into multiple outputs.
       // Note that this includes the counters, which are dropped.
       unsigned len = retTypes.size();
@@ -144,11 +165,14 @@ namespace {
 //   return ..., t1..tn, ...
 // }
 //
-// TODO: refine output sparse tensors to work well with external framework
+// (with a direct-out variant without the disassemble).
 //
 struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
   using OpRewritePattern::OpRewritePattern;
 
+  SparseFuncAssembler(MLIRContext *context, bool dO)
+      : OpRewritePattern(context), directOut(dO) {}
+
   LogicalResult matchAndRewrite(func::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
     // Only rewrite public entry methods.
@@ -159,8 +183,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     SmallVector<Type> inputTypes;
     SmallVector<Type> outputTypes;
     SmallVector<Type> extraTypes;
-    convTypes(funcOp.getArgumentTypes(), inputTypes);
-    convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
+    convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, directOut);
+    convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
 
     // Only sparse inputs or outputs need a wrapper method.
     if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
@@ -192,7 +216,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     // Convert inputs.
     SmallVector<Value> inputs;
     convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
-             ValueRange(), inputs, 0, /*isIn=*/true);
+             ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
 
     // Call the original, now private method. A subsequent inlining pass can
     // determine whether cloning the method body in place is worthwhile.
@@ -203,7 +227,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     // Convert outputs and return.
     SmallVector<Value> outputs;
     convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
-             body->getArguments(), outputs, extra, /*isIn=*/false);
+             body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
     rewriter.create<func::ReturnOp>(loc, outputs);
 
     // Finally, migrate a potential c-interface property.
@@ -215,6 +239,9 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     }
     return success();
   }
+
+private:
+  const bool directOut;
 };
 
 } // namespace
@@ -223,6 +250,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
 // Public method for populating conversion rules.
 //===----------------------------------------------------------------------===//
 
-void mlir::populateSparseAssembler(RewritePatternSet &patterns) {
-  patterns.add<SparseFuncAssembler>(patterns.getContext());
+void mlir::populateSparseAssembler(RewritePatternSet &patterns,
+                                   bool directOut) {
+  patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c52fa3751e6b4a..f0d162bdb84d96 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -767,6 +767,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
 };
 
 /// Sparse conversion rule for the sparse_tensor.disassemble operator.
+/// Note that the current implementation simply exposes the buffers to
+/// the external client. This assumes the client only reads the buffers
+/// (usually copying it to the external data structures, such as numpy
+/// arrays). The semantics of the disassemble operation technically
+/// require that the copying is done here already using the out-levels
+/// and out-values clause.
 class SparseTensorDisassembleConverter
     : public OpConversionPattern<DisassembleOp> {
 public:
@@ -774,9 +780,6 @@ class SparseTensorDisassembleConverter
   LogicalResult
   matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // We simply expose the buffers to the external client. This
-    // assumes the client only reads the buffers (usually copying it
-    // to the external data structures, such as numpy arrays).
     Location loc = op->getLoc();
     auto stt = getSparseTensorType(op.getTensor());
     SmallVector<Value> retVal;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index acea25f023980a..b42d58634a36c4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -50,11 +50,12 @@ namespace {
 struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
   SparseAssembler() = default;
   SparseAssembler(const SparseAssembler &pass) = default;
+  SparseAssembler(bool dO) { directOut = dO; }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseAssembler(patterns);
+    populateSparseAssembler(patterns, directOut);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
diff --git a/mlir/test/Dialect/SparseTensor/external_direct.mlir b/mlir/test/Dialect/SparseTensor/external_direct.mlir
new file mode 100644
index 00000000000000..97a6d3031d90cd
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/external_direct.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --sparse-assembler="direct-out=True" -split-input-file | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_out(
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>)
+// CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
+// CHECK:         %[[P:.*]] = sparse_tensor.positions %[[F]]
+// CHECK:         %[[C:.*]] = sparse_tensor.coordinates %[[F]]
+// CHECK:         %[[V:.*]] = sparse_tensor.values %[[F]]
+// CHECK:         return %[[P]], %[[C]], %[[V]]
+// CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
+  %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
+  return %0 : tensor<64x64xf32, #sparse>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_out2(
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>)
+// CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
+// CHECK:         %[[P:.*]] = sparse_tensor.positions %[[F]]#1
+// CHECK:         %[[C:.*]] = sparse_tensor.coordinates %[[F]]#1
+// CHECK:         %[[V:.*]] = sparse_tensor.values %[[F]]#1
+// CHECK:         return %[[F]]#0, %[[P]], %[[C]], %[[V]]
+// CHECK:       }
+// CHECK:       func.func private @_internal_sparse_out2
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
+  %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
+  return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
+}

>From 5c9614fdfe7847649842635720e319a5f1f7500a Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 10 Apr 2024 16:20:55 -0700
Subject: [PATCH 2/4] edit

---
 .../Transforms/SparseAssembler.cpp            | 31 +++++++++----------
 1 file changed, 14 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index a2edc75fc38c02..bcb97fad81f953 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -90,27 +90,24 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
           kind == SparseTensorFieldKind::ValMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
+        } else if (directOut) {
+          Value mem;
+          if (kind == SparseTensorFieldKind::PosMemRef)
+            mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
+                                                               lv);
+          else if (kind == SparseTensorFieldKind::CrdMemRef)
+            mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
+                                                                 lv);
+          else
+            mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
+          toVals.push_back(mem);
         } else {
           ShapedType st = t.cast<ShapedType>();
           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
-          if (directOut) {
-            Value mem;
-            if (kind == SparseTensorFieldKind::PosMemRef)
-              mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
-                                                                 lv);
-            else if (kind == SparseTensorFieldKind::CrdMemRef)
-              mem = builder.create<sparse_tensor::ToCoordinatesOp>(
-                  loc, inputs[0], lv);
-            else
-              mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
-            toVals.push_back(mem);
-          } else {
-            inputs.push_back(extraVals[extra++]);
-            retTypes.push_back(rtp);
-            cntTypes.push_back(builder.getIndexType());
-          }
+          inputs.push_back(extraVals[extra++]);
+          retTypes.push_back(rtp);
+          cntTypes.push_back(builder.getIndexType());
         }
-      }
       return true;
     });
 

>From 42c75bc49c09f36096b1ebe4ead68256cb201985 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 10 Apr 2024 16:32:16 -0700
Subject: [PATCH 3/4] edit

---
 mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index bcb97fad81f953..8ae9ad6bf2151b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -108,6 +108,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
           retTypes.push_back(rtp);
           cntTypes.push_back(builder.getIndexType());
         }
+      }
       return true;
     });
 

>From 5fe3b6dc93e6f1eb7a437ab002bcacaaa243796c Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 10 Apr 2024 18:48:55 -0700
Subject: [PATCH 4/4] addressed reviewer feedback

---
 .../SparseTensor/Transforms/SparseAssembler.cpp | 17 ++++++-----------
 .../Dialect/SparseTensor/external_direct.mlir   | 17 +++++++++++++++++
 2 files changed, 23 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 8ae9ad6bf2151b..eafbe95b7aebe0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -42,14 +42,9 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
           if (kind == SparseTensorFieldKind::PosMemRef ||
               kind == SparseTensorFieldKind::CrdMemRef ||
               kind == SparseTensorFieldKind::ValMemRef) {
-            auto st = t.cast<ShapedType>();
-            auto shape = st.getShape();
-            auto eltTp = st.getElementType();
-            Type rtp;
-            if (directOut) {
-              rtp = MemRefType::get(shape, eltTp);
-            } else {
-              rtp = RankedTensorType::get(shape, eltTp);
+            auto rtp = t.cast<ShapedType>();
+            if (!directOut) {
+              rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
               if (extraTypes)
                 extraTypes->push_back(rtp);
             }
@@ -102,8 +97,8 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
             mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
           toVals.push_back(mem);
         } else {
-          ShapedType st = t.cast<ShapedType>();
-          auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
+          ShapedType rtp = t.cast<ShapedType>();
+          rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
           inputs.push_back(extraVals[extra++]);
           retTypes.push_back(rtp);
           cntTypes.push_back(builder.getIndexType());
@@ -181,7 +176,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
     SmallVector<Type> inputTypes;
     SmallVector<Type> outputTypes;
     SmallVector<Type> extraTypes;
-    convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, directOut);
+    convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
     convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
 
     // Only sparse inputs or outputs need a wrapper method.
diff --git a/mlir/test/Dialect/SparseTensor/external_direct.mlir b/mlir/test/Dialect/SparseTensor/external_direct.mlir
index 97a6d3031d90cd..78c4a295686b33 100644
--- a/mlir/test/Dialect/SparseTensor/external_direct.mlir
+++ b/mlir/test/Dialect/SparseTensor/external_direct.mlir
@@ -2,6 +2,23 @@
 
 // -----
 
+// CHECK-LABEL: func.func @sparse_in(
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
+// CHECK:         %[[F:.*]] = call @_internal_sparse_in(%[[I]])
+// CHECK:         return %[[F]] : tensor<64x64xf32>
+// CHECK:       }
+// CHECK:       func.func private @_internal_sparse_in
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
+  %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @sparse_out(
 // CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>)
 // CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])



More information about the Mlir-commits mailing list