[Mlir-commits] [mlir] 8237cac - [mlir][sparse] extend storage specifier operations for slices.

Peiming Liu llvmlistbot at llvm.org
Fri Mar 10 10:58:53 PST 2023


Author: Peiming Liu
Date: 2023-03-10T18:58:47Z
New Revision: 8237cac612c6a8d00d673cee9c445f5aae2949d7

URL: https://github.com/llvm/llvm-project/commit/8237cac612c6a8d00d673cee9c445f5aae2949d7
DIFF: https://github.com/llvm/llvm-project/commit/8237cac612c6a8d00d673cee9c445f5aae2949d7.diff

LOG: [mlir][sparse] extend storage specifier operations for slices.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D141641

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index a4ea438b5d991..a5b96a86596eb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -368,6 +368,8 @@ def SparseTensorStorageSpecifierKindEnum
         I32EnumAttrCase<"PosMemSize", 1, "pos_mem_sz">,
         I32EnumAttrCase<"CrdMemSize", 2, "crd_mem_sz">,
         I32EnumAttrCase<"ValMemSize", 3, "val_mem_sz">,
+        I32EnumAttrCase<"DimOffset",  4, "dim_offset">,
+        I32EnumAttrCase<"DimStride",  5, "dim_stride">,
       ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = SparseTensor_Dialect.cppNamespace;

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index dc6e795933431..336f19686a1c5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -358,21 +358,44 @@ def SparseTensor_ToSliceStrideOp : SparseTensor_Op<"slice.stride", [Pure]>,
 }
 
 def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>,
+    Arguments<(ins Optional<SparseTensorStorageSpecifier>:$source)>,
     Results<(outs SparseTensorStorageSpecifier:$result)> {
   let summary = "";
   let description = [{
     Returns an initial storage specifier value.  A storage specifier
     value holds the level-sizes, position arrays, coordinate arrays,
     and the value array.
+    If this is a specifier for slices, it also holds the extra strides/offsets
+    for each tensor dimension.
+
+    TODO: The sparse tensor slice support is currently in a unstable state, and
+    is subject to change in the future.
 
     Example:
 
     ```mlir
+    #CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ]}>
+    #CSR_SLICE = #sparse_tensor.encoding<{
+      dimLevelType = [ "dense", "compressed" ],
+      slice = [ (1, 4, 1), (1, 4, 2) ]
+    }>
+
     %0 = sparse_tensor.storage_specifier.init :  !sparse_tensor.storage_specifier<#CSR>
+    %1 = sparse_tensor.storage_specifier.init with %src
+         : !sparse_tensor.storage_specifier<#CSR> to
+           !sparse_tensor.storage_specifier<#CSR_SLICE>
     ```
   }];
 
-  let assemblyFormat = "attr-dict `:` qualified(type($result))";
+  let builders = [
+    OpBuilder<(ins "Type":$result),
+    [{
+      build($_builder, $_state, result, Value());
+    }]>
+  ];
+
+  let assemblyFormat = "attr-dict (`with` $source^)? `:` (`from` qualified(type($source))^ `to`)?"
+                                                        " qualified(type($result))";
 }
 
 def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get", [Pure]>,

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 5220e4df2af4c..64112222f912a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -620,6 +620,12 @@ static LogicalResult verifySparsifierGetterSetter(
   const auto enc = md.getType().getEncoding();
   const Level lvlRank = enc.getLvlRank();
 
+  // TODO:
+  //  if (mdKind == StorageSpecifierKind::DimOffset ||
+  //      mdKind == StorageSpecifierKind::DimStride)
+  //    if (!enc.isSlice())
+  //      return op->emitError("requested slice data on non-slice tensor");
+
   if (mdKind != StorageSpecifierKind::ValMemSize) {
     if (!lvl)
       return op->emitError("missing level argument");

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index 39fe6e811ff13..f3a6adbf0eceb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -21,12 +21,12 @@ namespace {
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
+static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
   MLIRContext *ctx = tp.getContext();
   auto enc = tp.getEncoding();
   const Level lvlRank = enc.getLvlRank();
 
-  SmallVector<Type, 2> result;
+  SmallVector<Type, 4> result;
   // TODO: how can we get the lowering type for index type in the later pipeline
   // to be consistent? LLVM::StructureType does not allow index fields.
   auto sizeType = IntegerType::get(tp.getContext(), 64);
@@ -35,6 +35,16 @@ static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
                                            getNumDataFieldsFromEncoding(enc));
   result.push_back(lvlSizes);
   result.push_back(memSizes);
+
+  if (enc.isSlice()) {
+    // Extra fields are required for the slice information.
+    auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
+    auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
+
+    result.push_back(dimOffset);
+    result.push_back(dimStride);
+  }
+
   return result;
 }
 
@@ -49,11 +59,13 @@ static Type convertSpecifier(StorageSpecifierType tp) {
 
 constexpr uint64_t kLvlSizePosInSpecifier = 0;
 constexpr uint64_t kMemSizePosInSpecifier = 1;
+constexpr uint64_t kDimOffsetPosInSpecifier = 2;
+constexpr uint64_t kDimStridePosInSpecifier = 3;
 
 class SpecifierStructBuilder : public StructBuilder {
 private:
   Value extractField(OpBuilder &builder, Location loc,
-                     ArrayRef<int64_t> indices) {
+                     ArrayRef<int64_t> indices) const {
     return genCast(builder, loc,
                    builder.create<LLVM::ExtractValueOp>(loc, value, indices),
                    builder.getIndexType());
@@ -71,36 +83,69 @@ class SpecifierStructBuilder : public StructBuilder {
     assert(value);
   }
 
-  // Undef value for level-sizes, all zero values for memory-sizes.
-  static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
+  // Undef value for dimension sizes, all zero value for memory sizes.
+  static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
+                            Value source);
 
-  Value lvlSize(OpBuilder &builder, Location loc, Level lvl);
+  Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
   void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);
 
-  Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx);
+  Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
+  void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
+                    Value size);
+
+  Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
+  void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
+                    Value size);
+
+  Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
   void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
                   Value size);
