[Mlir-commits] [mlir] 083ddff - [mlir][sparse] introduce sparse_tensor::StorageSpecifierToLLVM pass

Peiming Liu llvmlistbot at llvm.org
Thu Dec 22 14:45:20 PST 2022


Author: Peiming Liu
Date: 2022-12-22T22:45:15Z
New Revision: 083ddffe476b73f7c05b06c3f6f54ebfcbf34727

URL: https://github.com/llvm/llvm-project/commit/083ddffe476b73f7c05b06c3f6f54ebfcbf34727
DIFF: https://github.com/llvm/llvm-project/commit/083ddffe476b73f7c05b06c3f6f54ebfcbf34727.diff

LOG: [mlir][sparse] introduce sparse_tensor::StorageSpecifierToLLVM pass

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140122

Added: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
    mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index e8396b6bb7716..fc4ab2cb792ff 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -52,9 +52,9 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
 
   let parameters = (ins SparseTensorEncodingAttr : $encoding);
   let builders = [
+    TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>,
     TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{
-      assert(encoding && "sparse tensor encoding should not be null");
-      return $_get(encoding.getContext(), encoding);
+      return get(encoding.getContext(), encoding);
     }]>,
     TypeBuilderWithInferredContext<(ins "Type":$type), [{
       return get(getSparseTensorEncoding(type));
@@ -71,6 +71,10 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
     Type getFieldType(StorageSpecifierKind kind, std::optional<APInt> dim) const;
   }];
 
+  // We skipped the default builder that simply takes the input sparse tensor encoding
+  // attribute since we need to normalize the dimension level type and remove unrelated
+  // fields that are irrelavant to sparse tensor storage scheme.
+  let skipDefaultBuilders = 1;
   let assemblyFormat="`<` qualified($encoding) `>`";
 }
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d4e3210a6a342..df465e443d36a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -158,6 +158,19 @@ std::unique_ptr<Pass>
 createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
                                     bool enableConvert = true);
 
+//===----------------------------------------------------------------------===//
+// The SparseStorageSpecifierToLLVM pass.
+//===----------------------------------------------------------------------===//
+
+class StorageSpecifierToLLVMTypeConverter : public TypeConverter {
+public:
+  StorageSpecifierToLLVMTypeConverter();
+};
+
+void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
+                                            RewritePatternSet &patterns);
+std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
+
 //===----------------------------------------------------------------------===//
 // Other rewriting rules and passes.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 4c827c1432056..6ec9d4eb9e46e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -301,4 +301,28 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
   ];
 }
 
+def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"> {
+  let summary = "Lower sparse storage specifer to llvm structure";
+  let description = [{
+     This pass rewrites sparse tensor storage specifier-related operations into
+     LLVMDialect, and converts sparse tensor storage specifier into an llvm.struct.
+
+     Example of the conversion:
+     ```mlir
+     Before:
+       %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
+       : !sparse_tensor.storage_specifier<#CSR> to i64
+
+     After:
+       %0 = llvm.extractvalue %arg0[0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+     ```
+  }];
+  let constructor = "mlir::createStorageSpecifierToLLVMPass()";
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "LLVM::LLVMDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e4e54b7143319..1e9aab8deb172 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -323,6 +323,28 @@ uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) {
 // SparseTensorDialect Types.
 //===----------------------------------------------------------------------===//
 
