[Mlir-commits] [mlir] [mlir][sparse] allow for direct-out passing of sparse tensor buffers (PR #88327)
Aart Bik
llvmlistbot at llvm.org
Wed Apr 10 18:50:39 PDT 2024
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/88327
>From a21187969ba0a63b6e98ed615c666bf96b75094b Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 10 Apr 2024 16:08:47 -0700
Subject: [PATCH 1/4] [mlir][sparse] allow for direct-out passing of sparse
tensor buffers
In order to support various external frameworks (JAX vs PyTorch)
we need a bit more flexibility in [dis]assembling external buffers
to and from sparse tensors in MLIR land. This PR adds a direct-out
option that avoids the rigid pre-allocated for copy-out semantics.
Note that over time, we expect the [dis]assemble operations to
converge into something that supports all sorts of external frameworks.
Until then, this option helps in experimenting with different options.
---
.../Dialect/SparseTensor/Transforms/Passes.h | 3 +-
.../Dialect/SparseTensor/Transforms/Passes.td | 9 ++
.../Transforms/SparseAssembler.cpp | 96 ++++++++++++-------
.../Transforms/SparseTensorConversion.cpp | 9 +-
.../Transforms/SparseTensorPasses.cpp | 3 +-
.../Dialect/SparseTensor/external_direct.mlir | 35 +++++++
6 files changed, 116 insertions(+), 39 deletions(-)
create mode 100644 mlir/test/Dialect/SparseTensor/external_direct.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..d6d038ef65bdf4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -60,9 +60,10 @@ enum class SparseEmitStrategy {
// The SparseAssembler pass.
//===----------------------------------------------------------------------===//
-void populateSparseAssembler(RewritePatternSet &patterns);
+void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
std::unique_ptr<Pass> createSparseAssembler();
+std::unique_ptr<Pass> createSparseAssembler(bool directOut);
//===----------------------------------------------------------------------===//
// The SparseReinterpretMap pass.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..4706d5ba2f218c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -23,12 +23,21 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
sparse tensors as numpy arrays from and to Python. Note that eventual
bufferization decisions (e.g. who [de]allocates the underlying memory)
should be resolved in agreement with the external runtime.
+
+ By default, the pass uses the [dis]assemble operations to input and output
+ sparse tensors. When the direct-out option is set, however, the output
+ directly returns the MLIR allocated buffers to the external runtime.
}];
let constructor = "mlir::createSparseAssembler()";
let dependentDialects = [
+ "bufferization::BufferizationDialect",
"sparse_tensor::SparseTensorDialect",
"tensor::TensorDialect",
];
+ let options = [
+ Option<"directOut", "direct-out", "bool",
+ "false", "Directly returns buffers externally">,
+ ];
}
def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index a91d32a23cac9f..a2edc75fc38c02 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -8,6 +8,7 @@
#include "Utils/CodegenUtils.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
@@ -24,7 +25,7 @@ using namespace sparse_tensor;
// Convert type range to new types range, with sparse tensors externalized.
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
- SmallVectorImpl<Type> *extraTypes = nullptr) {
+ SmallVectorImpl<Type> *extraTypes, bool directOut) {
for (auto type : types) {
// All "dense" data passes through unmodified.
if (!getSparseTensorEncoding(type)) {
@@ -32,31 +33,38 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
continue;
}
- // Convert the external representation of the position/coordinate array
+ // Convert the external representations of the pos/crd/val arrays.
const SparseTensorType stt(cast<RankedTensorType>(type));
- foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
- Type t, FieldIndex,
- SparseTensorFieldKind kind,
- Level, LevelType) {
- if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef ||
- kind == SparseTensorFieldKind::ValMemRef) {
- ShapedType st = t.cast<ShapedType>();
- auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
- convTypes.push_back(rtp);
- if (extraTypes)
- extraTypes->push_back(rtp);
- }
- return true;
- });
+ foreachFieldAndTypeInSparseTensor(
+ stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
+ SparseTensorFieldKind kind,
+ Level, LevelType) {
+ if (kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::CrdMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
+ auto st = t.cast<ShapedType>();
+ auto shape = st.getShape();
+ auto eltTp = st.getElementType();
+ Type rtp;
+ if (directOut) {
+ rtp = MemRefType::get(shape, eltTp);
+ } else {
+ rtp = RankedTensorType::get(shape, eltTp);
+ if (extraTypes)
+ extraTypes->push_back(rtp);
+ }
+ convTypes.push_back(rtp);
+ }
+ return true;
+ });
}
}
// Convert input and output values to [dis]assemble ops for sparse tensors.
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
ValueRange fromVals, ValueRange extraVals,
- SmallVectorImpl<Value> &toVals, unsigned extra,
- bool isIn) {
+ SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
+ bool directOut) {
unsigned idx = 0;
for (auto type : types) {
// All "dense" data passes through unmodified.
@@ -73,21 +81,34 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
if (!isIn)
inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
- // Collect the external representations of the pos/crd arrays.
+ // Collect the external representations of the pos/crd/val arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
- Level, LevelType) {
- if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef ||
+ Level lv, LevelType) {
+ if (kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
- inputs.push_back(extraVals[extra++]);
- retTypes.push_back(rtp);
- cntTypes.push_back(builder.getIndexType());
+ if (directOut) {
+ Value mem;
+ if (kind == SparseTensorFieldKind::PosMemRef)
+ mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
+ lv);
+ else if (kind == SparseTensorFieldKind::CrdMemRef)
+ mem = builder.create<sparse_tensor::ToCoordinatesOp>(
+ loc, inputs[0], lv);
+ else
+ mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
+ toVals.push_back(mem);
+ } else {
+ inputs.push_back(extraVals[extra++]);
+ retTypes.push_back(rtp);
+ cntTypes.push_back(builder.getIndexType());
+ }
}
}
return true;
@@ -97,7 +118,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
// Assemble multiple inputs into a single sparse tensor.
auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
toVals.push_back(a.getResult());
- } else {
+ } else if (!directOut) {
// Disassemble a single sparse input into multiple outputs.
// Note that this includes the counters, which are dropped.
unsigned len = retTypes.size();
@@ -144,11 +165,14 @@ namespace {
// return ..., t1..tn, ...
// }
//
-// TODO: refine output sparse tensors to work well with external framework
+// (with a direct-out variant without the disassemble).
//
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
+ SparseFuncAssembler(MLIRContext *context, bool dO)
+ : OpRewritePattern(context), directOut(dO) {}
+
LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
// Only rewrite public entry methods.
@@ -159,8 +183,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<Type> extraTypes;
- convTypes(funcOp.getArgumentTypes(), inputTypes);
- convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
+ convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, directOut);
+ convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
// Only sparse inputs or outputs need a wrapper method.
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
@@ -192,7 +216,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
// Convert inputs.
SmallVector<Value> inputs;
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
- ValueRange(), inputs, 0, /*isIn=*/true);
+ ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
// Call the original, now private method. A subsequent inlining pass can
// determine whether cloning the method body in place is worthwhile.
@@ -203,7 +227,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
// Convert outputs and return.
SmallVector<Value> outputs;
convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
- body->getArguments(), outputs, extra, /*isIn=*/false);
+ body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
rewriter.create<func::ReturnOp>(loc, outputs);
// Finally, migrate a potential c-interface property.
@@ -215,6 +239,9 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
}
return success();
}
+
+private:
+ const bool directOut;
};
} // namespace
@@ -223,6 +250,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//
-void mlir::populateSparseAssembler(RewritePatternSet &patterns) {
- patterns.add<SparseFuncAssembler>(patterns.getContext());
+void mlir::populateSparseAssembler(RewritePatternSet &patterns,
+ bool directOut) {
+ patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c52fa3751e6b4a..f0d162bdb84d96 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -767,6 +767,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
};
/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+/// Note that the current implementation simply exposes the buffers to
+/// the external client. This assumes the client only reads the buffers
+/// (usually copying it to the external data structures, such as numpy
+/// arrays). The semantics of the disassemble operation technically
+/// require that the copying is done here already using the out-levels
+/// and out-values clause.
class SparseTensorDisassembleConverter
: public OpConversionPattern<DisassembleOp> {
public:
@@ -774,9 +780,6 @@ class SparseTensorDisassembleConverter
LogicalResult
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // We simply expose the buffers to the external client. This
- // assumes the client only reads the buffers (usually copying it
- // to the external data structures, such as numpy arrays).
Location loc = op->getLoc();
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index acea25f023980a..b42d58634a36c4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -50,11 +50,12 @@ namespace {
struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
SparseAssembler() = default;
SparseAssembler(const SparseAssembler &pass) = default;
+ SparseAssembler(bool dO) { directOut = dO; }
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseAssembler(patterns);
+ populateSparseAssembler(patterns, directOut);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/mlir/test/Dialect/SparseTensor/external_direct.mlir b/mlir/test/Dialect/SparseTensor/external_direct.mlir
new file mode 100644
index 00000000000000..97a6d3031d90cd
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/external_direct.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --sparse-assembler="direct-out=True" -split-input-file | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_out(
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>)
+// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
+// CHECK: %[[P:.*]] = sparse_tensor.positions %[[F]]
+// CHECK: %[[C:.*]] = sparse_tensor.coordinates %[[F]]
+// CHECK: %[[V:.*]] = sparse_tensor.values %[[F]]
+// CHECK: return %[[P]], %[[C]], %[[V]]
+// CHECK: }
+// CHECK: func.func private @_internal_sparse_out
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
+ %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
+ return %0 : tensor<64x64xf32, #sparse>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @sparse_out2(
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>)
+// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
+// CHECK: %[[P:.*]] = sparse_tensor.positions %[[F]]#1
+// CHECK: %[[C:.*]] = sparse_tensor.coordinates %[[F]]#1
+// CHECK: %[[V:.*]] = sparse_tensor.values %[[F]]#1
+// CHECK: return %[[F]]#0, %[[P]], %[[C]], %[[V]]
+// CHECK: }
+// CHECK: func.func private @_internal_sparse_out2
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
+ %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
+ return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
+}
>From 5c9614fdfe7847649842635720e319a5f1f7500a Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 10 Apr 2024 16:20:55 -0700
Subject: [PATCH 2/4] edit
---
.../Transforms/SparseAssembler.cpp | 31 +++++++++----------
1 file changed, 14 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index a2edc75fc38c02..bcb97fad81f953 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -90,27 +90,24 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
+ } else if (directOut) {
+ Value mem;
+ if (kind == SparseTensorFieldKind::PosMemRef)
+ mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
+ lv);
+ else if (kind == SparseTensorFieldKind::CrdMemRef)
+ mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
+ lv);
+ else
+ mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
+ toVals.push_back(mem);
} else {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
- if (directOut) {
- Value mem;
- if (kind == SparseTensorFieldKind::PosMemRef)
- mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
- lv);
- else if (kind == SparseTensorFieldKind::CrdMemRef)
- mem = builder.create<sparse_tensor::ToCoordinatesOp>(
- loc, inputs[0], lv);
- else
- mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
- toVals.push_back(mem);
- } else {
- inputs.push_back(extraVals[extra++]);
- retTypes.push_back(rtp);
- cntTypes.push_back(builder.getIndexType());
- }
+ inputs.push_back(extraVals[extra++]);
+ retTypes.push_back(rtp);
+ cntTypes.push_back(builder.getIndexType());
}
- }
return true;
});
>From 42c75bc49c09f36096b1ebe4ead68256cb201985 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 10 Apr 2024 16:32:16 -0700
Subject: [PATCH 3/4] edit
---
mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index bcb97fad81f953..8ae9ad6bf2151b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -108,6 +108,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
retTypes.push_back(rtp);
cntTypes.push_back(builder.getIndexType());
}
+ }
return true;
});
>From 5fe3b6dc93e6f1eb7a437ab002bcacaaa243796c Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 10 Apr 2024 18:48:55 -0700
Subject: [PATCH 4/4] addressed reviewer feedback
---
.../SparseTensor/Transforms/SparseAssembler.cpp | 17 ++++++-----------
.../Dialect/SparseTensor/external_direct.mlir | 17 +++++++++++++++++
2 files changed, 23 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index 8ae9ad6bf2151b..eafbe95b7aebe0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -42,14 +42,9 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
- auto st = t.cast<ShapedType>();
- auto shape = st.getShape();
- auto eltTp = st.getElementType();
- Type rtp;
- if (directOut) {
- rtp = MemRefType::get(shape, eltTp);
- } else {
- rtp = RankedTensorType::get(shape, eltTp);
+ auto rtp = t.cast<ShapedType>();
+ if (!directOut) {
+ rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
if (extraTypes)
extraTypes->push_back(rtp);
}
@@ -102,8 +97,8 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
toVals.push_back(mem);
} else {
- ShapedType st = t.cast<ShapedType>();
- auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
+ ShapedType rtp = t.cast<ShapedType>();
+ rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
cntTypes.push_back(builder.getIndexType());
@@ -181,7 +176,7 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<Type> extraTypes;
- convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, directOut);
+ convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
// Only sparse inputs or outputs need a wrapper method.
diff --git a/mlir/test/Dialect/SparseTensor/external_direct.mlir b/mlir/test/Dialect/SparseTensor/external_direct.mlir
index 97a6d3031d90cd..78c4a295686b33 100644
--- a/mlir/test/Dialect/SparseTensor/external_direct.mlir
+++ b/mlir/test/Dialect/SparseTensor/external_direct.mlir
@@ -2,6 +2,23 @@
// -----
+// CHECK-LABEL: func.func @sparse_in(
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
+// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
+// CHECK: return %[[F]] : tensor<64x64xf32>
+// CHECK: }
+// CHECK: func.func private @_internal_sparse_in
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
+ %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
+ return %0 : tensor<64x64xf32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @sparse_out(
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>)
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
More information about the Mlir-commits
mailing list