+
+  Value memSizeArray(OpBuilder &builder, Location loc) const;
+  void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
 };
 
 Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
-                                           Type structType) {
+                                           Type structType, Value source) {
   Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
   SpecifierStructBuilder md(metaData);
-  auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
-                              .getBody()[kMemSizePosInSpecifier]
-                              .cast<LLVM::LLVMArrayType>();
+  if (!source) {
+    auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
+                                .getBody()[kMemSizePosInSpecifier]
+                                .cast<LLVM::LLVMArrayType>();
+
+    Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
+    // Fill memSizes array with zero.
+    for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
+      md.setMemSize(builder, loc, i, zero);
+  } else {
+    // We copy non-slice information (memory sizes array) from source
+    SpecifierStructBuilder sourceMd(source);
+    md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
+  }
+  return md;
+}
 
-  Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
-  // Fill memSizes array with zero.
-  for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
-    md.setMemSize(builder, loc, i, zero);
+/// Builds IR extracting the pos-th offset from the descriptor.
+Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
+                                        Dimension dim) const {
+  return builder.create<LLVM::ExtractValueOp>(
+      loc, value,
+      ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
+}
 
-  return md;
+/// Builds IR inserting the pos-th offset into the descriptor.
+void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
+                                          Dimension dim, Value size) {
+  value = builder.create<LLVM::InsertValueOp>(
+      loc, value, size,
+      ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
 }
 
 /// Builds IR extracting the `lvl`-th level-size from the descriptor.
 Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
-                                      Level lvl) {
+                                      Level lvl) const {
   // This static_cast makes the narrowing of `lvl` explicit, as required
   // by the braces notation for the ctor.
   return extractField(
@@ -119,18 +164,52 @@ void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
       size);
 }
 
-/// Builds IR extracting the `fidx`-th memory-size from the descriptor.
+/// Builds IR extracting the pos-th stride from the descriptor.
+Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
+                                        Dimension dim) const {
+  return extractField(
+      builder, loc,
+      ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
+}
+
+/// Builds IR inserting the pos-th stride into the descriptor.
+void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
+                                          Dimension dim, Value size) {
+  insertField(
+      builder, loc,
+      ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
+      size);
+}
+
+/// Builds IR extracting the pos-th memory size into the descriptor.
 Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