+/// We normalized sparse tensor encoding attribute by always using
+/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
+/// as other variants) lead to the same storage specifier type, and stripping
+/// irrelevant fields that does not alter the sparse tensor memory layout.
+static SparseTensorEncodingAttr
+getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
+  SmallVector<DimLevelType> dlts;
+  for (auto dlt : enc.getDimLevelType())
+    dlts.push_back(*getDimLevelType(*getLevelFormat(dlt), true, true));
+
+  return SparseTensorEncodingAttr::get(
+      enc.getContext(), dlts,
+      AffineMap(), // dimOrdering (irrelavant to storage speicifer)
+      AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
+      enc.getPointerBitWidth(), enc.getIndexBitWidth());
+}
+
+StorageSpecifierType
+StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
+  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
+}
+
 IntegerType StorageSpecifierType::getSizesType() const {
   unsigned idxBitWidth =
       getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index cfa73bf246ecc..410bf343b8fc1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -3,10 +3,12 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   CodegenEnv.cpp
   CodegenUtils.cpp
   SparseBufferRewriting.cpp
+  SparseStorageSpecifierToLLVM.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
   SparseTensorRewriting.cpp
+  SparseTensorStorageLayout.cpp
   SparseVectorization.cpp
   Sparsification.cpp
   SparsificationAndBufferizationPass.cpp

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
new file mode 100644
index 0000000000000..d6a9007baad0c
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -0,0 +1,184 @@
+//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+#include "SparseTensorStorageLayout.h"
+
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace sparse_tensor;
+
+static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
+  MLIRContext *ctx = tp.getContext();
+  auto enc = tp.getEncoding();
+  unsigned rank = enc.getDimLevelType().size();
+
+  SmallVector<Type, 2> result;
+  auto indexType = tp.getSizesType();
+  auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank);
+  auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
+                                           getNumDataFieldsFromEncoding(enc));
+  result.push_back(dimSizes);
+  result.push_back(memSizes);
+  return result;
+}
+
+static Type convertSpecifier(StorageSpecifierType tp) {
+  return LLVM::LLVMStructType::getLiteral(tp.getContext(),
+                                          getSpecifierFields(tp));
+}
+
+StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
+}
+
+constexpr uint64_t kDimSizePosInSpecifier = 0;
+constexpr uint64_t kMemSizePosInSpecifier = 1;
+
+class SpecifierStructBuilder : public StructBuilder {
+public:
+  explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
+    assert(value);
+  }
+
+  // Undef value for dimension sizes, all zero value for memory sizes.
+  static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
+
+  Value dimSize(OpBuilder &builder, Location loc, unsigned dim);
+  void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size);
+
+  Value memSize(OpBuilder &builder, Location loc, unsigned pos);
+  void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
+};
+
+Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
+                                           Type structType) {
+  Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
+  SpecifierStructBuilder md(metaData);
+  auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
+                              .getBody()[kMemSizePosInSpecifier]
+                              .cast<LLVM::LLVMArrayType>();
+
+  Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
+  // Fill memSizes array with zero.
+  for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
+    md.setMemSize(builder, loc, i, zero);
+
+  return md;
+}
+
+/// Builds IR inserting the pos-th size into the descriptor.
+Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
+                                      unsigned dim) {
+  return builder.create<LLVM::ExtractValueOp>(
+      loc, value, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
+}
+
+/// Builds IR inserting the pos-th size into the descriptor.
+void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
+                                        unsigned dim, Value size) {
+  value = builder.create<LLVM::InsertValueOp>(
+      loc, value, size, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
+}
+
+/// Builds IR extracting the pos-th memory size into the descriptor.
+Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
+                                      unsigned pos) {
+  return builder.create<LLVM::ExtractValueOp>(
+      loc, value, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
+}
+
+/// Builds IR inserting the pos-th memory size into the descriptor.
+void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
+                                        unsigned pos, Value size) {
+  value = builder.create<LLVM::InsertValueOp>(
+      loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
+}
+
+template <typename Base, typename SourceOp>
+class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
+public:
+  using OpAdaptor = typename SourceOp::Adaptor;
+  using OpConversionPattern<SourceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SpecifierStructBuilder spec(adaptor.getSpecifier());
+    Value v;
+    if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) {
+      v = Base::onDimSize(rewriter, op, spec,
+                          op.getDim().value().getZExtValue());
+    } else {
+      auto enc = op.getSpecifier().getType().getEncoding();
+      builder::StorageLayout layout(enc);
+      Optional<unsigned> dim = std::nullopt;
+      if (op.getDim())
+        dim = op.getDim().value().getZExtValue();
+      unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim);
+      v = Base::onMemSize(rewriter, op, spec, idx);
+    }
+
+    rewriter.replaceOp(op, v);
+    return success();
+  }
+};
+
+struct StorageSpecifierSetOpConverter
+    : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
+                                              SetStorageSpecifierOp> {
+  using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+  static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op,
+                         SpecifierStructBuilder &spec, unsigned d) {
+    spec.setDimSize(builder, op.getLoc(), d, op.getValue());
+    return spec;
+  }
+
+  static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
+                         SpecifierStructBuilder &spec, unsigned i) {
+    spec.setMemSize(builder, op.getLoc(), i, op.getValue());
+    return spec;
+  }
+};
+
+struct StorageSpecifierGetOpConverter
+    : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
+                                              GetStorageSpecifierOp> {
+  using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
+  static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op,
+                         SpecifierStructBuilder &spec, unsigned d) {
+    return spec.dimSize(builder, op.getLoc(), d);
+  }
+  static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
+                         SpecifierStructBuilder &spec, unsigned i) {
+    return spec.memSize(builder, op.getLoc(), i);
+  }
+};
+
+struct StorageSpecifierInitOpConverter
+    : public OpConversionPattern<StorageSpecifierInitOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
+    rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
+                               rewriter, op.getLoc(), llvmType));
+    return success();
+  }
+};
+
+void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
+                                                  RewritePatternSet &patterns) {
+  patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
+               StorageSpecifierInitOpConverter>(converter,
+                                                patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 6845aa0d81877..c384d6ad966c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -28,6 +28,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
 #define GEN_PASS_DEF_SPARSEVECTORIZATION
+#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -193,9 +194,14 @@ struct SparseTensorCodegenPass
     target.addLegalOp<SortOp>();
     target.addLegalOp<SortCooOp>();
     target.addLegalOp<PushBackOp>();
-    // All dynamic rules below accept new function, call, return, and various
-    // tensor and bufferization operations as legal output of the rewriting
-    // provided that all sparse tensor types have been fully rewritten.
+    // Storage specifier outlives sparse tensor pipeline.
+    target.addLegalOp<GetStorageSpecifierOp>();
+    target.addLegalOp<SetStorageSpecifierOp>();
+    target.addLegalOp<StorageSpecifierInitOp>();
+    // All dynamic rules below accept new function, call, return, and
+    // various tensor and bufferization operations as legal output of the
+    // rewriting provided that all sparse tensor types have been fully
+    // rewritten.
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return converter.isSignatureLegal(op.getFunctionType());
     });
