[Mlir-commits] [mlir] 083ddff - [mlir][sparse] introduce sparse_tensor::StorageSpecifierToLLVM pass
Peiming Liu
llvmlistbot at llvm.org
Thu Dec 22 14:45:20 PST 2022
Author: Peiming Liu
Date: 2022-12-22T22:45:15Z
New Revision: 083ddffe476b73f7c05b06c3f6f54ebfcbf34727
URL: https://github.com/llvm/llvm-project/commit/083ddffe476b73f7c05b06c3f6f54ebfcbf34727
DIFF: https://github.com/llvm/llvm-project/commit/083ddffe476b73f7c05b06c3f6f54ebfcbf34727.diff
LOG: [mlir][sparse] introduce sparse_tensor::StorageSpecifierToLLVM pass
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D140122
Added:
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index e8396b6bb7716..fc4ab2cb792ff 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -52,9 +52,9 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
let parameters = (ins SparseTensorEncodingAttr : $encoding);
let builders = [
+ TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>,
TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{
- assert(encoding && "sparse tensor encoding should not be null");
- return $_get(encoding.getContext(), encoding);
+ return get(encoding.getContext(), encoding);
}]>,
TypeBuilderWithInferredContext<(ins "Type":$type), [{
return get(getSparseTensorEncoding(type));
@@ -71,6 +71,10 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
Type getFieldType(StorageSpecifierKind kind, std::optional<APInt> dim) const;
}];
+ // We skipped the default builder that simply takes the input sparse tensor encoding
+ // attribute since we need to normalize the dimension level type and remove unrelated
+ // fields that are irrelavant to sparse tensor storage scheme.
+ let skipDefaultBuilders = 1;
let assemblyFormat="`<` qualified($encoding) `>`";
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d4e3210a6a342..df465e443d36a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -158,6 +158,19 @@ std::unique_ptr<Pass>
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
bool enableConvert = true);
+//===----------------------------------------------------------------------===//
+// The SparseStorageSpecifierToLLVM pass.
+//===----------------------------------------------------------------------===//
+
+class StorageSpecifierToLLVMTypeConverter : public TypeConverter {
+public:
+ StorageSpecifierToLLVMTypeConverter();
+};
+
+void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
+ RewritePatternSet &patterns);
+std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
+
//===----------------------------------------------------------------------===//
// Other rewriting rules and passes.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 4c827c1432056..6ec9d4eb9e46e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -301,4 +301,28 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
];
}
+def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"> {
+ let summary = "Lower sparse storage specifer to llvm structure";
+ let description = [{
+ This pass rewrites sparse tensor storage specifier-related operations into
+ LLVMDialect, and converts sparse tensor storage specifier into an llvm.struct.
+
+ Example of the conversion:
+ ```mlir
+ Before:
+ %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
+ : !sparse_tensor.storage_specifier<#CSR> to i64
+
+ After:
+ %0 = llvm.extractvalue %arg0[0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+ ```
+ }];
+ let constructor = "mlir::createStorageSpecifierToLLVMPass()";
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "LLVM::LLVMDialect",
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e4e54b7143319..1e9aab8deb172 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -323,6 +323,28 @@ uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) {
// SparseTensorDialect Types.
//===----------------------------------------------------------------------===//
+/// We normalized sparse tensor encoding attribute by always using
+/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
+/// as other variants) lead to the same storage specifier type, and stripping
+/// irrelevant fields that does not alter the sparse tensor memory layout.
+static SparseTensorEncodingAttr
+getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
+ SmallVector<DimLevelType> dlts;
+ for (auto dlt : enc.getDimLevelType())
+ dlts.push_back(*getDimLevelType(*getLevelFormat(dlt), true, true));
+
+ return SparseTensorEncodingAttr::get(
+ enc.getContext(), dlts,
+ AffineMap(), // dimOrdering (irrelavant to storage speicifer)
+ AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
+ enc.getPointerBitWidth(), enc.getIndexBitWidth());
+}
+
+StorageSpecifierType
+StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
+ return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
+}
+
IntegerType StorageSpecifierType::getSizesType() const {
unsigned idxBitWidth =
getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index cfa73bf246ecc..410bf343b8fc1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -3,10 +3,12 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
CodegenEnv.cpp
CodegenUtils.cpp
SparseBufferRewriting.cpp
+ SparseStorageSpecifierToLLVM.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
+ SparseTensorStorageLayout.cpp
SparseVectorization.cpp
Sparsification.cpp
SparsificationAndBufferizationPass.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
new file mode 100644
index 0000000000000..d6a9007baad0c
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -0,0 +1,184 @@
+//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+#include "SparseTensorStorageLayout.h"
+
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace sparse_tensor;
+
+static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
+ MLIRContext *ctx = tp.getContext();
+ auto enc = tp.getEncoding();
+ unsigned rank = enc.getDimLevelType().size();
+
+ SmallVector<Type, 2> result;
+ auto indexType = tp.getSizesType();
+ auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank);
+ auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
+ getNumDataFieldsFromEncoding(enc));
+ result.push_back(dimSizes);
+ result.push_back(memSizes);
+ return result;
+}
+
+static Type convertSpecifier(StorageSpecifierType tp) {
+ return LLVM::LLVMStructType::getLiteral(tp.getContext(),
+ getSpecifierFields(tp));
+}
+
+StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
+}
+
+constexpr uint64_t kDimSizePosInSpecifier = 0;
+constexpr uint64_t kMemSizePosInSpecifier = 1;
+
+class SpecifierStructBuilder : public StructBuilder {
+public:
+ explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
+ assert(value);
+ }
+
+ // Undef value for dimension sizes, all zero value for memory sizes.
+ static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
+
+ Value dimSize(OpBuilder &builder, Location loc, unsigned dim);
+ void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size);
+
+ Value memSize(OpBuilder &builder, Location loc, unsigned pos);
+ void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
+};
+
+Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
+ Type structType) {
+ Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
+ SpecifierStructBuilder md(metaData);
+ auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
+ .getBody()[kMemSizePosInSpecifier]
+ .cast<LLVM::LLVMArrayType>();
+
+ Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
+ // Fill memSizes array with zero.
+ for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
+ md.setMemSize(builder, loc, i, zero);
+
+ return md;
+}
+
+/// Builds IR inserting the pos-th size into the descriptor.
+Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
+ unsigned dim) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, value, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
+}
+
+/// Builds IR inserting the pos-th size into the descriptor.
+void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
+ unsigned dim, Value size) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, value, size, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
+}
+
+/// Builds IR extracting the pos-th memory size into the descriptor.
+Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
+ unsigned pos) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, value, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
+}
+
+/// Builds IR inserting the pos-th memory size into the descriptor.
+void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
+ unsigned pos, Value size) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
+}
+
+template <typename Base, typename SourceOp>
+class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
+public:
+ using OpAdaptor = typename SourceOp::Adaptor;
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SpecifierStructBuilder spec(adaptor.getSpecifier());
+ Value v;
+ if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) {
+ v = Base::onDimSize(rewriter, op, spec,
+ op.getDim().value().getZExtValue());
+ } else {
+ auto enc = op.getSpecifier().getType().getEncoding();
+ builder::StorageLayout layout(enc);
+ Optional<unsigned> dim = std::nullopt;
+ if (op.getDim())
+ dim = op.getDim().value().getZExtValue();
+ unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim);
+ v = Base::onMemSize(rewriter, op, spec, idx);
+ }
+
+ rewriter.replaceOp(op, v);
+ return success();
+ }
+};
+
+struct StorageSpecifierSetOpConverter
+ : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
+ SetStorageSpecifierOp> {
+ using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+ static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op,
+ SpecifierStructBuilder &spec, unsigned d) {
+ spec.setDimSize(builder, op.getLoc(), d, op.getValue());
+ return spec;
+ }
+
+ static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
+ SpecifierStructBuilder &spec, unsigned i) {
+ spec.setMemSize(builder, op.getLoc(), i, op.getValue());
+ return spec;
+ }
+};
+
+struct StorageSpecifierGetOpConverter
+ : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
+ GetStorageSpecifierOp> {
+ using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+ static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op,
+ SpecifierStructBuilder &spec, unsigned d) {
+ return spec.dimSize(builder, op.getLoc(), d);
+ }
+ static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
+ SpecifierStructBuilder &spec, unsigned i) {
+ return spec.memSize(builder, op.getLoc(), i);
+ }
+};
+
+struct StorageSpecifierInitOpConverter
+ : public OpConversionPattern<StorageSpecifierInitOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
+ rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
+ rewriter, op.getLoc(), llvmType));
+ return success();
+ }
+};
+
+void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
+ StorageSpecifierInitOpConverter>(converter,
+ patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 6845aa0d81877..c384d6ad966c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -28,6 +28,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
+#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@@ -193,9 +194,14 @@ struct SparseTensorCodegenPass
target.addLegalOp<SortOp>();
target.addLegalOp<SortCooOp>();
target.addLegalOp<PushBackOp>();
- // All dynamic rules below accept new function, call, return, and various
- // tensor and bufferization operations as legal output of the rewriting
- // provided that all sparse tensor types have been fully rewritten.
+ // Storage specifier outlives sparse tensor pipeline.
+ target.addLegalOp<GetStorageSpecifierOp>();
+ target.addLegalOp<SetStorageSpecifierOp>();
+ target.addLegalOp<StorageSpecifierInitOp>();
+ // All dynamic rules below accept new function, call, return, and
+ // various tensor and bufferization operations as legal output of the
+ // rewriting provided that all sparse tensor types have been fully
+ // rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
@@ -271,6 +277,44 @@ struct SparseVectorizationPass
}
};
+struct StorageSpecifierToLLVMPass
+ : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
+
+ StorageSpecifierToLLVMPass() = default;
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ ConversionTarget target(*ctx);
+ RewritePatternSet patterns(ctx);
+ StorageSpecifierToLLVMTypeConverter converter;
+
+ // All ops in the sparse dialect must go!
+ target.addIllegalDialect<SparseTensorDialect>();
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return converter.isSignatureLegal(op.getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+ return converter.isSignatureLegal(op.getCalleeType());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+ return converter.isLegal(op.getOperandTypes());
+ });
+ target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
+
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+ scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
+ populateStorageSpecifierToLLVMPatterns(converter, patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -355,3 +399,7 @@ mlir::createSparseVectorizationPass(unsigned vectorLength,
return std::make_unique<SparseVectorizationPass>(
vectorLength, enableVLAVectorization, enableSIMDIndex32);
}
+
+std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
+ return std::make_unique<StorageSpecifierToLLVMPass>();
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
new file mode 100644
index 0000000000000..4dcd0345ed9db
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -0,0 +1,188 @@
+//===- SparseTensorStorageLayout.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "SparseTensorStorageLayout.h"
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace sparse_tensor;
+
+static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
+ Type to) {
+ if (value.getType() != to)
+ return builder.create<arith::IndexCastOp>(loc, to, value);
+ return value;
+}
+
+static IntegerAttr fromOptionalInt(MLIRContext *ctx, Optional<unsigned> dim) {
+ if (!dim)
+ return nullptr;
+ return IntegerAttr::get(IndexType::get(ctx), dim.value());
+}
+
+unsigned
+builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
+ Optional<unsigned> dim) const {
+ unsigned fieldIdx = -1u;
+ foreachFieldInSparseTensor(
+ enc,
+ [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind,
+ unsigned fDim, DimLevelType dlt) -> bool {
+ if ((dim && fDim == dim.value() && kind == fKind) ||
+ (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
+ fieldIdx = fIdx;
+ // Returns false to break the iteration.
+ return false;
+ }
+ return true;
+ });
+ assert(fieldIdx != -1u);
+ return fieldIdx;
+}
+
+unsigned
+builder::StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind,
+ Optional<unsigned> dim) const {
+ return getMemRefFieldIndex(toFieldKind(kind), dim);
+}
+
+Value builder::SparseTensorSpecifier::getInitValue(OpBuilder &builder,
+ Location loc,
+ RankedTensorType rtp) {
+ return builder.create<StorageSpecifierInitOp>(
+ loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp)));
+}
+
+Value builder::SparseTensorSpecifier::getSpecifierField(
+ OpBuilder &builder, Location loc, StorageSpecifierKind kind,
+ Optional<unsigned> dim) {
+ return createIndexCast(builder, loc,
+ builder.create<GetStorageSpecifierOp>(
+ loc, getFieldType(kind, dim), specifier, kind,
+ fromOptionalInt(specifier.getContext(), dim)),
+ builder.getIndexType());
+}
+
+void builder::SparseTensorSpecifier::setSpecifierField(
+ OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind,
+ Optional<unsigned> dim) {
+ specifier = builder.create<SetStorageSpecifierOp>(
+ loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim),
+ createIndexCast(builder, loc, v, getFieldType(kind, dim)));
+}
+
+constexpr uint64_t kDataFieldStartingIdx = 0;
+
+void sparse_tensor::builder::foreachFieldInSparseTensor(
+ const SparseTensorEncodingAttr enc,
+ llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
+ DimLevelType)>
+ callback) {
+ assert(enc);
+
+#define RETURN_ON_FALSE(idx, kind, dim, dlt) \
+ if (!(callback(idx, kind, dim, dlt))) \
+ return;
+
+ static_assert(kDataFieldStartingIdx == 0);
+ unsigned fieldIdx = kDataFieldStartingIdx;
+ // Per-dimension storage.
+ for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) {
+ // Dimension level types apply in order to the reordered dimension.
+ // As a result, the compound type can be constructed directly in the given
+ // order.
+ auto dlt = getDimLevelType(enc, r);
+ if (isCompressedDLT(dlt)) {
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
+ } else if (isSingletonDLT(dlt)) {
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
+ } else {
+ assert(isDenseDLT(dlt)); // no fields
+ }
+ }
+ // The values array.
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
+ DimLevelType::Undef);
+
+ // Put metadata at the end.
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u,
+ DimLevelType::Undef);
+
+#undef RETURN_ON_FALSE
+}
+
+void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor(
+ RankedTensorType rType,
+ llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
+ DimLevelType)>
+ callback) {
+ auto enc = getSparseTensorEncoding(rType);
+ assert(enc);
+ // Construct the basic types.
+ Type idxType = enc.getIndexType();
+ Type ptrType = enc.getPointerType();
+ Type eltType = rType.getElementType();
+
+ Type metaDataType = StorageSpecifierType::get(enc);
+ // memref<? x ptr> pointers
+ Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
+ // memref<? x idx> indices
+ Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType);
+ // memref<? x eltType> values
+ Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
+
+ foreachFieldInSparseTensor(
+ enc,
+ [metaDataType, ptrMemType, idxMemType, valMemType,
+ callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
+ unsigned dim, DimLevelType dlt) -> bool {
+ switch (fieldKind) {
+ case SparseTensorFieldKind::StorageSpec:
+ return callback(metaDataType, fieldIdx, fieldKind, dim, dlt);
+ case SparseTensorFieldKind::PtrMemRef:
+ return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt);
+ case SparseTensorFieldKind::IdxMemRef:
+ return callback(idxMemType, fieldIdx, fieldKind, dim, dlt);
+ case SparseTensorFieldKind::ValMemRef:
+ return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
+ };
+ llvm_unreachable("unrecognized field kind");
+ });
+}
+
+unsigned
+sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
+ unsigned numFields = 0;
+ foreachFieldInSparseTensor(enc,
+ [&numFields](unsigned, SparseTensorFieldKind,
+ unsigned, DimLevelType) -> bool {
+ numFields++;
+ return true;
+ });
+ return numFields;
+}
+
+unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding(
+ SparseTensorEncodingAttr enc) {
+ unsigned numFields = 0; // one value memref
+ foreachFieldInSparseTensor(enc,
+ [&numFields](unsigned fidx, SparseTensorFieldKind,
+ unsigned, DimLevelType) -> bool {
+ if (fidx >= kDataFieldStartingIdx)
+ numFields++;
+ return true;
+ });
+ numFields -= 1; // the last field is MetaData field
+ assert(numFields ==
+ builder::getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
+ return numFields;
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
new file mode 100644
index 0000000000000..9b4e2352b8f34
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -0,0 +1,361 @@
+//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for lowering and accessing sparse tensor
+// types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
+
+#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+// FIXME: this is a tmp namespace
+namespace builder {
+//===----------------------------------------------------------------------===//
+// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout
+// scheme.
+//
+// Sparse tensor storage scheme for rank-dimensional tensor is organized
+// as a single compound type with the following fields. Note that every
+// memref with ? size actually behaves as a "vector", i.e. the stored
+// size is the capacity and the used size resides in the memSizes array.
+//
+// struct {
+// ; per-dimension d:
+// ; if dense:
+// <nothing>
+// ; if compresed:
+// memref<? x ptr> pointers-d ; pointers for sparse dim d
+// memref<? x idx> indices-d ; indices for sparse dim d
+// ; if singleton:
+// memref<? x idx> indices-d ; indices for singleton dim d
+// memref<? x eltType> values ; values
+//
+// ; sparse tensor metadata
+// struct {
+// array<rank x int> dimSizes ; sizes for each dimension
+// array<n x int> memSizes; ; sizes for each data memref
+// }
+// };
+//
+//===----------------------------------------------------------------------===//
+enum class SparseTensorFieldKind : uint32_t {
+ StorageSpec = 0,
+ PtrMemRef = 1,
+ IdxMemRef = 2,
+ ValMemRef = 3
+};
+
+static_assert(static_cast<uint32_t>(SparseTensorFieldKind::PtrMemRef) ==
+ static_cast<uint32_t>(StorageSpecifierKind::PtrMemSize));
+static_assert(static_cast<uint32_t>(SparseTensorFieldKind::IdxMemRef) ==
+ static_cast<uint32_t>(StorageSpecifierKind::IdxMemSize));
+static_assert(static_cast<uint32_t>(SparseTensorFieldKind::ValMemRef) ==
+ static_cast<uint32_t>(StorageSpecifierKind::ValMemSize));
+
+/// For each field that will be allocated for the given sparse tensor encoding,
+/// calls the callback with the corresponding field index, field kind, dimension
+/// (for sparse tensor level memrefs) and dimlevelType.
+/// The field index always starts with zero and increments by one between two
+/// callback invocations.
+/// Ideally, all other methods should rely on this function to query a sparse
+/// tensor fields instead of relying on ad-hoc index computation.
+void foreachFieldInSparseTensor(
+ SparseTensorEncodingAttr,
+ llvm::function_ref<bool(unsigned /*fieldIdx*/,
+ SparseTensorFieldKind /*fieldKind*/,
+ unsigned /*dim (if applicable)*/,
+ DimLevelType /*DLT (if applicable)*/)>);
+
+/// Same as above, except that it also builds the Type for the corresponding
+/// field.
+void foreachFieldAndTypeInSparseTensor(
+ RankedTensorType,
+ llvm::function_ref<bool(Type /*fieldType*/, unsigned /*fieldIdx*/,
+ SparseTensorFieldKind /*fieldKind*/,
+ unsigned /*dim (if applicable)*/,
+ DimLevelType /*DLT (if applicable)*/)>);
+
+/// Gets the total number of fields for the given sparse tensor encoding.
+unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
+
+/// Gets the total number of data fields (index arrays, pointer arrays, and a
+/// value array) for the given sparse tensor encoding.
+unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
+
+inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
+ assert(kind != SparseTensorFieldKind::StorageSpec);
+ return static_cast<StorageSpecifierKind>(kind);
+}
+
+inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
+ assert(kind != StorageSpecifierKind::DimSize);
+ return static_cast<SparseTensorFieldKind>(kind);
+}
+
+class StorageLayout {
+public:
+ explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {}
+
+ ///
+ /// Getters: get the field index for required field.
+ ///
+ unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
+ Optional<unsigned> dim) const;
+
+ unsigned getMemRefFieldIndex(StorageSpecifierKind kind,
+ Optional<unsigned> dim) const;
+
+private:
+ unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const;
+ SparseTensorEncodingAttr enc;
+};
+
+class SparseTensorSpecifier {
+public:
+ explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {}
+
+ // Undef value for dimension sizes, all zero value for memory sizes.
+ static Value getInitValue(OpBuilder &builder, Location loc,
+ RankedTensorType rtp);
+
+ /*implicit*/ operator Value() { return specifier; }
+
+ Value getSpecifierField(OpBuilder &builder, Location loc,
+ StorageSpecifierKind kind, Optional<unsigned> dim);
+
+ void setSpecifierField(OpBuilder &builder, Location loc, Value v,
+ StorageSpecifierKind kind, Optional<unsigned> dim);
+
+ Type getFieldType(StorageSpecifierKind kind, Optional<unsigned> dim) {
+ return specifier.getType().getFieldType(kind, dim);
+ }
+
+private:
+ TypedValue<StorageSpecifierType> specifier;
+};
+
+/// A helper class around an array of values that corresponding to a sparse
+/// tensor, provides a set of meaningful APIs to query and update a particular
+/// field in a consistent way.
+/// Users should not make assumption on how a sparse tensor is laid out but
+/// instead relies on this class to access the right value for the right field.
+template <bool mut>
+class SparseTensorDescriptorImpl {
+private:
+ // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
+ // for mutable descriptors.
+ // Using SmallVector for mutable descriptor allows users to reuse it as a tmp
+ // buffers to append value for some special cases, though users should be
+ // responsible to restore the buffer to legal states after their use. It is
+ // probably not a clean way, but it is the most efficient way to avoid copying
+ // the fields into another SmallVector. If a more clear way is wanted, we
+ // should change it to MutableArrayRef instead.
+ using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
+ ValueRange>::type;
+
+public:
+ SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
+ : rType(tp.cast<RankedTensorType>()), fields(fields) {
+ assert(getSparseTensorEncoding(tp) &&
+ builder::getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
+ fields.size());
+ // We should make sure the class is trivially copyable (and should be small
+ // enough) such that we can pass it by value.
+ static_assert(
+ std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
+ }
+
+ // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to
+ // SparseTensorDescriptor.
+ template <typename T = SparseTensorDescriptorImpl<true>>
+ /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t<!mut, T> &mDesc)
+ : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {}
+
+ unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
+ Optional<unsigned> dim) const {
+ // Delegates to storage layout.
+ StorageLayout layout(getSparseTensorEncoding(rType));
+ return layout.getMemRefFieldIndex(kind, dim);
+ }
+
+ unsigned getPtrMemRefIndex(unsigned ptrDim) const {
+ return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim);
+ }
+
+ unsigned getIdxMemRefIndex(unsigned idxDim) const {
+ return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim);
+ }
+
+ unsigned getValMemRefIndex() const {
+ return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt);
+ }
+
+ unsigned getNumFields() const { return fields.size(); }
+
+ ///
+ /// Getters: get the value for required field.
+ ///
+
+ Value getSpecifierField(OpBuilder &builder, Location loc,
+ StorageSpecifierKind kind,
+ Optional<unsigned> dim) const {
+ SparseTensorSpecifier md(fields.back());
+ return md.getSpecifierField(builder, loc, kind, dim);
+ }
+
+ Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim);
+ }
+
+ Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
+ dim);
+ }
+
+ Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
+ dim);
+ }
+
+ Value getValMemSize(OpBuilder &builder, Location loc) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+ std::nullopt);
+ }
+
+ Value getPtrMemRef(unsigned ptrDim) const {
+ return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
+ }
+
+ Value getIdxMemRef(unsigned idxDim) const {
+ return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim);
+ }
+
+ Value getValMemRef() const {
+ return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
+ }
+
+ Value getMemRefField(SparseTensorFieldKind kind,
+ Optional<unsigned> dim) const {
+ return fields[getMemRefFieldIndex(kind, dim)];
+ }
+
+ Value getMemRefField(unsigned fidx) const {
+ assert(fidx < fields.size() - 1);
+ return fields[fidx];
+ }
+
+ Value getField(unsigned fidx) const {
+ assert(fidx < fields.size());
+ return fields[fidx];
+ }
+
+ ///
+ /// Setters: update the value for required field (only enabled for
+ /// MutSparseTensorDescriptor).
+ ///
+
+ template <typename T = Value>
+ void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim,
+ std::enable_if_t<mut, T> v) {
+ fields[getMemRefFieldIndex(kind, dim)] = v;
+ }
+
+ template <typename T = Value>
+ void setMemRefField(unsigned fidx, std::enable_if_t<mut, T> v) {
+ assert(fidx < fields.size() - 1);
+ fields[fidx] = v;
+ }
+
+ template <typename T = Value>
+ void setField(unsigned fidx, std::enable_if_t<mut, T> v) {
+ assert(fidx < fields.size());
+ fields[fidx] = v;
+ }
+
+ template <typename T = Value>
+ void setSpecifierField(OpBuilder &builder, Location loc,
+ StorageSpecifierKind kind, Optional<unsigned> dim,
+ std::enable_if_t<mut, T> v) {
+ SparseTensorSpecifier md(fields.back());
+ md.setSpecifierField(builder, loc, v, kind, dim);
+ fields.back() = md;
+ }
+
+ template <typename T = Value>
+ void setDimSize(OpBuilder &builder, Location loc, unsigned dim,
+ std::enable_if_t<mut, T> v) {
+ setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
+ }
+
+ ValueRange getMemRefFields() const {
+ ValueRange ret = fields;
+ // drop the last metadata fields
+ return ret.slice(0, fields.size() - 1);
+ }
+
+ Type getMemRefElementType(SparseTensorFieldKind kind,
+ Optional<unsigned> dim) const {
+ return getMemRefField(kind, dim)
+ .getType()
+ .template cast<MemRefType>()
+ .getElementType();
+ }
+
+ RankedTensorType getTensorType() const { return rType; }
+ ValueArrayRef getFields() const { return fields; }
+
+private:
+ RankedTensorType rType;
+ ValueArrayRef fields;
+};
+
+using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
+using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
+
+/// Returns the "tuple" value of the adapted tensor.
+inline UnrealizedConversionCastOp getTuple(Value tensor) {
+ return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
+}
+
+/// Packs the given values as a "tuple" value.
+inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
+ ValueRange values) {
+ return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
+ .getResult(0);
+}
+
+inline Value genTuple(OpBuilder &builder, Location loc,
+ SparseTensorDescriptor desc) {
+ return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
+}
+
+inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
+ auto tuple = getTuple(tensor);
+ return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
+}
+
+inline MutSparseTensorDescriptor
+getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
+ auto tuple = getTuple(tensor);
+ fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
+ return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
+}
+
+} // namespace builder
+} // namespace sparse_tensor
+} // namespace mlir
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 8d01a322222a7..e199d90903cdc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -147,6 +147,7 @@ class SparsificationAndBufferizationPass
} else {
pm.addPass(createSparseTensorCodegenPass(enableBufferInitialization));
pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
+ pm.addPass(createStorageSpecifierToLLVMPass());
}
if (failed(runPipeline(pm, getOperation())))
return signalPassFailure();
diff --git a/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
new file mode 100644
index 0000000000000..ecdaf3bf9c964
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -sparse-storage-specifier-to-llvm --cse --canonicalize | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+
+// CHECK-LABEL: func.func @sparse_metadata_init() -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
+// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_1]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK: return %[[VAL_4]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK: }
+func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#CSR> {
+ %0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR>
+ return %0 : !sparse_tensor.storage_specifier<#CSR>
+}
+
+// CHECK-LABEL: func.func @sparse_get_md(
+// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> i64 {
+// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK: return %[[VAL_1]] : i64
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> i64 {
+ %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
+ : !sparse_tensor.storage_specifier<#CSR> to i64
+ return %0 : i64
+}
+
+// CHECK-LABEL: func.func @sparse_set_md(
+// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>,
+// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK: return %[[VAL_2]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: i64)
+ -> !sparse_tensor.storage_specifier<#CSR> {
+ %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
+ : i64, !sparse_tensor.storage_specifier<#CSR>
+ return %0 : !sparse_tensor.storage_specifier<#CSR>
+}
More information about the Mlir-commits
mailing list