-                                      FieldIndex fidx) {
-  return extractField(builder, loc,
-                      ArrayRef<int64_t>{kMemSizePosInSpecifier, fidx});
+                                      FieldIndex fidx) const {
+  return extractField(
+      builder, loc,
+      ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
 }
 
 /// Builds IR inserting the `fidx`-th memory-size into the descriptor.
 void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
                                         FieldIndex fidx, Value size) {
-  insertField(builder, loc, ArrayRef<int64_t>{kMemSizePosInSpecifier, fidx},
-              size);
+  insertField(
+      builder, loc,
+      ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
+      size);
+}
+
+/// Builds IR extracting the memory size array from the descriptor.
+Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
+                                           Location loc) const {
+  return builder.create<LLVM::ExtractValueOp>(loc, value,
+                                              kMemSizePosInSpecifier);
+}
+
+/// Builds IR inserting the memory size array into the descriptor.
+void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
+                                             Value array) {
+  value = builder.create<LLVM::InsertValueOp>(loc, value, array,
+                                              kMemSizePosInSpecifier);
 }
 
 } // namespace
@@ -158,20 +237,37 @@ class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SpecifierStructBuilder spec(adaptor.getSpecifier());
-    Value v;
-    if (op.getSpecifierKind() == StorageSpecifierKind::LvlSize) {
-      assert(op.getLevel().has_value());
-      v = Base::onLvlSize(rewriter, op, spec, op.getLevel().value());
-    } else {
+    switch (op.getSpecifierKind()) {
+    case StorageSpecifierKind::LvlSize: {
+      Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
+      rewriter.replaceOp(op, v);
+      return success();
+    }
+    case StorageSpecifierKind::DimOffset: {
+      Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
+      rewriter.replaceOp(op, v);
+      return success();
+    }
+    case StorageSpecifierKind::DimStride: {
+      Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
+      rewriter.replaceOp(op, v);
+      return success();
+    }
+    case StorageSpecifierKind::CrdMemSize:
+    case StorageSpecifierKind::PosMemSize:
+    case StorageSpecifierKind::ValMemSize: {
       auto enc = op.getSpecifier().getType().getEncoding();
       StorageLayout layout(enc);
-      FieldIndex fidx =
-          layout.getMemRefFieldIndex(op.getSpecifierKind(), op.getLevel());
-      v = Base::onMemSize(rewriter, op, spec, fidx);
+      std::optional<unsigned> lvl;
+      if (op.getLevel())
+        lvl = (*op.getLevel());
+      unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), lvl);
+      Value v = Base::onMemSize(rewriter, op, spec, idx);
+      rewriter.replaceOp(op, v);
+      return success();
     }
-
-    rewriter.replaceOp(op, v);
-    return success();
+    }
+    llvm_unreachable("unrecognized specifer kind");
   }
 };
 
@@ -179,12 +275,25 @@ struct StorageSpecifierSetOpConverter
     : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
                                               SetStorageSpecifierOp> {
   using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+
   static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
                          SpecifierStructBuilder &spec, Level lvl) {
     spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
     return spec;
   }
 