@@ -271,6 +277,44 @@ struct SparseVectorizationPass
   }
 };
 
+struct StorageSpecifierToLLVMPass
+    : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
+
+  StorageSpecifierToLLVMPass() = default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    ConversionTarget target(*ctx);
+    RewritePatternSet patterns(ctx);
+    StorageSpecifierToLLVMTypeConverter converter;
+
+    // All ops in the sparse dialect must go!
+    target.addIllegalDialect<SparseTensorDialect>();
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return converter.isSignatureLegal(op.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+      return converter.isSignatureLegal(op.getCalleeType());
+    });
+    target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+      return converter.isLegal(op.getOperandTypes());
+    });
+    target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
+
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
+    populateReturnOpTypeConversionPattern(patterns, converter);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    populateStorageSpecifierToLLVMPatterns(converter, patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -355,3 +399,7 @@ mlir::createSparseVectorizationPass(unsigned vectorLength,
   return std::make_unique<SparseVectorizationPass>(
       vectorLength, enableVLAVectorization, enableSIMDIndex32);
 }
+
+std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
+  return std::make_unique<StorageSpecifierToLLVMPass>();
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
new file mode 100644
index 0000000000000..4dcd0345ed9db
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -0,0 +1,188 @@
+//===- SparseTensorStorageLayout.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "SparseTensorStorageLayout.h"
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace sparse_tensor;
+
+static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
+                             Type to) {
+  if (value.getType() != to)
+    return builder.create<arith::IndexCastOp>(loc, to, value);
+  return value;
+}
+
+static IntegerAttr fromOptionalInt(MLIRContext *ctx, Optional<unsigned> dim) {
+  if (!dim)
+    return nullptr;
+  return IntegerAttr::get(IndexType::get(ctx), dim.value());
+}
+
+unsigned
+builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
+                                            Optional<unsigned> dim) const {
+  unsigned fieldIdx = -1u;
+  foreachFieldInSparseTensor(
+      enc,
+      [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind,
+                             unsigned fDim, DimLevelType dlt) -> bool {
+        if ((dim && fDim == dim.value() && kind == fKind) ||
+            (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
+          fieldIdx = fIdx;
+          // Returns false to break the iteration.
+          return false;
+        }
+        return true;
+      });
+  assert(fieldIdx != -1u);
+  return fieldIdx;
+}
+
+unsigned
+builder::StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind,
+                                            Optional<unsigned> dim) const {
+  return getMemRefFieldIndex(toFieldKind(kind), dim);
+}
+
+Value builder::SparseTensorSpecifier::getInitValue(OpBuilder &builder,
+                                                   Location loc,
+                                                   RankedTensorType rtp) {
+  return builder.create<StorageSpecifierInitOp>(
+      loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp)));
+}
+
+Value builder::SparseTensorSpecifier::getSpecifierField(
+    OpBuilder &builder, Location loc, StorageSpecifierKind kind,
+    Optional<unsigned> dim) {
+  return createIndexCast(builder, loc,
+                         builder.create<GetStorageSpecifierOp>(
+                             loc, getFieldType(kind, dim), specifier, kind,
+                             fromOptionalInt(specifier.getContext(), dim)),
+                         builder.getIndexType());
+}
+
+void builder::SparseTensorSpecifier::setSpecifierField(
+    OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind,
+    Optional<unsigned> dim) {
+  specifier = builder.create<SetStorageSpecifierOp>(
+      loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim),
+      createIndexCast(builder, loc, v, getFieldType(kind, dim)));
+}
+
+constexpr uint64_t kDataFieldStartingIdx = 0;
+
+void sparse_tensor::builder::foreachFieldInSparseTensor(
+    const SparseTensorEncodingAttr enc,
+    llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
+                            DimLevelType)>
+        callback) {
+  assert(enc);
+
+#define RETURN_ON_FALSE(idx, kind, dim, dlt)                                   \
+  if (!(callback(idx, kind, dim, dlt)))                                        \
+    return;
+
+  static_assert(kDataFieldStartingIdx == 0);
+  unsigned fieldIdx = kDataFieldStartingIdx;
+  // Per-dimension storage.
+  for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) {
+    // Dimension level types apply in order to the reordered dimension.
+    // As a result, the compound type can be constructed directly in the given
+    // order.
+    auto dlt = getDimLevelType(enc, r);
+    if (isCompressedDLT(dlt)) {
+      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
+      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
+    } else if (isSingletonDLT(dlt)) {
+      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
+    } else {
+      assert(isDenseDLT(dlt)); // no fields
+    }
+  }
+  // The values array.
+  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
+                  DimLevelType::Undef);
+
+  // Put metadata at the end.
+  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u,
+                  DimLevelType::Undef);
+
+#undef RETURN_ON_FALSE
+}
+
+void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor(
+    RankedTensorType rType,
+    llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
+                            DimLevelType)>
+        callback) {
+  auto enc = getSparseTensorEncoding(rType);
+  assert(enc);
+  // Construct the basic types.
+  Type idxType = enc.getIndexType();
+  Type ptrType = enc.getPointerType();
+  Type eltType = rType.getElementType();
+
+  Type metaDataType = StorageSpecifierType::get(enc);
+  // memref<? x ptr>  pointers
+  Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
+  // memref<? x idx>  indices
+  Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType);
+  // memref<? x eltType> values
+  Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
+
+  foreachFieldInSparseTensor(
+      enc,
+      [metaDataType, ptrMemType, idxMemType, valMemType,
+       callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
+                 unsigned dim, DimLevelType dlt) -> bool {
+        switch (fieldKind) {
+        case SparseTensorFieldKind::StorageSpec:
+          return callback(metaDataType, fieldIdx, fieldKind, dim, dlt);
+        case SparseTensorFieldKind::PtrMemRef:
+          return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt);
+        case SparseTensorFieldKind::IdxMemRef:
+          return callback(idxMemType, fieldIdx, fieldKind, dim, dlt);
+        case SparseTensorFieldKind::ValMemRef:
+          return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
+        };
+        llvm_unreachable("unrecognized field kind");
+      });
+}
+
+unsigned
+sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
+  unsigned numFields = 0;
+  foreachFieldInSparseTensor(enc,
+                             [&numFields](unsigned, SparseTensorFieldKind,
+                                          unsigned, DimLevelType) -> bool {
+                               numFields++;
+                               return true;
+                             });
+  return numFields;
+}
+
+unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding(
+    SparseTensorEncodingAttr enc) {
+  unsigned numFields = 0; // one value memref
+  foreachFieldInSparseTensor(enc,
+                             [&numFields](unsigned fidx, SparseTensorFieldKind,
+                                          unsigned, DimLevelType) -> bool {
+                               if (fidx >= kDataFieldStartingIdx)
+                                 numFields++;
+                               return true;
+                             });
+  numFields -= 1; // the last field is MetaData field
+  assert(numFields ==
+         builder::getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
+  return numFields;
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
new file mode 100644
index 0000000000000..9b4e2352b8f34
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -0,0 +1,361 @@
+//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for lowering and accessing sparse tensor
+// types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
+
+#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+// FIXME: this is a tmp namespace
+namespace builder {
+//===----------------------------------------------------------------------===//
+// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout
+// scheme.
+//
+// Sparse tensor storage scheme for rank-dimensional tensor is organized
+// as a single compound type with the following fields. Note that every
+// memref with ? size actually behaves as a "vector", i.e. the stored
+// size is the capacity and the used size resides in the memSizes array.
+//
+// struct {
+//   ; per-dimension d:
+//   ;  if dense:
+//        <nothing>
+//   ;  if compresed:
+//        memref<? x ptr>  pointers-d  ; pointers for sparse dim d
+//        memref<? x idx>  indices-d   ; indices for sparse dim d
+//   ;  if singleton:
+//        memref<? x idx>  indices-d   ; indices for singleton dim d
+//   memref<? x eltType> values        ; values
+//
+//   ; sparse tensor metadata
+//   struct {
+//     array<rank x int> dimSizes    ; sizes for each dimension
+//     array<n x int> memSizes;      ; sizes for each data memref
+//   }
+// };
+//
+//===----------------------------------------------------------------------===//
+enum class SparseTensorFieldKind : uint32_t {
+  StorageSpec = 0,
+  PtrMemRef = 1,
+  IdxMemRef = 2,
+  ValMemRef = 3
+};
+
+static_assert(static_cast<uint32_t>(SparseTensorFieldKind::PtrMemRef) ==
+              static_cast<uint32_t>(StorageSpecifierKind::PtrMemSize));
+static_assert(static_cast<uint32_t>(SparseTensorFieldKind::IdxMemRef) ==
+              static_cast<uint32_t>(StorageSpecifierKind::IdxMemSize));
+static_assert(static_cast<uint32_t>(SparseTensorFieldKind::ValMemRef) ==
+              static_cast<uint32_t>(StorageSpecifierKind::ValMemSize));
+
+/// For each field that will be allocated for the given sparse tensor encoding,
+/// calls the callback with the corresponding field index, field kind, dimension
+/// (for sparse tensor level memrefs) and dimlevelType.
+/// The field index always starts with zero and increments by one between two
+/// callback invocations.
+/// Ideally, all other methods should rely on this function to query a sparse
+/// tensor fields instead of relying on ad-hoc index computation.
+void foreachFieldInSparseTensor(
+    SparseTensorEncodingAttr,
+    llvm::function_ref<bool(unsigned /*fieldIdx*/,
+                            SparseTensorFieldKind /*fieldKind*/,
+                            unsigned /*dim (if applicable)*/,
+                            DimLevelType /*DLT (if applicable)*/)>);
+
+/// Same as above, except that it also builds the Type for the corresponding
+/// field.
+void foreachFieldAndTypeInSparseTensor(
+    RankedTensorType,
+    llvm::function_ref<bool(Type /*fieldType*/, unsigned /*fieldIdx*/,
+                            SparseTensorFieldKind /*fieldKind*/,
+                            unsigned /*dim (if applicable)*/,
+                            DimLevelType /*DLT (if applicable)*/)>);
+
+/// Gets the total number of fields for the given sparse tensor encoding.
+unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
+
+/// Gets the total number of data fields (index arrays, pointer arrays, and a
+/// value array) for the given sparse tensor encoding.
+unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
+
+inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
+  assert(kind != SparseTensorFieldKind::StorageSpec);
+  return static_cast<StorageSpecifierKind>(kind);
+}
+
+inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
+  assert(kind != StorageSpecifierKind::DimSize);
+  return static_cast<SparseTensorFieldKind>(kind);
+}
+
+class StorageLayout {
+public:
+  explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {}
+
+  ///
+  /// Getters: get the field index for required field.
+  ///
+  unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
+                               Optional<unsigned> dim) const;
+
+  unsigned getMemRefFieldIndex(StorageSpecifierKind kind,
+                               Optional<unsigned> dim) const;
+
+private:
+  unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const;
+  SparseTensorEncodingAttr enc;
+};
+
+class SparseTensorSpecifier {
+public:
+  explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {}
+
+  // Undef value for dimension sizes, all zero value for memory sizes.
+  static Value getInitValue(OpBuilder &builder, Location loc,
+                            RankedTensorType rtp);
+
+  /*implicit*/ operator Value() { return specifier; }
+
+  Value getSpecifierField(OpBuilder &builder, Location loc,
+                          StorageSpecifierKind kind, Optional<unsigned> dim);
+
+  void setSpecifierField(OpBuilder &builder, Location loc, Value v,
+                         StorageSpecifierKind kind, Optional<unsigned> dim);
+
+  Type getFieldType(StorageSpecifierKind kind, Optional<unsigned> dim) {
+    return specifier.getType().getFieldType(kind, dim);
+  }
+
+private:
+  TypedValue<StorageSpecifierType> specifier;
+};
+
+/// A helper class around an array of values that corresponding to a sparse
+/// tensor, provides a set of meaningful APIs to query and update a particular
+/// field in a consistent way.
+/// Users should not make assumption on how a sparse tensor is laid out but
+/// instead relies on this class to access the right value for the right field.
+template <bool mut>
+class SparseTensorDescriptorImpl {
+private:
+  // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
+  // for mutable descriptors.
+  // Using SmallVector for mutable descriptor allows users to reuse it as a tmp
+  // buffers to append value for some special cases, though users should be
+  // responsible to restore the buffer to legal states after their use. It is
+  // probably not a clean way, but it is the most efficient way to avoid copying
+  // the fields into another SmallVector. If a more clear way is wanted, we
+  // should change it to MutableArrayRef instead.
+  using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
+                                                  ValueRange>::type;
+
+public:
+  SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
+      : rType(tp.cast<RankedTensorType>()), fields(fields) {
+    assert(getSparseTensorEncoding(tp) &&
+           builder::getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
+               fields.size());
+    // We should make sure the class is trivially copyable (and should be small
+    // enough) such that we can pass it by value.
+    static_assert(
+        std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
+  }
+
+  // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to
+  // SparseTensorDescriptor.
+  template <typename T = SparseTensorDescriptorImpl<true>>
+  /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t<!mut, T> &mDesc)
+      : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {}
+
+  unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
+                               Optional<unsigned> dim) const {
+    // Delegates to storage layout.
+    StorageLayout layout(getSparseTensorEncoding(rType));
+    return layout.getMemRefFieldIndex(kind, dim);
+  }
+
+  unsigned getPtrMemRefIndex(unsigned ptrDim) const {
+    return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim);
+  }
+
+  unsigned getIdxMemRefIndex(unsigned idxDim) const {
+    return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim);
+  }
+
+  unsigned getValMemRefIndex() const {
+    return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt);
+  }
+
+  unsigned getNumFields() const { return fields.size(); }
+
+  ///
+  /// Getters: get the value for required field.
+  ///
+
+  Value getSpecifierField(OpBuilder &builder, Location loc,
+                          StorageSpecifierKind kind,
+                          Optional<unsigned> dim) const {
+    SparseTensorSpecifier md(fields.back());
+    return md.getSpecifierField(builder, loc, kind, dim);
+  }
+
+  Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim);
+  }
+
+  Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
+                             dim);
+  }
+
+  Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
+                             dim);
+  }
+
+  Value getValMemSize(OpBuilder &builder, Location loc) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+                             std::nullopt);
+  }
+
+  Value getPtrMemRef(unsigned ptrDim) const {
+    return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
+  }
+
+  Value getIdxMemRef(unsigned idxDim) const {
+    return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim);
+  }
+
+  Value getValMemRef() const {
+    return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
+  }
+
+  Value getMemRefField(SparseTensorFieldKind kind,
+                       Optional<unsigned> dim) const {
+    return fields[getMemRefFieldIndex(kind, dim)];
+  }
+
+  Value getMemRefField(unsigned fidx) const {
+    assert(fidx < fields.size() - 1);
+    return fields[fidx];
+  }
+
+  Value getField(unsigned fidx) const {
+    assert(fidx < fields.size());
+    return fields[fidx];
+  }
+
+  ///
+  /// Setters: update the value for required field (only enabled for
+  /// MutSparseTensorDescriptor).
+  ///
+
+  template <typename T = Value>
+  void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim,
+                      std::enable_if_t<mut, T> v) {
+    fields[getMemRefFieldIndex(kind, dim)] = v;
+  }
+
+  template <typename T = Value>
+  void setMemRefField(unsigned fidx, std::enable_if_t<mut, T> v) {
+    assert(fidx < fields.size() - 1);
+    fields[fidx] = v;
+  }
+
+  template <typename T = Value>
+  void setField(unsigned fidx, std::enable_if_t<mut, T> v) {
+    assert(fidx < fields.size());
+    fields[fidx] = v;
+  }
+
+  template <typename T = Value>
+  void setSpecifierField(OpBuilder &builder, Location loc,
+                         StorageSpecifierKind kind, Optional<unsigned> dim,
+                         std::enable_if_t<mut, T> v) {
+    SparseTensorSpecifier md(fields.back());
+    md.setSpecifierField(builder, loc, v, kind, dim);
+    fields.back() = md;
+  }
+
+  template <typename T = Value>
+  void setDimSize(OpBuilder &builder, Location loc, unsigned dim,
+                  std::enable_if_t<mut, T> v) {
+    setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
+  }
+
+  ValueRange getMemRefFields() const {
+    ValueRange ret = fields;
+    // drop the last metadata fields
+    return ret.slice(0, fields.size() - 1);
+  }
+
+  Type getMemRefElementType(SparseTensorFieldKind kind,
+                            Optional<unsigned> dim) const {
+    return getMemRefField(kind, dim)
+        .getType()
+        .template cast<MemRefType>()
+        .getElementType();
+  }
+
+  RankedTensorType getTensorType() const { return rType; }
+  ValueArrayRef getFields() const { return fields; }
+
+private:
+  RankedTensorType rType;
+  ValueArrayRef fields;
+};
+
+using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
+using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
+
+/// Returns the "tuple" value of the adapted tensor.
+inline UnrealizedConversionCastOp getTuple(Value tensor) {
+  return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
+}
+
+/// Packs the given values as a "tuple" value.
+inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
+                      ValueRange values) {
+  return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
+      .getResult(0);
+}
+
+inline Value genTuple(OpBuilder &builder, Location loc,
+                      SparseTensorDescriptor desc) {
+  return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
+}
+
+inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
+  auto tuple = getTuple(tensor);
+  return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
+}
+
+inline MutSparseTensorDescriptor
+getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
+  auto tuple = getTuple(tensor);
+  fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
+  return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
+}
+
+} // namespace builder
+} // namespace sparse_tensor
+} // namespace mlir
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 8d01a322222a7..e199d90903cdc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -147,6 +147,7 @@ class SparsificationAndBufferizationPass
       } else {
         pm.addPass(createSparseTensorCodegenPass(enableBufferInitialization));
         pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
+        pm.addPass(createStorageSpecifierToLLVMPass());
       }
       if (failed(runPipeline(pm, getOperation())))
         return signalPassFailure();

