[Mlir-commits] [mlir] 8237cac - [mlir][sparse] extend storage specifier operations for slices.
Peiming Liu
llvmlistbot at llvm.org
Fri Mar 10 10:58:53 PST 2023
Author: Peiming Liu
Date: 2023-03-10T18:58:47Z
New Revision: 8237cac612c6a8d00d673cee9c445f5aae2949d7
URL: https://github.com/llvm/llvm-project/commit/8237cac612c6a8d00d673cee9c445f5aae2949d7
DIFF: https://github.com/llvm/llvm-project/commit/8237cac612c6a8d00d673cee9c445f5aae2949d7.diff
LOG: [mlir][sparse] extend storage specifier operations for slices.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D141641
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index a4ea438b5d991..a5b96a86596eb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -368,6 +368,8 @@ def SparseTensorStorageSpecifierKindEnum
I32EnumAttrCase<"PosMemSize", 1, "pos_mem_sz">,
I32EnumAttrCase<"CrdMemSize", 2, "crd_mem_sz">,
I32EnumAttrCase<"ValMemSize", 3, "val_mem_sz">,
+ I32EnumAttrCase<"DimOffset", 4, "dim_offset">,
+ I32EnumAttrCase<"DimStride", 5, "dim_stride">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = SparseTensor_Dialect.cppNamespace;
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index dc6e795933431..336f19686a1c5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -358,21 +358,44 @@ def SparseTensor_ToSliceStrideOp : SparseTensor_Op<"slice.stride", [Pure]>,
}
def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>,
+ Arguments<(ins Optional<SparseTensorStorageSpecifier>:$source)>,
Results<(outs SparseTensorStorageSpecifier:$result)> {
let summary = "";
let description = [{
Returns an initial storage specifier value. A storage specifier
value holds the level-sizes, position arrays, coordinate arrays,
and the value array.
+ If this is a specifier for slices, it also holds the extra strides/offsets
+ for each tensor dimension.
+
+ TODO: The sparse tensor slice support is currently in a unstable state, and
+ is subject to change in the future.
Example:
```mlir
+ #CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ]}>
+ #CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (1, 4, 1), (1, 4, 2) ]
+ }>
+
%0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR>
+ %1 = sparse_tensor.storage_specifier.init with %src
+ : !sparse_tensor.storage_specifier<#CSR> to
+ !sparse_tensor.storage_specifier<#CSR_SLICE>
```
}];
- let assemblyFormat = "attr-dict `:` qualified(type($result))";
+ let builders = [
+ OpBuilder<(ins "Type":$result),
+ [{
+ build($_builder, $_state, result, Value());
+ }]>
+ ];
+
+ let assemblyFormat = "attr-dict (`with` $source^)? `:` (`from` qualified(type($source))^ `to`)?"
+ " qualified(type($result))";
}
def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get", [Pure]>,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 5220e4df2af4c..64112222f912a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -620,6 +620,12 @@ static LogicalResult verifySparsifierGetterSetter(
const auto enc = md.getType().getEncoding();
const Level lvlRank = enc.getLvlRank();
+ // TODO:
+ // if (mdKind == StorageSpecifierKind::DimOffset ||
+ // mdKind == StorageSpecifierKind::DimStride)
+ // if (!enc.isSlice())
+ // return op->emitError("requested slice data on non-slice tensor");
+
if (mdKind != StorageSpecifierKind::ValMemSize) {
if (!lvl)
return op->emitError("missing level argument");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index 39fe6e811ff13..f3a6adbf0eceb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -21,12 +21,12 @@ namespace {
// Helper methods.
//===----------------------------------------------------------------------===//
-static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
+static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
MLIRContext *ctx = tp.getContext();
auto enc = tp.getEncoding();
const Level lvlRank = enc.getLvlRank();
- SmallVector<Type, 2> result;
+ SmallVector<Type, 4> result;
// TODO: how can we get the lowering type for index type in the later pipeline
// to be consistent? LLVM::StructureType does not allow index fields.
auto sizeType = IntegerType::get(tp.getContext(), 64);
@@ -35,6 +35,16 @@ static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
getNumDataFieldsFromEncoding(enc));
result.push_back(lvlSizes);
result.push_back(memSizes);
+
+ if (enc.isSlice()) {
+ // Extra fields are required for the slice information.
+ auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
+ auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
+
+ result.push_back(dimOffset);
+ result.push_back(dimStride);
+ }
+
return result;
}
@@ -49,11 +59,13 @@ static Type convertSpecifier(StorageSpecifierType tp) {
constexpr uint64_t kLvlSizePosInSpecifier = 0;
constexpr uint64_t kMemSizePosInSpecifier = 1;
+constexpr uint64_t kDimOffsetPosInSpecifier = 2;
+constexpr uint64_t kDimStridePosInSpecifier = 3;
class SpecifierStructBuilder : public StructBuilder {
private:
Value extractField(OpBuilder &builder, Location loc,
- ArrayRef<int64_t> indices) {
+ ArrayRef<int64_t> indices) const {
return genCast(builder, loc,
builder.create<LLVM::ExtractValueOp>(loc, value, indices),
builder.getIndexType());
@@ -71,36 +83,69 @@ class SpecifierStructBuilder : public StructBuilder {
assert(value);
}
- // Undef value for level-sizes, all zero values for memory-sizes.
- static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
+ // Undef value for dimension sizes, all zero value for memory sizes.
+ static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
+ Value source);
- Value lvlSize(OpBuilder &builder, Location loc, Level lvl);
+ Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);
- Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx);
+ Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
+ void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
+ Value size);
+
+ Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
+ void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
+ Value size);
+
+ Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
Value size);
+
+ Value memSizeArray(OpBuilder &builder, Location loc) const;
+ void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
};
Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
- Type structType) {
+ Type structType, Value source) {
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
SpecifierStructBuilder md(metaData);
- auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
- .getBody()[kMemSizePosInSpecifier]
- .cast<LLVM::LLVMArrayType>();
+ if (!source) {
+ auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
+ .getBody()[kMemSizePosInSpecifier]
+ .cast<LLVM::LLVMArrayType>();
+
+ Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
+ // Fill memSizes array with zero.
+ for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
+ md.setMemSize(builder, loc, i, zero);
+ } else {
+ // We copy non-slice information (memory sizes array) from source
+ SpecifierStructBuilder sourceMd(source);
+ md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
+ }
+ return md;
+}
- Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
- // Fill memSizes array with zero.
- for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
- md.setMemSize(builder, loc, i, zero);
+/// Builds IR extracting the pos-th offset from the descriptor.
+Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
+ Dimension dim) const {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, value,
+ ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
+}
- return md;
+/// Builds IR inserting the pos-th offset into the descriptor.
+void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
+ Dimension dim, Value size) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, value, size,
+ ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
}
/// Builds IR extracting the `lvl`-th level-size from the descriptor.
Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
- Level lvl) {
+ Level lvl) const {
// This static_cast makes the narrowing of `lvl` explicit, as required
// by the braces notation for the ctor.
return extractField(
@@ -119,18 +164,52 @@ void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
size);
}
-/// Builds IR extracting the `fidx`-th memory-size from the descriptor.
+/// Builds IR extracting the pos-th stride from the descriptor.
+Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
+ Dimension dim) const {
+ return extractField(
+ builder, loc,
+ ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
+}
+
+/// Builds IR inserting the pos-th stride into the descriptor.
+void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
+ Dimension dim, Value size) {
+ insertField(
+ builder, loc,
+ ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
+ size);
+}
+
+/// Builds IR extracting the pos-th memory size into the descriptor.
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
- FieldIndex fidx) {
- return extractField(builder, loc,
- ArrayRef<int64_t>{kMemSizePosInSpecifier, fidx});
+ FieldIndex fidx) const {
+ return extractField(
+ builder, loc,
+ ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
}
/// Builds IR inserting the `fidx`-th memory-size into the descriptor.
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
FieldIndex fidx, Value size) {
- insertField(builder, loc, ArrayRef<int64_t>{kMemSizePosInSpecifier, fidx},
- size);
+ insertField(
+ builder, loc,
+ ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
+ size);
+}
+
+/// Builds IR extracting the memory size array from the descriptor.
+Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
+ Location loc) const {
+ return builder.create<LLVM::ExtractValueOp>(loc, value,
+ kMemSizePosInSpecifier);
+}
+
+/// Builds IR inserting the memory size array into the descriptor.
+void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
+ Value array) {
+ value = builder.create<LLVM::InsertValueOp>(loc, value, array,
+ kMemSizePosInSpecifier);
}
} // namespace
@@ -158,20 +237,37 @@ class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SpecifierStructBuilder spec(adaptor.getSpecifier());
- Value v;
- if (op.getSpecifierKind() == StorageSpecifierKind::LvlSize) {
- assert(op.getLevel().has_value());
- v = Base::onLvlSize(rewriter, op, spec, op.getLevel().value());
- } else {
+ switch (op.getSpecifierKind()) {
+ case StorageSpecifierKind::LvlSize: {
+ Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
+ rewriter.replaceOp(op, v);
+ return success();
+ }
+ case StorageSpecifierKind::DimOffset: {
+ Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
+ rewriter.replaceOp(op, v);
+ return success();
+ }
+ case StorageSpecifierKind::DimStride: {
+ Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
+ rewriter.replaceOp(op, v);
+ return success();
+ }
+ case StorageSpecifierKind::CrdMemSize:
+ case StorageSpecifierKind::PosMemSize:
+ case StorageSpecifierKind::ValMemSize: {
auto enc = op.getSpecifier().getType().getEncoding();
StorageLayout layout(enc);
- FieldIndex fidx =
- layout.getMemRefFieldIndex(op.getSpecifierKind(), op.getLevel());
- v = Base::onMemSize(rewriter, op, spec, fidx);
+ std::optional<unsigned> lvl;
+ if (op.getLevel())
+ lvl = (*op.getLevel());
+ unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), lvl);
+ Value v = Base::onMemSize(rewriter, op, spec, idx);
+ rewriter.replaceOp(op, v);
+ return success();
}
-
- rewriter.replaceOp(op, v);
- return success();
+ }
+ llvm_unreachable("unrecognized specifer kind");
}
};
@@ -179,12 +275,25 @@ struct StorageSpecifierSetOpConverter
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
SetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, Level lvl) {
spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
return spec;
}
+ static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
+ SpecifierStructBuilder &spec, Dimension d) {
+ spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
+ return spec;
+ }
+
+ static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
+ SpecifierStructBuilder &spec, Dimension d) {
+ spec.setDimStride(builder, op.getLoc(), d, op.getValue());
+ return spec;
+ }
+
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, FieldIndex fidx) {
spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
@@ -196,10 +305,22 @@ struct StorageSpecifierGetOpConverter
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
GetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, Level lvl) {
return spec.lvlSize(builder, op.getLoc(), lvl);
}
+
+ static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
+ const SpecifierStructBuilder &spec, Dimension d) {
+ return spec.dimOffset(builder, op.getLoc(), d);
+ }
+
+ static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
+ const SpecifierStructBuilder &spec, Dimension d) {
+ return spec.dimStride(builder, op.getLoc(), d);
+ }
+
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, FieldIndex fidx) {
return spec.memSize(builder, op.getLoc(), fidx);
@@ -214,8 +335,9 @@ struct StorageSpecifierInitOpConverter
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
- rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
- rewriter, op.getLoc(), llvmType));
+ rewriter.replaceOp(
+ op, SpecifierStructBuilder::getInitValue(
+ rewriter, op.getLoc(), llvmType, adaptor.getSource()));
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 18eb63c02d2e9..69cc3af3a5bdd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -59,6 +59,11 @@ namespace sparse_tensor {
// access the AOS coordinates array. In the code below, the method `getCOOStart`
// is used to find the start of the "trailing COO region".
//
+// If the sparse tensor is a slice (produced by `tensor.extract_slice`
+// operation), instead of allocating a new sparse tensor for it, it reuses the
+// same sets of MemRefs but attaching a additional set of slicing-metadata for
+// per-dimension slice offset and stride.
+//
// Examples.
//
// #CSR storage of 2-dim matrix yields
@@ -73,6 +78,15 @@ namespace sparse_tensor {
// memref<?xf64> ; values
// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
//
+// Slice on #COO storage of 2-dim matrix yields
+// ;; Inherited from the original sparse tensors
+// memref<?xindex>, ; positions-0, essentially [0,sz]
+// memref<?xindex> ; AOS coordinates storage
+// memref<?xf64> ; values
+// struct<(array<2 x i64>, array<3 x i64>, ; lvl0, lvl1, 3xsizes
+// ;; Extra slicing-metadata
+// array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride.
+//
//===----------------------------------------------------------------------===//
enum class SparseTensorFieldKind : uint32_t {
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7f6c2b2106adc..caf994cf8c192 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -259,6 +259,17 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
return %0 : index
}
+//// -----
+//
+//#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+//
+//func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+// // _e_xpected-error at +1 {{requested slice data on non-slice tensor}}
+// %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0
+// : !sparse_tensor.storage_specifier<#SparseVector> to i64
+// return %0 : i64
+//}
+
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 3b1569b7d6728..ff622a4bb408f 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -179,6 +179,25 @@ func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#SparseVec
// -----
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#SparseVector_Slice = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed"],
+ slice = [ (?, ?, ?) ]
+}>
+
+// CHECK-LABEL: func @sparse_metadata_init(
+// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>
+// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.init with %[[A]] :
+// CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}>
+func.func @sparse_metadata_init(%src : !sparse_tensor.storage_specifier<#SparseVector>)
+ -> !sparse_tensor.storage_specifier<#SparseVector_Slice> {
+ %0 = sparse_tensor.storage_specifier.init with %src : from !sparse_tensor.storage_specifier<#SparseVector>
+ to !sparse_tensor.storage_specifier<#SparseVector_Slice>
+ return %0 : !sparse_tensor.storage_specifier<#SparseVector_Slice>
+}
+
+// -----
+
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
// CHECK-LABEL: func @sparse_get_md(
@@ -191,6 +210,41 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
return %0 : index
}
+// -----
+
+#SparseVector_Slice = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed"],
+ slice = [ (?, ?, ?) ]
+}>
+
+// CHECK-LABEL: func @sparse_get_md(
+// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>
+// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_offset at 0
+// CHECK: return %[[T]] : index
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector_Slice>) -> index {
+ %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0
+ : !sparse_tensor.storage_specifier<#SparseVector_Slice>
+ return %0 : index
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed"],
+ slice = [ (?, ?, ?) ]
+}>
+
+// CHECK-LABEL: func @sparse_get_md(
+// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>
+// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_stride at 0
+// CHECK: return %[[T]] : index
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
+ %0 = sparse_tensor.storage_specifier.get %arg0 dim_stride at 0
+ : !sparse_tensor.storage_specifier<#SparseVector>
+ return %0 : index
+}
+
+
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
More information about the Mlir-commits
mailing list