+  static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
+                           SpecifierStructBuilder &spec, Dimension d) {
+    spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
+    return spec;
+  }
+
+  static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
+                           SpecifierStructBuilder &spec, Dimension d) {
+    spec.setDimStride(builder, op.getLoc(), d, op.getValue());
+    return spec;
+  }
+
   static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
                          SpecifierStructBuilder &spec, FieldIndex fidx) {
     spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
@@ -196,10 +305,22 @@ struct StorageSpecifierGetOpConverter
     : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
                                               GetStorageSpecifierOp> {
   using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+
   static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
                          SpecifierStructBuilder &spec, Level lvl) {
     return spec.lvlSize(builder, op.getLoc(), lvl);
   }
+
+  static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
+                           const SpecifierStructBuilder &spec, Dimension d) {
+    return spec.dimOffset(builder, op.getLoc(), d);
+  }
+
+  static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
+                           const SpecifierStructBuilder &spec, Dimension d) {
+    return spec.dimStride(builder, op.getLoc(), d);
+  }
+
   static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
                          SpecifierStructBuilder &spec, FieldIndex fidx) {
     return spec.memSize(builder, op.getLoc(), fidx);
@@ -214,8 +335,9 @@ struct StorageSpecifierInitOpConverter
   matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
-    rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
-                               rewriter, op.getLoc(), llvmType));
+    rewriter.replaceOp(
+        op, SpecifierStructBuilder::getInitValue(
+                rewriter, op.getLoc(), llvmType, adaptor.getSource()));
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 18eb63c02d2e9..69cc3af3a5bdd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -59,6 +59,11 @@ namespace sparse_tensor {
 // access the AOS coordinates array. In the code below, the method `getCOOStart`
 // is used to find the start of the "trailing COO region".
 //
+// If the sparse tensor is a slice (produced by `tensor.extract_slice`
+// operation), instead of allocating a new sparse tensor for it, it reuses the
+// same sets of MemRefs but attaching a additional set of slicing-metadata for
+// per-dimension slice offset and stride.
+//
 // Examples.
 //
 // #CSR storage of 2-dim matrix yields
@@ -73,6 +78,15 @@ namespace sparse_tensor {
 //   memref<?xf64>                             ; values
 //   struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes
 //
+// Slice on #COO storage of 2-dim matrix yields
+//   ;; Inherited from the original sparse tensors
+//   memref<?xindex>,                          ; positions-0, essentially [0,sz]
+//   memref<?xindex>                           ; AOS coordinates storage
+//   memref<?xf64>                             ; values
+//   struct<(array<2 x i64>, array<3 x i64>,   ; lvl0, lvl1, 3xsizes
+//   ;; Extra slicing-metadata
+//           array<2 x i64>, array<2 x i64>)>) ; dim offset, dim stride.
+//
 //===----------------------------------------------------------------------===//
 
 enum class SparseTensorFieldKind : uint32_t {

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7f6c2b2106adc..caf994cf8c192 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -259,6 +259,17 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
   return %0 : index
 }
 
+//// -----
+//
+//#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+//
+//func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+//  // _e_xpected-error at +1 {{requested slice data on non-slice tensor}}
+//  %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0
+//       : !sparse_tensor.storage_specifier<#SparseVector> to i64
+//  return %0 : i64
+//}
+
 // -----
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 3b1569b7d6728..ff622a4bb408f 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -179,6 +179,25 @@ func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#SparseVec
 
 // -----
 
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#SparseVector_Slice = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"],
+  slice = [ (?, ?, ?) ]
+}>
+
+// CHECK-LABEL: func @sparse_metadata_init(
+//  CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>
+//       CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.init with %[[A]] :
+//       CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}>
+func.func @sparse_metadata_init(%src : !sparse_tensor.storage_specifier<#SparseVector>)
+                                    -> !sparse_tensor.storage_specifier<#SparseVector_Slice> {
+  %0 = sparse_tensor.storage_specifier.init with %src : from !sparse_tensor.storage_specifier<#SparseVector>
+                                                          to !sparse_tensor.storage_specifier<#SparseVector_Slice>
+  return %0 : !sparse_tensor.storage_specifier<#SparseVector_Slice>
+}
+
+// -----
+
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
 // CHECK-LABEL: func @sparse_get_md(
@@ -191,6 +210,41 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
   return %0 : index
 }
 
+// -----
+
+#SparseVector_Slice = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"],
+  slice = [ (?, ?, ?) ]
+}>
+
+// CHECK-LABEL: func @sparse_get_md(
+//  CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>
+//       CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_offset at 0
+//       CHECK: return %[[T]] : index
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector_Slice>) -> index {
+  %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0
+       : !sparse_tensor.storage_specifier<#SparseVector_Slice>
+  return %0 : index
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"],
+  slice = [ (?, ?, ?) ]
+}>
+
+// CHECK-LABEL: func @sparse_get_md(
+//  CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>
+//       CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_stride at 0
+//       CHECK: return %[[T]] : index
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
+  %0 = sparse_tensor.storage_specifier.get %arg0 dim_stride at 0
+       : !sparse_tensor.storage_specifier<#SparseVector>
+  return %0 : index
+}
+
+
 // -----
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>


        


More information about the Mlir-commits mailing list