diff  --git a/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
new file mode 100644
index 0000000000000..ecdaf3bf9c964
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -sparse-storage-specifier-to-llvm --cse --canonicalize | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+
+// CHECK-LABEL:   func.func @sparse_metadata_init() -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : i64
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK:           %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_1]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK:           %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK:           %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK:           return %[[VAL_4]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK:         }
+func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#CSR> {
+  %0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR>
+  return %0 : !sparse_tensor.storage_specifier<#CSR>
+}
+
+// CHECK-LABEL:   func.func @sparse_get_md(
+// CHECK-SAME:      %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> i64 {
+// CHECK:           %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK:           return %[[VAL_1]] : i64
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> i64 {
+  %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
+       : !sparse_tensor.storage_specifier<#CSR> to i64
+  return %0 : i64
+}
+
+// CHECK-LABEL:   func.func @sparse_set_md(
+// CHECK-SAME:      %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>,
+// CHECK-SAME:      %[[VAL_1:.*]]: i64) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
+// CHECK:           %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK:           return %[[VAL_2]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: i64)
+          -> !sparse_tensor.storage_specifier<#CSR> {
+  %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
+       : i64, !sparse_tensor.storage_specifier<#CSR>
+  return %0 : !sparse_tensor.storage_specifier<#CSR>
+}


        


More information about the Mlir-commits mailing list