[Mlir-commits] [mlir] edca72f - [mlir][sparse] Refactoring: remove dependence on tuple type when lowering sparse tensors.
Peiming Liu
llvmlistbot at llvm.org
Wed Sep 7 10:53:57 PDT 2022
Author: Peiming Liu
Date: 2022-09-07T17:53:48Z
New Revision: edca72f5bcb039840fda28e324af4614d4e46fde
URL: https://github.com/llvm/llvm-project/commit/edca72f5bcb039840fda28e324af4614d4e46fde
DIFF: https://github.com/llvm/llvm-project/commit/edca72f5bcb039840fda28e324af4614d4e46fde.diff
LOG: [mlir][sparse] Refactoring: remove dependence on tuple type when lowering sparse tensors.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133390
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5c56f16d71ee8..7035fb16d8e18 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -624,79 +624,4 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
let hasVerifier = 1;
}
-//===----------------------------------------------------------------------===//
-// Sparse Tensor Storage Operation. These operations are used internally by
-// sparse tensor codegen to progressively lower sparse tensors.
-//===----------------------------------------------------------------------===//
-
-def SparseTensor_StorageOp : SparseTensor_Op<"storage", []>,
- Arguments<(ins Variadic<AnyType>:$inputs)>,
- Results<(outs AnyTuple:$result)> {
- let summary = "Pack a list of value into one sparse tensor storage value";
- let description = [{
- Pack a list of value into one sparse tensor storage value (represented as
- a tuple) at the given index.
-
- The result tuple elements' type should match the corresponding type in the
- input array.
-
- Example:
-
- ```mlir
- %0 = sparse_tensor.storage(%1, %2): memref<?xf64>, memref<?xf64>
- to tuple<memref<?xf64>, memref<?xf64>>
- ```
- }];
-
- let assemblyFormat = " attr-dict `(` $inputs `)``:` type($inputs) `to` type($result)";
- let hasVerifier = 1;
-}
-
-def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
- Arguments<(ins AnyTuple:$storage,
- IndexAttr:$idx)>,
- Results<(outs AnyType:$result)> {
- let summary = "Get the data stored in the sparse tensor storage at the given index";
- let description = [{
- Get the data stored in the sparse tensor storage (represented as a tuple)
- at the given index.
-
- The result type should match the corresponding element type in the tuple.
-
- Example:
-
- ```mlir
- %0 = sparse_tensor.storage_get %arg0[0] : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
- ```
- }];
-
- let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)";
- let hasVerifier = 1;
-}
-
-def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>,
- Arguments<(ins AnyTuple:$storage,
- AnyType:$value,
- IndexAttr:$idx)>,
- Results<(outs AnyTuple:$result)> {
- let summary = "Set the data stored in the sparse tensor storage at given index";
- let description = [{
- Set the data stored in the sparse tensor storage (represented as a tuple)
- at the given index. Return a new SSA value with the corresponding element
- updated (others remain unchanged).
-
- The result type should match the original tuple type with only the updated
- element type changed accordingly.
-
- Example:
-
- ```mlir
- %0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, f64>
- ```
- }];
-
- let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)";
- let hasVerifier = 1;
-}
-
#endif // SPARSETENSOR_OPS
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 227b70a381192..fd885f646221a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -155,22 +155,6 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
std::unique_ptr<Pass> createSparseTensorCodegenPass();
-//===----------------------------------------------------------------------===//
-// The SparseTensorStorageExpansion pass.
-//===----------------------------------------------------------------------===//
-
-/// Sparse tensor storage type converter from compound to expanded form.
-class SparseTensorStorageTupleExpander : public TypeConverter {
-public:
- SparseTensorStorageTupleExpander();
-};
-
-/// Sets up sparse tensor storage expansion rules.
-void populateSparseTensorStorageExpansionPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
-std::unique_ptr<Pass> createSparseTensorStorageExpansionPass();
-
//===----------------------------------------------------------------------===//
// Other rewriting rules and passes.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index cd6b77ea50eea..f7f4a39a95f23 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -175,39 +175,4 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
];
}
-def SparseTensorStorageExpansion : Pass<"sparse-tensor-storage-expansion", "ModuleOp"> {
- let summary = "Expand compounded sparse tensor storage into individual SSA values";
- let description = [{
- A pass that expands sparse tensor storage (aggregated by tuple) into
- individual SSA values. It also lowers sparse tensor storage operations,
- e.g., sparse_tensor.storage_get and sparse_tensor.storage_set.
-
- Example of the conversion:
-
- ```mlir
- Before:
- func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>,
- memref<?xf64>,
- f64>)
- -> tuple<memref<?xf64>,
- memref<?xf64>,
- f64> {
- return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
- }
- After:
- func.func @sparse_storage_set(%arg0: memref<?xf64>,
- %arg1: memref<?xf64>,
- %arg2: f64)
- -> (memref<?xf64>, memref<?xf64>, f64) {
- return %arg0, %arg1, %arg2 : memref<?xf64>, memref<?xf64>, f64
- }
- ```
- }];
- let constructor = "mlir::createSparseTensorStorageExpansionPass()";
- let dependentDialects = [
- "sparse_tensor::SparseTensorDialect",
- ];
-}
-
-
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ef32dea8efff2..8691b94351f9f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -482,65 +482,6 @@ LogicalResult YieldOp::verify() {
"expected parent op to be sparse_tensor unary, binary, or reduce");
}
-//===----------------------------------------------------------------------===//
-// Sparse Tensor Storage Operation.
-//===----------------------------------------------------------------------===//
-
-LogicalResult StorageOp::verify() {
- auto retTypes = getResult().getType().getTypes();
- if (retTypes.size() != getInputs().size())
- return emitError("The number of inputs is inconsistent with output tuple");
-
- for (auto pair : llvm::zip(getInputs(), retTypes)) {
- auto input = std::get<0>(pair);
- auto retTy = std::get<1>(pair);
-
- if (input.getType() != retTy)
- return emitError(llvm::formatv("Type mismatch between input (type={0}) "
- "and output tuple element (type={1})",
- input.getType(), retTy));
- }
- return success();
-}
-
-LogicalResult StorageGetOp::verify() {
- uint64_t extractIdx = getIdx().getZExtValue();
- auto innerTypeArray = getStorage().getType().getTypes();
- if (extractIdx >= innerTypeArray.size())
- return emitError(llvm::formatv(
- "Out-of-bound access with index={0} on tuple with length={1}",
- extractIdx, innerTypeArray.size()));
-
- auto expectedTy = getStorage().getType().getType(extractIdx);
- auto returnTy = getResult().getType();
- if (expectedTy != returnTy)
- return emitError(llvm::formatv(
- "Type mismatch between the returning type (type={0}) and the "
- "corresponding element type at index {1} (type={2})",
- expectedTy, extractIdx, returnTy));
- return success();
-}
-
-LogicalResult StorageSetOp::verify() {
- uint64_t setIdx = getIdx().getZExtValue();
- SmallVector<Type, 8> expectedElemTy(getStorage().getType().getTypes());
- if (setIdx >= expectedElemTy.size())
- return emitError(llvm::formatv(
- "Out-of-bound access with index = {0} on tuple with length={1}", setIdx,
- expectedElemTy.size()));
-
- // Updates the element type after storage_set.
- expectedElemTy[setIdx] = getValue().getType();
- auto expectedTy = TupleType::get(getContext(), expectedElemTy);
- auto returnTy = getResult().getType();
- if (expectedTy != returnTy)
- return emitError(
- llvm::formatv("Type mismatch between the returning type "
- "(type={0}) and the expected type (type={1})",
- returnTy, expectedTy));
- return success();
-}
-
//===----------------------------------------------------------------------===//
// TensorDialect Methods.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 39b633a6c7f6a..640ee67302b1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -7,7 +7,6 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
- SparseTensorStorageExpansion.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 022c4be443a0a..3c9caf71512b8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -54,8 +54,30 @@ static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
return i;
}
+/// Flatten a list of operands that may contain sparse tensors.
+static void flattenOperands(ValueRange operands,
+ SmallVectorImpl<Value> &flattened) {
+ // In case of
+ // sparse_tensor, c, sparse_tensor
+ // ==>
+ // memref ..., c, memref ...
+ for (auto operand : operands) {
+ if (auto cast =
+ dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
+ cast && getSparseTensorEncoding(cast->getResultTypes()[0]))
+ // An unrealized_conversion_cast will be inserted by type converter to
+ // inter-mix the gap between 1:N conversion between sparse tensors and
+ // fields. In this case, take the operands in the cast and replace the
+ // sparse tensor output with the flattened type array.
+ flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+ else
+ flattened.push_back(operand);
+ }
+}
+
/// Maps a sparse tensor type to the appropriate compounded buffers.
-static Optional<Type> convertSparseTensorType(Type type) {
+static Optional<LogicalResult>
+convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
auto enc = getSparseTensorEncoding(type);
if (!enc)
return llvm::None;
@@ -86,7 +108,6 @@ static Optional<Type> convertSparseTensorType(Type type) {
// };
//
unsigned rank = rType.getShape().size();
- SmallVector<Type, 8> fields;
// The dimSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
// Per-dimension storage.
@@ -115,10 +136,7 @@ static Optional<Type> convertSparseTensorType(Type type) {
}
// The values array.
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
- // Sparse tensor storage (temporarily) lives in a tuple. This allows a
- // simple 1:1 type conversion during codegen. A subsequent pass uses
- // a 1:N type conversion to expand the tuple into its fields.
- return TupleType::get(context, fields);
+ return success();
}
// Returns field index of sparse tensor type for pointers/indices, when set.
@@ -158,25 +176,6 @@ static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
return -1;
}
-/// Returns field type in tuple at given index.
-static Type getFieldType(Value tuple, unsigned field) {
- return tuple.getType().cast<TupleType>().getType(field);
-}
-
-/// Creates tuple get operation at given index.
-static Value createTupleGet(OpBuilder &builder, Location loc, Value tuple,
- unsigned field) {
- Type indexType = builder.getIndexType();
- return builder.create<StorageGetOp>(loc, getFieldType(tuple, field), tuple,
- builder.getIntegerAttr(indexType, field));
-}
-
-/// Creates tuple.
-static Value createTupleMake(OpBuilder &builder, Location loc, Type type,
- ValueRange values) {
- return builder.create<StorageOp>(loc, type, values);
-}
-
/// Create allocation operation.
static Value createAllocation(OpBuilder &builder, Location loc, Type type,
Value sz) {
@@ -184,14 +183,15 @@ static Value createAllocation(OpBuilder &builder, Location loc, Type type,
return builder.create<memref::AllocOp>(loc, memType, sz);
}
-/// Creates allocation tuple for sparse tensor type.
+/// Creates allocation for each field in sparse tensor type.
///
/// TODO: for efficiency, we will need heuristis to make educated guesses
/// on the required final sizes; also, we will need an improved
/// memory allocation scheme with capacity and reallocation
///
-static Value createAllocTuple(OpBuilder &builder, Location loc, Type type,
- ValueRange dynSizes) {
+static void createAllocFields(OpBuilder &builder, Location loc, Type type,
+ ValueRange dynSizes,
+ SmallVectorImpl<Value> &fields) {
auto enc = getSparseTensorEncoding(type);
assert(enc);
// Construct the basic types.
@@ -202,10 +202,8 @@ static Value createAllocTuple(OpBuilder &builder, Location loc, Type type,
Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType;
Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
Type eltType = rType.getElementType();
- // Build the allocation tuple, using heuristics for pre-allocation.
auto shape = rType.getShape();
unsigned rank = shape.size();
- SmallVector<Value, 8> fields;
bool allDense = true;
Value one = constantIndex(builder, loc, 1);
Value linear = one;
@@ -254,9 +252,6 @@ static Value createAllocTuple(OpBuilder &builder, Location loc, Type type,
// In all other case, we resort to the heuristical initial value.
Value valuesSz = allDense ? linear : heuristic;
fields.push_back(createAllocation(builder, loc, eltType, valuesSz));
- // Construct tuple allocation.
- Type tupleType = *convertSparseTensorType(type);
- return createTupleMake(builder, loc, tupleType, fields);
}
/// Returns integral constant, if defined.
@@ -270,14 +265,80 @@ static Optional<int64_t> getConstantInt(Value val) {
// Codegen rules.
//===----------------------------------------------------------------------===//
-/// Sparse codegen rule for returns.
+/// Sparse tensor storage conversion rule for returns.
class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
+ SmallVector<Value, 8> flattened;
+ flattenOperands(adaptor.getOperands(), flattened);
+ // Create a return with the flattened value extracted from sparse tensors.
+ rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+ return success();
+ }
+};
+
+/// Sparse tensor storage conversion rule for calls.
+class SparseCallConverter : public OpConversionPattern<func::CallOp> {
+public:
+ // The default CallOp converter can not handle 1:N type conversion.
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // In case of:
+ // sparse_tensor, f, sparse_tensor = call @foo(...)
+ // ==>
+ // memref..., f, memref = call @foo(...) replace with
+ // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
+ SmallVector<Type, 8> finalRetTy;
+ if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
+ return failure();
+
+ // (1) Genereates new call with flattened return value.
+ SmallVector<Value, 8> flattened;
+ flattenOperands(adaptor.getOperands(), flattened);
+ auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
+ finalRetTy, flattened);
+ // (2) Create cast operation for sparse tensor returns.
+ SmallVector<Value, 4> castedRet;
+ // Tracks the offset of current return value (of the orignal call)
+ // relative to the new call (after sparse tensor flattening);
+ unsigned retOffset = 0;
+ // Temporal buffer to hold the flattened list of type for
+ // a sparse tensor.
+ SmallVector<Type, 8> sparseFlat;
+ for (auto ret : op.getResults()) {
+ assert(retOffset < newCall.getNumResults());
+ auto retType = ret.getType();
+ if (failed(typeConverter->convertType(retType, sparseFlat)))
+ // This should never happen.
+ llvm_unreachable("Failed to convert type in sparse tensor codegen");
+
+ // Converted types can not be empty when the type conversion succeed.
+ assert(!sparseFlat.empty());
+ if (sparseFlat.size() > 1) {
+ auto flatSize = sparseFlat.size();
+ ValueRange sparseElem(iterator_range<ResultRange::iterator>(
+ newCall.result_begin() + retOffset,
+ newCall.result_begin() + retOffset + flatSize));
+ auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+ loc, TypeRange({retType}), sparseElem);
+ castedRet.push_back(castOp.getResult(0));
+ retOffset += flatSize;
+ } else {
+ // If this is an 1:1 conversion, no need for casting.
+ castedRet.push_back(newCall.getResult(retOffset));
+ retOffset++;
+ }
+ sparseFlat.clear();
+ }
+
+ assert(castedRet.size() == op.getNumResults());
+ rewriter.replaceOp(op, castedRet);
return success();
}
};
@@ -306,10 +367,11 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
}
// Any other query can consult the dimSizes array at field 0 using,
// accounting for the reordering applied to the sparse storage.
- Value tuple = adaptor.getSource();
- Value dimSizes = createTupleGet(rewriter, loc, tuple, 0);
+ auto tuple = llvm::cast<UnrealizedConversionCastOp>(
+ adaptor.getSource().getDefiningOp());
rewriter.replaceOpWithNewOp<memref::LoadOp>(
- op, dimSizes, constantIndex(rewriter, loc, toStored(enc, *index)));
+ op, tuple.getInputs().front(),
+ constantIndex(rewriter, loc, toStored(enc, *index)));
return success();
}
};
@@ -345,10 +407,13 @@ class SparseTensorAllocConverter
return failure();
if (op.getCopy())
return rewriter.notifyMatchFailure(op, "tensor copy not implemented");
- // Construct allocation tuple.
- Value tuple = createAllocTuple(rewriter, op->getLoc(), resType,
- adaptor.getOperands());
- rewriter.replaceOp(op, tuple);
+
+ // Construct allocation for each field.
+ Location loc = op.getLoc();
+ SmallVector<Value, 8> fields;
+ createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, TypeRange{resType}, fields);
return success();
}
};
@@ -364,86 +429,101 @@ class SparseTensorDeallocConverter
auto enc = getSparseTensorEncoding(op.getTensor().getType());
if (!enc)
return failure();
- // Replace the tuple deallocation with field deallocations.
- Location loc = op->getLoc();
- Value tuple = adaptor.getTensor();
- for (unsigned i = 0, sz = tuple.getType().cast<TupleType>().size(); i < sz;
- i++) {
- Value mem = createTupleGet(rewriter, loc, tuple, i);
- rewriter.create<memref::DeallocOp>(loc, mem);
- }
+
+ // Replace the sparse tensor deallocation with field deallocations.
+ Location loc = op.getLoc();
+ auto tuple = llvm::cast<UnrealizedConversionCastOp>(
+ adaptor.getTensor().getDefiningOp());
+ for (auto input : tuple.getInputs())
+ // Deallocate every buffer used to store the sparse tensor handler.
+ rewriter.create<memref::DeallocOp>(loc, input);
+
rewriter.eraseOp(op);
return success();
}
};
-/// Sparse codegen rule for pointer accesses.
-class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
+/// Sparse codegen rule for tensor rematerialization.
+class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
+ matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
- if (!index)
- return failure();
- // Replace the requested pointer access with corresponding field.
- Location loc = op->getLoc();
- Value tuple = adaptor.getTensor();
- unsigned i = getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*index, -1);
- rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+ if (op.getHasInserts()) {
+ // Finalize any pending insertions.
+ // TODO: implement
+ }
+ rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};
-/// Sparse codegen rule for index accesses.
-class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
+/// Base class for getter-like operations, e.g., to_indices, to_pointers.
+template <typename SourceOp, typename Base>
+class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using OpAdaptor = typename SourceOp::Adaptor;
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
+ matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
- if (!index)
+ // Replace the requested pointer access with corresponding field.
+ // The cast_op is inserted by type converter to intermix 1:N type
+ // conversion.
+ auto tuple = llvm::cast<UnrealizedConversionCastOp>(
+ adaptor.getTensor().getDefiningOp());
+ auto idx = Base::getIndexForOp(tuple, op);
+ if (!idx)
+ // Failed to get the index.
return failure();
- // Replace the requested indices access with corresponding field.
- Location loc = op->getLoc();
- Value tuple = adaptor.getTensor();
- unsigned i = getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*index);
- rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+ auto fields = tuple.getInputs();
+ assert(*idx < fields.size());
+ rewriter.replaceOp(op, fields[*idx]);
return success();
}
};
-/// Sparse codegen rule for value accesses.
-class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
+/// Sparse codegen rule for pointer accesses.
+class SparseToPointersConverter
+ : public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Replace the requested values access with corresponding field.
- Location loc = op->getLoc();
- Value tuple = adaptor.getTensor();
- unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
- rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
- return success();
+ using SparseGetterOpConverter::SparseGetterOpConverter;
+ // Callback for SparseGetterOpConverter.
+ static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
+ ToPointersOp op) {
+ Optional<int64_t> dim = getConstantInt(op.getDim());
+ if (!dim)
+ return llvm::None; // variable dim
+ return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*dim, -1);
}
};
-/// Sparse codegen rule for tensor rematerialization.
-class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
+/// Sparse codegen rule for index accesses.
+class SparseToIndicesConverter
+ : public SparseGetterOpConverter<ToIndicesOp, SparseToIndicesConverter> {
public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(LoadOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (op.getHasInserts()) {
- // Finalize any pending insertions.
- // TODO: implement
- }
- rewriter.replaceOp(op, adaptor.getOperands());
- return success();
+ using SparseGetterOpConverter::SparseGetterOpConverter;
+ // Callback for SparseGetterOpConverter.
+ static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
+ ToIndicesOp op) {
+ Optional<int64_t> dim = getConstantInt(op.getDim());
+ if (!dim)
+ return llvm::None; // variable dim
+ return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*dim);
+ }
+};
+
+/// Sparse codegen rule for value accesses.
+class SparseToValuesConverter
+ : public SparseGetterOpConverter<ToValuesOp, SparseToValuesConverter> {
+public:
+ using SparseGetterOpConverter::SparseGetterOpConverter;
+ // Callback for SparseGetterOpConverter.
+ static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp tuple,
+ ToValuesOp /*op*/) {
+ // The last field holds the value buffer.
+ return tuple.getInputs().size() - 1;
}
};
@@ -466,9 +546,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
- SparseTensorAllocConverter, SparseTensorDeallocConverter,
- SparseToPointersConverter, SparseToIndicesConverter,
- SparseToValuesConverter, SparseTensorLoadConverter>(
- typeConverter, patterns.getContext());
+ patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
+ SparseCastConverter, SparseTensorAllocConverter,
+ SparseTensorDeallocConverter, SparseToPointersConverter,
+ SparseToIndicesConverter, SparseToValuesConverter,
+ SparseTensorLoadConverter>(typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 505ae79e26fac..fee4222cb53d3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -24,7 +24,6 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
-#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@@ -154,9 +153,8 @@ struct SparseTensorCodegenPass
RewritePatternSet patterns(ctx);
SparseTensorTypeToBufferConverter converter;
ConversionTarget target(*ctx);
- // Almost everything in the sparse dialect must go!
+ // Everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
- target.addLegalOp<StorageGetOp, StorageSetOp, StorageOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
@@ -181,53 +179,13 @@ struct SparseTensorCodegenPass
target.addLegalDialect<arith::ArithmeticDialect,
bufferization::BufferizationDialect,
memref::MemRefDialect, scf::SCFDialect>();
- // Populate with rules and apply rewriting rules.
- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
- converter);
- populateCallOpTypeConversionPattern(patterns, converter);
- scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
- target);
- populateSparseTensorCodegenPatterns(converter, patterns);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- signalPassFailure();
- }
-};
-
-struct SparseTensorStorageExpansionPass
- : public impl::SparseTensorStorageExpansionBase<
- SparseTensorStorageExpansionPass> {
-
- SparseTensorStorageExpansionPass() = default;
- SparseTensorStorageExpansionPass(
- const SparseTensorStorageExpansionPass &pass) = default;
-
- void runOnOperation() override {
- auto *ctx = &getContext();
- RewritePatternSet patterns(ctx);
- SparseTensorStorageTupleExpander converter;
- ConversionTarget target(*ctx);
- // Now, everything in the sparse dialect must go!
- target.addIllegalDialect<SparseTensorDialect>();
- // All dynamic rules below accept new function, call, return.
- target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType());
- });
- target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
- return converter.isSignatureLegal(op.getCalleeType());
- });
- target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
- return converter.isLegal(op.getOperandTypes());
- });
- // We generate UnrealizedConversionCastOp to intermix tuples and a
- // list of types.
target.addLegalOp<UnrealizedConversionCastOp>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
- populateSparseTensorStorageExpansionPatterns(converter, patterns);
+ populateSparseTensorCodegenPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -277,7 +235,3 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
return std::make_unique<SparseTensorCodegenPass>();
}
-
-std::unique_ptr<Pass> mlir::createSparseTensorStorageExpansionPass() {
- return std::make_unique<SparseTensorStorageExpansionPass>();
-}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
deleted file mode 100644
index 1f7afa1d77804..0000000000000
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
+++ /dev/null
@@ -1,218 +0,0 @@
-//===- SparseTensorStorageExpansion.cpp - Sparse tensor storage expansion ===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// The sparse tensor storage expansion pass expands the compound storage for
-// sparse tensors (using tuple) to flattened SSA values.
-//
-//===----------------------------------------------------------------------===//
-
-#include "CodegenUtils.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
-#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-using namespace mlir;
-using namespace mlir::sparse_tensor;
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Helper methods.
-//===----------------------------------------------------------------------===//
-
-/// Expands sparse tensor storage tuple.
-static Optional<LogicalResult>
-convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
- if (auto tuple = t.dyn_cast<TupleType>()) {
- // Note that it does not handle nest tuples, but it is fine
- // for sparse compiler as they will not be generated.
- result.append(tuple.getTypes().begin(), tuple.getTypes().end());
- return success();
- }
- return llvm::None;
-}
-
-/// Flatten a list of operands that may contain tuples.
-static void flattenOperands(ValueRange operands,
- SmallVectorImpl<Value> &flattened) {
- // In case of
- // tuple<a, b>, c, tuple<d, e>
- // ==>
- // a, b, c, d, e
- for (auto operand : operands) {
- if (auto cast =
- dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
- cast && cast->getResultTypes()[0].isa<TupleType>())
- // An unrealized_conversion_cast will be inserted by type converter to
- // inter-mix the gap between 1:N conversion between tuple and types.
- // In this case, take the operands in the cast and replace the tuple
- // output with the flattened type array.
- flattened.append(cast.getOperands().begin(), cast.getOperands().end());
- else
- flattened.push_back(operand);
- }
-}
-//===----------------------------------------------------------------------===//
-// Conversion rules.
-//===----------------------------------------------------------------------===//
-
-/// Sparse tensor storage conversion rule for sparse_tensor::storage.
-class SparseStorageConversion : public OpConversionPattern<StorageOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(StorageOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Simply convert it to a unrealize_conversion_cast.
- // We should guarantee that all uses of sparse_tensor.storage op will
- // be eventually eliminated by accessing the flattened SSA values directly.
- rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
- op, TypeRange{op.getType()}, adaptor.getInputs());
- return success();
- }
-};
-
-/// Sparse tensor storage conversion rule for sparse_tensor::storage_get.
-class SparseStorageGetConverter : public OpConversionPattern<StorageGetOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(StorageGetOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto castOp =
- cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
- uint64_t idx = op.getIdx().getZExtValue();
- assert(idx < castOp.getOperands().size());
-
- rewriter.replaceOp(op, castOp.getOperand(idx));
- return success();
- }
-};
-
-/// Sparse tensor storage conversion rule for sparse_tensor::storage_set.
-class SparseStorageSetConverter : public OpConversionPattern<StorageSetOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(StorageSetOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto castOp =
- cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
- uint64_t idx = op.getIdx().getZExtValue();
-
- SmallVector<Value, 8> values(castOp.getOperands());
- assert(idx < values.size());
-
- // Updates the corresponding element.
- values[idx] = adaptor.getValue();
- rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
- op, TypeRange{op.getType()}, values);
- return success();
- }
-};
-
-/// Sparse tensor storage conversion rule for returns.
-class SparseStorageReturnConverter
- : public OpConversionPattern<func::ReturnOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value, 8> flattened;
- flattenOperands(adaptor.getOperands(), flattened);
- // Create a return with the flattened value extracted from tuple.
- rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
- return success();
- }
-};
-
-/// Sparse tensor storage conversion rule for calls.
-class SparseStorageCallConverter : public OpConversionPattern<func::CallOp> {
-public:
- // The default CallOp converter can not handle 1:N type conversion properly
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- // In case of:
- // tuple(a, b), f, tuple(c, d) = call @foo(...)
- // ==>
- // a, b, f, c, d = call @foo(...)
- // cast(a, b)->tuple, f, cast(c,d)->tuple
- SmallVector<Type, 8> finalRetTy;
- if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
- return failure();
-
- // (1) Genereates new call with flattened return value.
- SmallVector<Value, 8> flattened;
- flattenOperands(adaptor.getOperands(), flattened);
- auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
- finalRetTy, flattened);
-
- // (2) Create cast operation for tuple returns.
- SmallVector<Value, 4> castedRet;
- // Tracks the offset of current return value (of the orignal call)
- // relative to the new call (after tuple flattening);
- unsigned retOffset = 0;
- for (auto ret : op.getResults()) {
- assert(retOffset < newCall.getNumResults());
- auto tupleRet = ret.getType().dyn_cast<TupleType>();
- if (tupleRet) {
- auto tupleSize = tupleRet.size();
- // NOTE: The range is computed under the assumption of non-recursive
- // tuple type.
- ValueRange tupleElem(iterator_range<ResultRange::iterator>(
- newCall.result_begin() + retOffset,
- newCall.result_begin() + retOffset + tupleSize));
- auto castOp = rewriter.create<UnrealizedConversionCastOp>(
- loc, TypeRange({tupleRet}), tupleElem);
- castedRet.push_back(castOp.getResult(0));
- retOffset += tupleSize;
- } else {
- // If this not a tuple, simply add it into returned values.
- castedRet.push_back(ret);
- retOffset++;
- }
- }
-
- assert(castedRet.size() == op.getNumResults());
- rewriter.replaceOp(op, castedRet);
- return success();
- }
-};
-
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Sparse tensor storage expansion
-//===----------------------------------------------------------------------===//
-
-mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
- addConversion([](Type type) { return type; });
- addConversion(convertSparseTensorStorageTuple);
-}
-
-//===----------------------------------------------------------------------===//
-// Public method for populating conversion rules.
-//===----------------------------------------------------------------------===//
-
-/// Populates the given patterns list with conversion rules required
-/// to expand compounded sparse tensor tuples.
-void mlir::populateSparseTensorStorageExpansionPatterns(
- TypeConverter &typeConverter, RewritePatternSet &patterns) {
- patterns.add<SparseStorageConversion, SparseStorageGetConverter,
- SparseStorageSetConverter, SparseStorageReturnConverter,
- SparseStorageCallConverter>(typeConverter,
- patterns.getContext());
-}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 8c7968022e6f6..89fb8a9129fa5 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-CODEGEN
-// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-STORAGE
+// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = [ "compressed" ],
@@ -41,96 +40,114 @@
dimOrdering = affine_map<(i, j, k) -> (k, i, j)>
}>
-// CHECK-CODEGEN-LABEL: func @sparse_nop(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-// CHECK-CODEGEN: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
-//
-// CHECK-STORAGE-LABEL: func @sparse_nop(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf64>)
-// CHECK-STORAGE: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+// CHECK-LABEL: func @sparse_nop(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>)
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
return %arg0 : tensor<?xf64, #SparseVector>
}
-// CHECK-CODEGEN-LABEL: func @sparse_nop_cast(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>>)
-// CHECK-CODEGEN: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>>
+// CHECK-LABEL: func @sparse_nop_multi_ret(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<1xindex>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xf64>) ->
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]]
+func.func @sparse_nop_multi_ret(%arg0: tensor<?xf64, #SparseVector>,
+ %arg1: tensor<?xf64, #SparseVector>) ->
+ (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
+ return %arg0, %arg1 : tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_nop_call(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<1xindex>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
+// CHECK: %[[T0:.*]]:8 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]])
+// CHECK: return %[[T0]]#0, %[[T0]]#1, %[[T0]]#2, %[[T0]]#3, %[[T0]]#4, %[[T0]]#5, %[[T0]]#6, %[[T0]]#7
+func.func @sparse_nop_call(%arg0: tensor<?xf64, #SparseVector>,
+ %arg1: tensor<?xf64, #SparseVector>) ->
+ (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
+ %1, %2 = call @sparse_nop_multi_ret(%arg0, %arg1) :
+ (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) ->
+ (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
+ return %1, %2: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>
+}
+
//
-// CHECK-STORAGE-LABEL: func @sparse_nop_cast(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<1xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf32>)
-// CHECK-STORAGE: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
+// CHECK-LABEL: func @sparse_nop_cast(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xf32>)
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
%0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
}
-// CHECK-CODEGEN-LABEL: func @sparse_nop_cast_3d(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf32>>)
-// CHECK-CODEGEN: return %[[A]] : tuple<memref<3xindex>, memref<?xf32>>
//
-// CHECK-STORAGE-LABEL: func @sparse_nop_cast_3d(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xf32>)
-// CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf32>
+// CHECK-LABEL: func @sparse_nop_cast_3d(
+// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xf32>)
+// CHECK: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf32>
func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?x?x?xf32, #Dense3D> {
%0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor<?x?x?xf32, #Dense3D>
return %0 : tensor<?x?x?xf32, #Dense3D>
}
-// CHECK-CODEGEN-LABEL: func @sparse_dense_2d(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xf64>>)
//
-// CHECK-STORAGE-LABEL: func @sparse_dense_2d(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xf64>) {
-// CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_dense_2d(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xf64>) {
+// CHECK: return
func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
return
}
-// CHECK-CODEGEN-LABEL: func @sparse_row(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
//
-// CHECK-STORAGE-LABEL: func @sparse_row(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf64>) {
-// CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_row(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>) {
+// CHECK: return
func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
return
}
-// CHECK-CODEGEN-LABEL: func @sparse_csr(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
//
-// CHECK-STORAGE-LABEL: func @sparse_csr(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf64>) {
-// CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_csr(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>) {
+// CHECK: return
func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
return
}
-// CHECK-CODEGEN-LABEL: func @sparse_dcsr(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
//
-// CHECK-STORAGE-LABEL: func @sparse_dcsr(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref<?xf64>) {
-// CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_dcsr(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>) {
+// CHECK: return
func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
return
}
@@ -139,16 +156,12 @@ func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
// Querying for dimension 1 in the tensor type can immediately
// fold using the original static dimension sizes.
//
-// CHECK-CODEGEN-LABEL: func @sparse_dense_3d(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>)
-// CHECK-CODEGEN: %[[C:.*]] = arith.constant 20 : index
-// CHECK-CODEGEN: return %[[C]] : index
//
-// CHECK-STORAGE-LABEL: func @sparse_dense_3d(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xf64>)
-// CHECK-STORAGE: %[[C:.*]] = arith.constant 20 : index
-// CHECK-STORAGE: return %[[C]] : index
+// CHECK-LABEL: func @sparse_dense_3d(
+// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xf64>)
+// CHECK: %[[C:.*]] = arith.constant 20 : index
+// CHECK: return %[[C]] : index
func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
%c = arith.constant 1 : index
%0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D>
@@ -160,103 +173,74 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
// into querying for dimension 2 in the stored sparse tensor scheme,
// since the latter honors the dimOrdering.
//
-// CHECK-CODEGEN-LABEL: func @sparse_dense_3d_dyn(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>)
-// CHECK-CODEGEN: %[[C:.*]] = arith.constant 2 : index
-// CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<3xindex>, memref<?xf64>> to memref<3xindex>
-// CHECK-CODEGEN: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex>
-// CHECK-CODEGEN: return %[[L]] : index
//
-// CHECK-STORAGE-LABEL: func @sparse_dense_3d_dyn(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<3xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xf64>)
-// CHECK-STORAGE: %[[C:.*]] = arith.constant 2 : index
-// CHECK-STORAGE: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex>
-// CHECK-STORAGE: return %[[L]] : index
+// CHECK-LABEL: func @sparse_dense_3d_dyn(
+// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xf64>)
+// CHECK: %[[C:.*]] = arith.constant 2 : index
+// CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex>
+// CHECK: return %[[L]] : index
func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
%c = arith.constant 1 : index
%0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #Dense3D>
return %0 : index
}
-// CHECK-CODEGEN-LABEL: func @sparse_pointers_dcsr(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-// CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
-// CHECK-CODEGEN: return %[[F]] : memref<?xi32>
//
-// CHECK-STORAGE-LABEL: func @sparse_pointers_dcsr(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref<?xf64>)
-// CHECK-STORAGE: return %[[A3]] : memref<?xi32>
+// CHECK-LABEL: func @sparse_pointers_dcsr(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+// CHECK: return %[[A3]] : memref<?xi32>
func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
%c = arith.constant 1 : index
%0 = sparse_tensor.pointers %arg0, %c : tensor<?x?xf64, #DCSR> to memref<?xi32>
return %0 : memref<?xi32>
}
-// CHECK-CODEGEN-LABEL: func @sparse_indices_dcsr(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-// CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][4] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
-// CHECK-CODEGEN: return %[[F]] : memref<?xi64>
//
-// CHECK-STORAGE-LABEL: func @sparse_indices_dcsr(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref<?xf64>)
-// CHECK-STORAGE: return %[[A4]] : memref<?xi64>
+// CHECK-LABEL: func @sparse_indices_dcsr(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+// CHECK: return %[[A4]] : memref<?xi64>
func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
%c = arith.constant 1 : index
%0 = sparse_tensor.indices %arg0, %c : tensor<?x?xf64, #DCSR> to memref<?xi64>
return %0 : memref<?xi64>
}
-// CHECK-CODEGEN-LABEL: func @sparse_values_dcsr(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-// CHECK-CODEGEN: %[[F:.*]] = sparse_tensor.storage_get %[[A]][5] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
-// CHECK-CODEGEN: return %[[F]] : memref<?xf64>
//
-// CHECK-STORAGE-LABEL: func @sparse_values_dcsr(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A4:.*4]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A5:.*5]]: memref<?xf64>)
-// CHECK-STORAGE: return %[[A5]] : memref<?xf64>
+// CHECK-LABEL: func @sparse_values_dcsr(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+// CHECK: return %[[A5]] : memref<?xf64>
func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
return %0 : memref<?xf64>
}
-// CHECK-CODEGEN-LABEL: func @sparse_dealloc_csr(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
-// CHECK-CODEGEN: %[[F0:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<2xindex>
-// CHECK-CODEGEN: memref.dealloc %[[F0]] : memref<2xindex>
-// CHECK-CODEGEN: %[[F1:.*]] = sparse_tensor.storage_get %[[A]][1] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
-// CHECK-CODEGEN: memref.dealloc %[[F1]] : memref<?xi32>
-// CHECK-CODEGEN: %[[F2:.*]] = sparse_tensor.storage_get %[[A]][2] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
-// CHECK-CODEGEN: memref.dealloc %[[F2]] : memref<?xi64>
-// CHECK-CODEGEN: %[[F3:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
-// CHECK-CODEGEN: memref.dealloc %[[F3]] : memref<?xf64>
-// CHECK-CODEGEN: return
//
-// CHECK-STORAGE-LABEL: func @sparse_dealloc_csr(
-// CHECK-STORAGE-SAME: %[[A0:.*0]]: memref<2xindex>,
-// CHECK-STORAGE-SAME: %[[A1:.*1]]: memref<?xi32>,
-// CHECK-STORAGE-SAME: %[[A2:.*2]]: memref<?xi64>,
-// CHECK-STORAGE-SAME: %[[A3:.*3]]: memref<?xf64>) {
-// CHECK-STORAGE: memref.dealloc %[[A0]] : memref<2xindex>
-// CHECK-STORAGE: memref.dealloc %[[A1]] : memref<?xi32>
-// CHECK-STORAGE: memref.dealloc %[[A2]] : memref<?xi64>
-// CHECK-STORAGE: memref.dealloc %[[A3]] : memref<?xf64>
-// CHECK-STORAGE: return
+// CHECK-LABEL: func @sparse_dealloc_csr(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi32>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi64>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xf64>) {
+// CHECK: memref.dealloc %[[A0]] : memref<2xindex>
+// CHECK: memref.dealloc %[[A1]] : memref<?xi32>
+// CHECK: memref.dealloc %[[A2]] : memref<?xi64>
+// CHECK: memref.dealloc %[[A3]] : memref<?xf64>
+// CHECK: return
func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
return
@@ -264,8 +248,7 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
// CHECK-LABEL: func @sparse_alloc_csc(
// CHECK-SAME: %[[A:.*]]: index) ->
-// CHECK-CODEGEN-SAME: tuple<memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>>
-// CHECK-STORAGE-SAME: memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK-SAME: memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
@@ -278,9 +261,7 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
// CHECK: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[T5:.*]] = memref.alloc() : memref<1xf64>
// CHECK: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref<?xf64>
-// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]])
-// CHECK-CODEGEN: return %[[T]]
-// CHECK-STORAGE: return %[[T0]], %[[T2]], %[[T4]], %[[T6]]
+// CHECK: return %[[T0]], %[[T2]], %[[T4]], %[[T6]]
func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
%1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC>
@@ -288,8 +269,7 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
}
// CHECK-LABEL: func @sparse_alloc_3d() ->
-// CHECK-CODEGEN-SAME: tuple<memref<3xindex>, memref<?xf64>>
-// CHECK-STORAGE-SAME: memref<3xindex>, memref<?xf64>
+// CHECK-SAME: memref<3xindex>, memref<?xf64>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -302,9 +282,7 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex>
// CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64>
// CHECK: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref<?xf64>
-// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]])
-// CHECK-CODEGEN: return %[[T]] : tuple<memref<3xindex>, memref<?xf64>>
-// CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf64>
+// CHECK: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf64>
func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
%0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
%1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index b9555e8861a25..ce495e0c7f227 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -442,63 +442,3 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
return %0 : tensor<9x4xf64, #DC>
}
-
-// -----
-
-func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
- tuple<memref<?xf64>, memref<?xf64>> {
- // expected-error at +1{{The number of inputs is inconsistent with output}}
- %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
- : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>>
- return %0 : tuple<memref<?xf64>, memref<?xf64>>
-}
-
-// -----
-
-func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
- tuple<memref<?xi64>, memref<?xf64>, f64> {
- // expected-error at +1{{Type mismatch between}}
- %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
- : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xi64>, memref<?xf64>, f64>
- return %0 : tuple<memref<?xi64>, memref<?xf64>, f64>
-}
-
-// -----
-
-func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
- // expected-error at +1{{Out-of-bound access}}
- %0 = sparse_tensor.storage_get %arg0[3]
- : tuple<memref<?xf64>, memref<?xf64>, f64> to
- memref<?xf64>
- return %0 : memref<?xf64>
-}
-
-// -----
-
-func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
- // expected-error at +1{{Type mismatch}}
- %0 = sparse_tensor.storage_get %arg0[2]
- : tuple<memref<?xf64>, memref<?xf64>, f64> to
- memref<?xf64>
- return %0 : memref<?xf64>
-}
-
-// -----
-
-func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
- // expected-error at +1{{Out-of-bound access}}
- %0 = sparse_tensor.storage_set %arg0[3], %arg1
- : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
- tuple<memref<?xf64>, memref<?xf64>, f64>
- return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
-
-// -----
-
-func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
- // expected-error at +1{{Type mismatch}}
- %0 = sparse_tensor.storage_set %arg0[2], %arg1
- : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
- tuple<memref<?xf64>, memref<?xf64>, f64>
- return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index c37b4e7b53ac8..5edc977de7c00 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -314,50 +314,3 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
return %0 : tensor<9x4xf64, #SparseMatrix>
}
-
-// -----
-
-
-// CHECK: func @sparse_storage_new(
-// CHECK-SAME: %[[A0:.*0]]: memref<?xf64>,
-// CHECK-SAME: %[[A1:.*1]]: memref<?xf64>,
-// CHECK-SAME: %[[A2:.*]]: f64
-// CHECK: %[[TMP_0:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]], %[[A2]])
-// CHECK: return %[[TMP_0]] : tuple<memref<?xf64>, memref<?xf64>, f64>
-func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
- tuple<memref<?xf64>, memref<?xf64>, f64> {
- %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
- : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>, f64>
- return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_storage_get(
-// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
-// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
-// CHECK-SAME: tuple<memref<?xf64>, memref<?xf64>, f64>
-// CHECK-SAME: to memref<?xf64>
-// CHECK: return %[[TMP0]] : memref<?xf64>
-func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
- %0 = sparse_tensor.storage_get %arg0[0]
- : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
- return %0 : memref<?xf64>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_storage_set(
-// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,
-// CHECK-SAME: %[[A1:.*]]: memref<?xf64>
-// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] :
-// CHECK-SAME: tuple<memref<?xf64>, memref<?xf64>, f64>,
-// CHECK-SAME: memref<?xf64>
-// CHECK-SAME: to tuple<memref<?xf64>, memref<?xf64>, f64>
-// CHECK: return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
- %0 = sparse_tensor.storage_set %arg0[0], %arg1
- : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
- tuple<memref<?xf64>, memref<?xf64>, f64>
- return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
deleted file mode 100644
index d2d4769353a3c..0000000000000
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
+++ /dev/null
@@ -1,60 +0,0 @@
-// RUN: mlir-opt %s -sparse-tensor-storage-expansion -cse | FileCheck %s
-
-// CHECK-LABEL: func @sparse_storage_expand(
-// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg2:.*]]: f64
-// CHECK return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
-func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
- -> tuple<memref<?xf64>, memref<?xf64>, f64> {
- return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
-
-// CHECK-LABEL: func @call_sparse_storage_expand(
-// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
-// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]])
-// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref<?xf64>, memref<?xf64>, f64
-func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
- -> tuple<memref<?xf64>, memref<?xf64>, f64> {
- %1 = call @sparse_storage_expand(%arg0) : (tuple<memref<?xf64>, memref<?xf64>, f64>) ->
- tuple<memref<?xf64>, memref<?xf64>, f64>
- return %1 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
-
-// CHECK-LABEL: func @sparse_storage(
-// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg2:.*2]]: memref<?xf64>)
-// CHECK: return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
-func.func @sparse_storage(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: memref<?xf64>)
- -> tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>> {
- %1 = sparse_tensor.storage(%arg0, %arg1, %arg2) : memref<?xf64>, memref<?xf64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>>
- return %1 : tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>>
-}
-
-// CHECK-LABEL: func @sparse_storage_get(
-// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
-// CHECK: return %[[TMP_arg0]] : memref<?xf64>
-func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
- %0 = sparse_tensor.storage_get %arg0[0]
- : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
- return %0 : memref<?xf64>
-}
-
-// CHECK-LABEL: func @sparse_storage_set(
-// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg2:.*]]: f64,
-// CHECK-SAME: %[[TMP_arg3:.*]]: memref<?xf64>)
-// CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref<?xf64>, memref<?xf64>, f64
-func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>,
- %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
- %0 = sparse_tensor.storage_set %arg0[0], %arg1
- : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
- tuple<memref<?xf64>, memref<?xf64>, f64>
- return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
-}
More information about the Mlir-commits
mailing list