[Mlir-commits] [mlir] 44ff23d - [mlir][sparse] unconditionally use IndexType for sparse_tensor.specifier

Peiming Liu llvmlistbot at llvm.org
Wed Feb 22 12:21:39 PST 2023


Author: Peiming Liu
Date: 2023-02-22T20:21:34Z
New Revision: 44ff23d5e49058bcaa170f71540398b4a290a642

URL: https://github.com/llvm/llvm-project/commit/44ff23d5e49058bcaa170f71540398b4a290a642
DIFF: https://github.com/llvm/llvm-project/commit/44ff23d5e49058bcaa170f71540398b4a290a642.diff

LOG: [mlir][sparse] unconditionally use IndexType for sparse_tensor.specifier

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144574

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
    mlir/test/Dialect/SparseTensor/codegen.mlir
    mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
    mlir/test/Dialect/SparseTensor/fold.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index b719d29be3c4d..4d06cb0b088d0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -365,7 +365,7 @@ def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get"
     Arguments<(ins SparseTensorStorageSpecifier:$specifier,
                    SparseTensorStorageSpecifierKindAttr:$specifierKind,
                    OptionalAttr<IndexAttr>:$dim)>,
-    Results<(outs AnyType:$result)> {
+    Results<(outs Index:$result)> {
   let summary = "";
   let description = [{
     Returns the requested field of the given storage_specifier.
@@ -374,12 +374,12 @@ def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get"
 
     ```mlir
     %0 = sparse_tensor.storage_specifier.get %arg0 idx_mem_sz at 0
-         : !sparse_tensor.storage_specifier<#COO> to i64
+         : !sparse_tensor.storage_specifier<#COO>
     ```
   }];
 
   let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? attr-dict `:` "
-                       "qualified(type($specifier)) `to` type($result)";
+                       "qualified(type($specifier))";
   let hasVerifier = 1;
   let hasFolder = 1;
 }
@@ -389,7 +389,7 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
     Arguments<(ins SparseTensorStorageSpecifier:$specifier,
                    SparseTensorStorageSpecifierKindAttr:$specifierKind,
                    OptionalAttr<IndexAttr>:$dim,
-                   AnyType:$value)>,
+                   Index:$value)>,
     Results<(outs SparseTensorStorageSpecifier:$result)> {
   let summary = "";
   let description = [{
@@ -400,12 +400,12 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
 
     ```mlir
     %0 = sparse_tensor.storage_specifier.set %arg0 idx_mem_sz at 0 with %new_sz
-       : i32, !sparse_tensor.storage_specifier<#COO>
+       : !sparse_tensor.storage_specifier<#COO>
 
     ```
   }];
   let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? `with` $value attr-dict `:` "
-                       "type($value) `,` qualified(type($result))";
+                       "qualified(type($result))";
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index fc4c5d870d62a..3ae40d625c8a2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -65,13 +65,6 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
     }]>
   ];
 
-  let extraClassDeclaration = [{
-    // Get the integer type used to store memory and dimension sizes.
-    IntegerType getSizesType() const;
-    Type getFieldType(StorageSpecifierKind kind, std::optional<unsigned> dim) const;
-    Type getFieldType(StorageSpecifierKind kind, std::optional<APInt> dim) const;
-  }];
-
   // We skipped the default builder that simply takes the input sparse tensor encoding
   // attribute since we need to normalize the dimension level type and remove unrelated
   // fields that are irrelavant to sparse tensor storage scheme.

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 207834bacf7e6..9279ec7dddca9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -571,7 +571,11 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
       enc.getContext(), dlts,
       AffineMap(), // dimOrdering (irrelavant to storage speicifer)
       AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
-      enc.getPointerBitWidth(), enc.getIndexBitWidth(),
+      // Always use index for memSize, dimSize instead of reusing
+      // getBitwidth from pointers/indices.
+      // It allows us to reuse the same SSA value for 
diff erent bitwidth,
+      // It also avoids casting between index/integer (returned by DimOp)
+      0, 0,
       // FIXME: we should keep the slice information, for now it is okay as only
       // constant can be used for slice
       ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
@@ -582,36 +586,6 @@ StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
 }
 
-IntegerType StorageSpecifierType::getSizesType() const {
-  unsigned idxBitWidth =
-      getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;
-  unsigned ptrBitWidth =
-      getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;
-
-  return IntegerType::get(getContext(), std::max(idxBitWidth, ptrBitWidth));
-}
-
-// FIXME: see note [CLARIFY_DIM_LVL] in
-// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h"
-Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
-                                        std::optional<unsigned> dim) const {
-  if (kind != StorageSpecifierKind::ValMemSize)
-    assert(dim);
-
-  // Right now, we store every sizes metadata using the same size type.
-  // TODO: the field size type can be defined dimensional wise after sparse
-  // tensor encoding supports per dimension index/pointer bitwidth.
-  return getSizesType();
-}
-
-// FIXME: see note [CLARIFY_DIM_LVL] in
-// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h"
-Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
-                                        std::optional<APInt> dim) const {
-  return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue())
-                                : std::nullopt);
-}
-
 //===----------------------------------------------------------------------===//
 // SparseTensorDialect Operations.
 //===----------------------------------------------------------------------===//
@@ -776,12 +750,6 @@ LogicalResult ToSliceStrideOp::verify() {
 LogicalResult GetStorageSpecifierOp::verify() {
   RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
       getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
-  // Checks the result type
-  if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
-      getResult().getType()) {
-    return emitError(
-        "type mismatch between requested specifier field and result value");
-  }
   return success();
 }
 
@@ -802,12 +770,6 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
 LogicalResult SetStorageSpecifierOp::verify() {
   RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
       getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
-  // Checks the input type
-  if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
-      getValue().getType()) {
-    return emitError(
-        "type mismatch between requested specifier field and input value");
-  }
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 809e9712c752a..cb4eda192d9a9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -204,6 +204,39 @@ StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
 // Misc code generators.
 //===----------------------------------------------------------------------===//
 
+Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
+                             Type dstTy) {
+  Type srcTy = value.getType();
+  if (srcTy != dstTy) {
+    // int <=> index
+    if (dstTy.isa<IndexType>() || srcTy.isa<IndexType>())
+      return builder.create<arith::IndexCastOp>(loc, dstTy, value);
+
+    bool ext = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
+
+    // float => float.
+    if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && ext)
+      return builder.create<arith::ExtFOp>(loc, dstTy, value);
+
+    if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !ext)
+      return builder.create<arith::TruncFOp>(loc, dstTy, value);
+
+    // int => int
+    if (srcTy.isUnsignedInteger() && dstTy.isa<IntegerType>() && ext)
+      return builder.create<arith::ExtUIOp>(loc, dstTy, value);
+
+    if (srcTy.isSignedInteger() && dstTy.isa<IntegerType>() && ext)
+      return builder.create<arith::ExtSIOp>(loc, dstTy, value);
+
+    if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !ext)
+      return builder.create<arith::TruncIOp>(loc, dstTy, value);
+
+    llvm_unreachable("unhandled type casting");
+  }
+
+  return value;
+}
+
 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
   if (tp.isa<FloatType>())
     return builder.getFloatAttr(tp, 1.0);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 2624d5c826f4a..de78010e57f22 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -78,6 +78,9 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
 // Misc code generators and utilities.
 //===----------------------------------------------------------------------===//
 
+/// Add type casting between arith and index types when needed.
+Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
+
 /// Generates a 1-valued attribute of the given type.  This supports
 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
 /// for unsupported types we raise `llvm_unreachable` rather than

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index d634a1d1f5377..caa4dd5d722f6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -27,7 +27,9 @@ static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
   const Level lvlRank = enc.getLvlRank();
 
   SmallVector<Type, 2> result;
-  auto indexType = tp.getSizesType();
+  // TODO: how can we get the lowering type for index type in the later pipeline
+  // to be consistent? LLVM::StructureType does not allow index fields.
+  auto indexType = IntegerType::get(tp.getContext(), 64);
   auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, lvlRank);
   auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
                                            getNumDataFieldsFromEncoding(enc));
@@ -49,6 +51,21 @@ constexpr uint64_t kDimSizePosInSpecifier = 0;
 constexpr uint64_t kMemSizePosInSpecifier = 1;
 
 class SpecifierStructBuilder : public StructBuilder {
+private:
+  Value extractField(OpBuilder &builder, Location loc,
+                     ArrayRef<int64_t> indices) {
+    return genCast(builder, loc,
+                   builder.create<LLVM::ExtractValueOp>(loc, value, indices),
+                   builder.getIndexType());
+  }
+
+  void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
+                   Value v) {
+    value = builder.create<LLVM::InsertValueOp>(
+        loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
+        indices);
+  }
+
 public:
   explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
     assert(value);
@@ -83,29 +100,30 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
 /// Builds IR inserting the pos-th size into the descriptor.
 Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
                                       unsigned dim) {
-  return builder.create<LLVM::ExtractValueOp>(
-      loc, value, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
+  return extractField(builder, loc,
+                      ArrayRef<int64_t>{kDimSizePosInSpecifier, dim});
 }
 
 /// Builds IR inserting the pos-th size into the descriptor.
 void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
                                         unsigned dim, Value size) {
-  value = builder.create<LLVM::InsertValueOp>(
-      loc, value, size, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
+
+  insertField(builder, loc, ArrayRef<int64_t>{kDimSizePosInSpecifier, dim},
+              size);
 }
 
 /// Builds IR extracting the pos-th memory size into the descriptor.
 Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
                                       unsigned pos) {
-  return builder.create<LLVM::ExtractValueOp>(
-      loc, value, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
+  return extractField(builder, loc,
+                      ArrayRef<int64_t>{kMemSizePosInSpecifier, pos});
 }
 
 /// Builds IR inserting the pos-th memory size into the descriptor.
 void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
                                         unsigned pos, Value size) {
-  value = builder.create<LLVM::InsertValueOp>(
-      loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
+  insertField(builder, loc, ArrayRef<int64_t>{kMemSizePosInSpecifier, pos},
+              size);
 }
 
 } // namespace

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 335f743e2db3d..fc1ea386b4223 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -67,25 +67,18 @@ static void flattenOperands(ValueRange operands,
   }
 }
 
-/// Adds index conversions where needed.
-static Value toType(OpBuilder &builder, Location loc, Value value, Type tp) {
-  if (value.getType() != tp)
-    return builder.create<arith::IndexCastOp>(loc, tp, value);
-  return value;
-}
-
 /// Generates a load with proper index typing.
 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
-  idx = toType(builder, loc, idx, builder.getIndexType());
+  idx = genCast(builder, loc, idx, builder.getIndexType());
   return builder.create<memref::LoadOp>(loc, mem, idx);
 }
 
 /// Generates a store with proper index typing and (for indices) proper value.
 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
                      Value idx) {
-  idx = toType(builder, loc, idx, builder.getIndexType());
-  val = toType(builder, loc, val,
-               mem.getType().cast<ShapedType>().getElementType());
+  idx = genCast(builder, loc, idx, builder.getIndexType());
+  val = genCast(builder, loc, val,
+                mem.getType().cast<ShapedType>().getElementType());
   builder.create<memref::StoreOp>(loc, val, mem, idx);
 }
 
@@ -141,7 +134,7 @@ static void createPushback(OpBuilder &builder, Location loc,
 
   auto pushBackOp = builder.create<PushBackOp>(
       loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
-      toType(builder, loc, value, etp), repeat);
+      genCast(builder, loc, value, etp), repeat);
 
   desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
   desc.setSpecifierField(builder, loc, specFieldKind, lvl,
@@ -338,7 +331,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
     msz = builder.create<arith::DivUIOp>(loc, msz, idxStrideC);
   }
   Value phim1 = builder.create<arith::SubIOp>(
-      loc, toType(builder, loc, phi, indexType), one);
+      loc, genCast(builder, loc, phi, indexType), one);
   // Conditional expression.
   Value lt =
       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, plo, phi);
@@ -350,9 +343,9 @@ static Value genCompressed(OpBuilder &builder, Location loc,
       builder, loc, desc.getMemRefField(idxIndex),
       idxStride > 1 ? builder.create<arith::MulIOp>(loc, phim1, idxStrideC)
                     : phim1);
-  Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
-                                           toType(builder, loc, crd, indexType),
-                                           indices[lvl]);
+  Value eq = builder.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
+      indices[lvl]);
   builder.create<scf::YieldOp>(loc, eq);
   builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
   if (lvl > 0)
@@ -1226,8 +1219,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
     // Converts MemRefs back to Tensors.
     Value data = rewriter.create<bufferization::ToTensorOp>(loc, dataBuf);
     Value indices = rewriter.create<bufferization::ToTensorOp>(loc, idxBuf);
-    Value nnz = toType(rewriter, loc, desc.getValMemSize(rewriter, loc),
-                       op.getNnz().getType());
+    Value nnz = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
+                        op.getNnz().getType());
 
     rewriter.replaceOp(op, {data, indices, nnz});
     return success();

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index be59ba83f0f4b..b3336b5fc6ae2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -22,13 +22,6 @@ using namespace sparse_tensor;
 // Private helper methods.
 //===----------------------------------------------------------------------===//
 
-static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
-                             Type to) {
-  if (value.getType() != to)
-    return builder.create<arith::IndexCastOp>(loc, to, value);
-  return value;
-}
-
 static IntegerAttr fromOptionalInt(MLIRContext *ctx,
                                    std::optional<unsigned> dim) {
   if (!dim)
@@ -90,20 +83,17 @@ Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
 Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc,
                                                StorageSpecifierKind kind,
                                                std::optional<unsigned> dim) {
-  return createIndexCast(builder, loc,
-                         builder.create<GetStorageSpecifierOp>(
-                             loc, getFieldType(kind, dim), specifier, kind,
-                             fromOptionalInt(specifier.getContext(), dim)),
-                         builder.getIndexType());
+  return builder.create<GetStorageSpecifierOp>(
+      loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim));
 }
 
 void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
                                               Value v,
                                               StorageSpecifierKind kind,
                                               std::optional<unsigned> dim) {
+  assert(v.getType().isIndex());
   specifier = builder.create<SetStorageSpecifierOp>(
-      loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim),
-      createIndexCast(builder, loc, v, getFieldType(kind, dim)));
+      loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim), v);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 40207d756425e..c30a15d87baca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -241,11 +241,6 @@ class SparseTensorSpecifier {
                          StorageSpecifierKind kind,
                          std::optional<unsigned> dim);
 
-  // FIXME: see note [CLARIFY_DIM_LVL].
-  Type getFieldType(StorageSpecifierKind kind, std::optional<unsigned> dim) {
-    return specifier.getType().getFieldType(kind, dim);
-  }
-
 private:
   TypedValue<StorageSpecifierType> specifier;
 };
@@ -283,6 +278,8 @@ class SparseTensorDescriptorImpl {
   /// Getters: get the value for required field.
   ///
 
+  Value getSpecifier() const { return fields.back(); }
+
   // FIXME: see note [CLARIFY_DIM_LVL].
   Value getSpecifierField(OpBuilder &builder, Location loc,
                           StorageSpecifierKind kind,

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 081ad5b2cf1e1..4b1beb51713e7 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -190,8 +190,7 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
 //  CHECK-SAME: %[[A0:.*]]: memref<?xf64>,
 //  CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier
 //       CHECK: %[[A2:.*]] = sparse_tensor.storage_specifier.get %[[A1]] dim_sz at 2
-//       CHECK: %[[A3:.*]] = arith.index_cast %[[A2]] : i64 to index
-//       CHECK: return %[[A3]] : index
+//       CHECK: return %[[A2]] : index
 func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
   %c = arith.constant 1 : index
   %0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #Dense3D>
@@ -260,8 +259,7 @@ func.func @sparse_values_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xf64> {
 //  CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
 //       CHECK: %[[C2:.*]] = arith.constant 2 : index
 //       CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]]  idx_mem_sz at 1
-//       CHECK: %[[S1:.*]] = arith.index_cast %[[S0]]
-//       CHECK: %[[S2:.*]] = arith.divui %[[S1]], %[[C2]] : index
+//       CHECK: %[[S2:.*]] = arith.divui %[[S0]], %[[C2]] : index
 //       CHECK: %[[R1:.*]] = memref.subview %[[A3]][0] {{\[}}%[[S2]]] [2] : memref<?xindex> to memref<?xindex, strided<[2]>>
 //       CHECK: %[[R2:.*]] = memref.cast %[[R1]] : memref<?xindex, strided<[2]>> to memref<?xindex, strided<[?], offset: ?>>
 //       CHECK: return %[[R2]] : memref<?xindex, strided<[?], offset: ?>>
@@ -288,8 +286,7 @@ func.func @sparse_indices_buffer_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<
 //  CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
 //  CHECK-SAME: %[[A2:.*]]: memref<?xf64>,
 //  CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
-//       CHECK: %[[A4:.*]] = sparse_tensor.storage_specifier.get %[[A3]] val_mem_sz
-//       CHECK: %[[NOE:.*]] = arith.index_cast %[[A4]] : i64 to index
+//       CHECK: %[[NOE:.*]] = sparse_tensor.storage_specifier.get %[[A3]] val_mem_sz
 //       CHECK: return %[[NOE]] : index
 func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
   %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
@@ -312,8 +309,8 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
 
 // CHECK-LABEL:   func.func @sparse_alloc_csc(
 //  CHECK-SAME:     %[[A0:.*]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-//       CHECK:     %[[A1:.*]] = arith.constant 10 : i64
-//       CHECK:     %[[A2:.*]] = arith.constant 0 : index
+//   CHECK-DAG:     %[[A1:.*]] = arith.constant 10 : index
+//   CHECK-DAG:     %[[A2:.*]] = arith.constant 0 : index
 //       CHECK:     %[[A3:.*]] = memref.alloc() : memref<16xindex>
 //       CHECK:     %[[A4:.*]] = memref.cast %[[A3]] : memref<16xindex> to memref<?xindex>
 //       CHECK:     %[[A5:.*]] = memref.alloc() : memref<16xindex>
@@ -321,17 +318,13 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
 //       CHECK:     %[[A7:.*]] = memref.alloc() : memref<16xf64>
 //       CHECK:     %[[A8:.*]] = memref.cast %[[A7]] : memref<16xf64> to memref<?xf64>
 //       CHECK:     %[[A9:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
-//       CHECK:     %[[A10:.*]] = arith.index_cast %[[A0]] : index to i64
-//       CHECK:     %[[A11:.*]] = sparse_tensor.storage_specifier.set %[[A9]]  dim_sz at 0 with %[[A10]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     %[[A12:.*]] = sparse_tensor.storage_specifier.set %[[A11]]  dim_sz at 1 with %[[A1]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     %[[A13:.*]] = sparse_tensor.storage_specifier.get %[[A12]]  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-//       CHECK:     %[[A14:.*]] = arith.index_cast %[[A13]] : i64 to index
-//       CHECK:     %[[A15:.*]], %[[A16:.*]] = sparse_tensor.push_back %[[A14]], %[[A4]], %[[A2]] : index, memref<?xindex>, index
-//       CHECK:     %[[A17:.*]] = arith.index_cast %[[A16]] : index to i64
-//       CHECK:     %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A12]]  ptr_mem_sz at 1 with %[[A17]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     %[[A23:.*]], %[[A24:.*]] = sparse_tensor.push_back %[[A16]], %[[A15]], %[[A2]], %[[A0]] : index, memref<?xindex>, index, index
-//       CHECK:     %[[A25:.*]] = arith.index_cast %[[A24]] : index to i64
-//       CHECK:     %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]]  ptr_mem_sz at 1 with %[[A25]] : i64, !sparse_tensor.storage_specifier
+//       CHECK:     %[[A11:.*]] = sparse_tensor.storage_specifier.set %[[A9]]  dim_sz at 0 with %[[A0]] : !sparse_tensor.storage_specifier
+//       CHECK:     %[[A12:.*]] = sparse_tensor.storage_specifier.set %[[A11]]  dim_sz at 1 with %[[A1]] : !sparse_tensor.storage_specifier
+//       CHECK:     %[[A14:.*]] = sparse_tensor.storage_specifier.get %[[A12]]  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+//       CHECK:     %[[A15:.*]], %[[A17:.*]] = sparse_tensor.push_back %[[A14]], %[[A4]], %[[A2]] : index, memref<?xindex>, index
+//       CHECK:     %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A12]]  ptr_mem_sz at 1 with %[[A17]] : !sparse_tensor.storage_specifier
+//       CHECK:     %[[A23:.*]], %[[A25:.*]] = sparse_tensor.push_back %[[A17]], %[[A15]], %[[A2]], %[[A0]] : index, memref<?xindex>, index, index
+//       CHECK:     %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]]  ptr_mem_sz at 1 with %[[A25]] : !sparse_tensor.storage_specifier
 //       CHECK:     return %[[A23]], %[[A6]], %[[A8]], %[[A26]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
   %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
@@ -340,23 +333,21 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
 }
 
 // CHECK-LABEL:   func.func @sparse_alloc_3d() -> (memref<?xf64>, !sparse_tensor.storage_specifier
-//       CHECK:     %[[A0:.*]] = arith.constant 6000 : index
-//       CHECK:     %[[A1:.*]] = arith.constant 20 : i64
-//       CHECK:     %[[A2:.*]] = arith.constant 10 : i64
-//       CHECK:     %[[A3:.*]] = arith.constant 30 : i64
-//       CHECK:     %[[A4:.*]] = arith.constant 0.000000e+00 : f64
+//   CHECK-DAG:     %[[A0:.*]] = arith.constant 6000 : index
+//   CHECK-DAG:     %[[A1:.*]] = arith.constant 20 : index
+//   CHECK-DAG:     %[[A2:.*]] = arith.constant 10 : index
+//   CHECK-DAG:     %[[A3:.*]] = arith.constant 30 : index
+//   CHECK-DAG:     %[[A4:.*]] = arith.constant 0.000000e+00 : f64
 //       CHECK:     %[[A5:.*]] = memref.alloc() : memref<6000xf64>
 //       CHECK:     %[[A6:.*]] = memref.cast %[[A5]] : memref<6000xf64> to memref<?xf64>
 //       CHECK:     %[[A7:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
-//       CHECK:     %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]]  dim_sz at 0 with %[[A3]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]]  dim_sz at 1 with %[[A2]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     %[[A10:.*]] = sparse_tensor.storage_specifier.set %[[A9]]  dim_sz at 2 with %[[A1]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     %[[A11:.*]] = sparse_tensor.storage_specifier.get %[[A10]]  val_mem_sz : !sparse_tensor.storage_specifier
-//       CHECK:     %[[A12:.*]] = arith.index_cast %[[A11]] : i64 to index
-//       CHECK:     %[[A13:.*]], %[[A14:.*]] = sparse_tensor.push_back %[[A12]], %[[A6]], %[[A4]], %[[A0]] : index, memref<?xf64>, f64, index
-//       CHECK:     %[[A15:.*]] = arith.index_cast %[[A14]] : index to i64
-//       CHECK:     %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]]  val_mem_sz with %[[A15]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     return %[[A13]], %[[A16]] : memref<?xf64>, !sparse_tensor.storage_specifier
+//       CHECK:     %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]]  dim_sz at 0 with %[[A3]] : !sparse_tensor.storage_specifier
+//       CHECK:     %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]]  dim_sz at 1 with %[[A2]] : !sparse_tensor.storage_specifier
+//       CHECK:     %[[A10:.*]] = sparse_tensor.storage_specifier.set %[[A9]]  dim_sz at 2 with %[[A1]] : !sparse_tensor.storage_specifier
+//       CHECK:     %[[A12:.*]] = sparse_tensor.storage_specifier.get %[[A10]]  val_mem_sz : !sparse_tensor.storage_specifier
+//       CHECK:     %[[A15:.*]], %[[A14:.*]] = sparse_tensor.push_back %[[A12]], %[[A6]], %[[A4]], %[[A0]] : index, memref<?xf64>, f64, index
+//       CHECK:     %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]]  val_mem_sz with %[[A14]] : !sparse_tensor.storage_specifier
+//       CHECK:     return %[[A15]], %[[A16]] : memref<?xf64>, !sparse_tensor.storage_specifier
 func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
   %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
   %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
@@ -503,8 +494,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
 //       CHECK:     memref.dealloc %[[A4]] : memref<?xf64>
 //       CHECK:     memref.dealloc %[[A5]] : memref<?xi1>
 //       CHECK:     memref.dealloc %[[A6]] : memref<?xindex>
-//       CHECK:     %[[A23:.*]] = sparse_tensor.storage_specifier.get %[[A24:.*]]#3  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-//       CHECK:     %[[A25:.*]] = arith.index_cast %[[A23]] : i64 to index
+//       CHECK:     %[[A25:.*]] = sparse_tensor.storage_specifier.get %[[A24:.*]]#3  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
 //       CHECK:     %[[A26:.*]] = memref.load %[[A24]]#0{{\[}}%[[A13]]] : memref<?xi32>
 //       CHECK:     %[[A27:.*]] = scf.for %[[A28:.*]] = %[[A12]] to %[[A25]] step %[[A12]] iter_args(%[[A29:.*]] = %[[A26]]) -> (i32) {
 //       CHECK:       %[[A30:.*]] = memref.load %[[A24]]#0{{\[}}%[[A28]]] : memref<?xi32>
@@ -562,8 +552,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
 //       CHECK:     memref.dealloc %[[A4]] : memref<?xf64>
 //       CHECK:     memref.dealloc %[[A5]] : memref<?xi1>
 //       CHECK:     memref.dealloc %[[A6]] : memref<?xindex>
-//       CHECK:     %[[A22:.*]] = sparse_tensor.storage_specifier.get %[[A23:.*]]#3  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-//       CHECK:     %[[A24:.*]] = arith.index_cast %[[A22]] : i64 to index
+//       CHECK:     %[[A24:.*]] = sparse_tensor.storage_specifier.get %[[A23:.*]]#3  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
 //       CHECK:     %[[A25:.*]] = memref.load %[[A23]]#0{{\[}}%[[A11]]] : memref<?xindex>
 //       CHECK:     %[[A26:.*]] = scf.for %[[A27:.*]] = %[[A12]] to %[[A24]] step %[[A12]] iter_args(%[[A28:.*]] = %[[A25]]) -> (index) {
 //       CHECK:       %[[A29:.*]] = memref.load %[[A23]]#0{{\[}}%[[A27]]] : memref<?xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
index 33bbe6a71ad07..7a0d668082d52 100644
--- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
@@ -17,16 +17,12 @@
 //       CHECK:     %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<16xf64> to memref<?xf64>
 //       CHECK:     linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_8]] : memref<16xf64>)
 //       CHECK:     %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
-//       CHECK:     %[[VAL_11:.*]] = arith.index_cast %[[VAL_0]] : index to i64
-//       CHECK:     %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  dim_sz at 0 with %[[VAL_11]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_12]]  ptr_mem_sz at 0 : !sparse_tensor.storage_specifier
-//       CHECK:     %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i64 to index
-//       CHECK:     %[[VAL_15:.*]], %[[VAL_16:.*]] = sparse_tensor.push_back %[[VAL_14]], %[[VAL_5]], %[[VAL_3]] : index, memref<?xindex>, index
-//       CHECK:     %[[VAL_17:.*]] = arith.index_cast %[[VAL_16]] : index to i64
-//       CHECK:     %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]]  ptr_mem_sz at 0 with %[[VAL_17]] : i64, !sparse_tensor.storage_specifier
-//       CHECK:     %[[VAL_19:.*]], %[[VAL_20:.*]] = sparse_tensor.push_back %[[VAL_16]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref<?xindex>, index, index
-//       CHECK:     %[[VAL_21:.*]] = arith.index_cast %[[VAL_20]] : index to i64
-//       CHECK:     %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  ptr_mem_sz at 0 with %[[VAL_21]] : i64, !sparse_tensor.storage_specifier
+//       CHECK:     %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  dim_sz at 0 with %[[VAL_0]] : !sparse_tensor.storage_specifier
+//       CHECK:     %[[VAL_14:.*]] = sparse_tensor.storage_specifier.get %[[VAL_12]]  ptr_mem_sz at 0 : !sparse_tensor.storage_specifier
+//       CHECK:     %[[VAL_15:.*]], %[[VAL_17:.*]] = sparse_tensor.push_back %[[VAL_14]], %[[VAL_5]], %[[VAL_3]] : index, memref<?xindex>, index
+//       CHECK:     %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]]  ptr_mem_sz at 0 with %[[VAL_17]] : !sparse_tensor.storage_specifier
+//       CHECK:     %[[VAL_19:.*]], %[[VAL_21:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref<?xindex>, index, index
+//       CHECK:     %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  ptr_mem_sz at 0 with %[[VAL_21]] : !sparse_tensor.storage_specifier
 //       CHECK:     return %[[VAL_19]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor<?xf64, #SV> {
   %0 = bufferization.alloc_tensor(%arg0) : tensor<?xf64, #SV>

diff  --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 7397c0b22958a..3e559109189fe 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -48,17 +48,17 @@ func.func @sparse_concat_dce(%arg0: tensor<2xf64, #SparseVector>,
 
 // CHECK-LABEL: func @sparse_get_specifier_dce_fold(
 //  CHECK-SAME:  %[[A0:.*]]: !sparse_tensor.storage_specifier
-//  CHECK-SAME:  %[[A1:.*]]: i64,
-//  CHECK-SAME:  %[[A2:.*]]: i64)
+//  CHECK-SAME:  %[[A1:.*]]: index,
+//  CHECK-SAME:  %[[A2:.*]]: index)
 //   CHECK-NOT:  sparse_tensor.storage_specifier.set
 //   CHECK-NOT:  sparse_tensor.storage_specifier.get
 //       CHECK:  return %[[A1]]
-func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64, %arg2: i64) -> i64 {
+func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: index, %arg2: index) -> index {
   %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
-       : i64, !sparse_tensor.storage_specifier<#SparseVector>
+       : !sparse_tensor.storage_specifier<#SparseVector>
   %1 = sparse_tensor.storage_specifier.set %0 ptr_mem_sz at 0 with %arg2
-       : i64, !sparse_tensor.storage_specifier<#SparseVector>
+       : !sparse_tensor.storage_specifier<#SparseVector>
   %2 = sparse_tensor.storage_specifier.get %1 dim_sz at 0
-       : !sparse_tensor.storage_specifier<#SparseVector> to i64
-  return %2 : i64
+       : !sparse_tensor.storage_specifier<#SparseVector>
+  return %2 : index
 }

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 8f52d20942f9d..8d0a0e7d69869 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -252,68 +252,44 @@ func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index {
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
   // expected-error at +1 {{redundant level argument for querying value memory size}}
   %0 = sparse_tensor.storage_specifier.get %arg0 val_mem_sz at 0
-       : !sparse_tensor.storage_specifier<#SparseVector> to i64
-  return %0 : i64
+       : !sparse_tensor.storage_specifier<#SparseVector>
+  return %0 : index
 }
 
 // -----
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
   // expected-error at +1 {{missing level argument}}
   %0 = sparse_tensor.storage_specifier.get %arg0 idx_mem_sz
-       : !sparse_tensor.storage_specifier<#SparseVector> to i64
-  return %0 : i64
+       : !sparse_tensor.storage_specifier<#SparseVector>
+  return %0 : index
 }
 
 // -----
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
   // expected-error at +1 {{requested level out of bound}}
   %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 1
-       : !sparse_tensor.storage_specifier<#SparseVector> to i64
-  return %0 : i64
+       : !sparse_tensor.storage_specifier<#SparseVector>
+  return %0 : index
 }
 
 // -----
 
 #COO = #sparse_tensor.encoding<{dimLevelType = ["compressed-nu", "singleton"]}>
 
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> i64 {
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> index {
   // expected-error at +1 {{requested pointer memory size on a singleton level}}
   %0 = sparse_tensor.storage_specifier.get %arg0 ptr_mem_sz at 1
-       : !sparse_tensor.storage_specifier<#COO> to i64
-  return %0 : i64
-}
-
-// -----
-
-#COO = #sparse_tensor.encoding<{dimLevelType = ["compressed-nu", "singleton"]}>
-
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> i64 {
-  // expected-error at +1 {{type mismatch between requested }}
-  %0 = sparse_tensor.storage_specifier.get %arg0 ptr_mem_sz at 0
-       : !sparse_tensor.storage_specifier<#COO> to i32
-  return %0 : i32
-}
-
-// -----
-
-#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
-
-func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>,
-                         %arg1: i32)
-          -> !sparse_tensor.storage_specifier<#SparseVector> {
-  // expected-error at +1 {{type mismatch between requested }}
-  %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
-       : i32, !sparse_tensor.storage_specifier<#SparseVector>
-  return %0 : !sparse_tensor.storage_specifier<#SparseVector>
+       : !sparse_tensor.storage_specifier<#COO>
+  return %0 : index
 }
 
 // -----

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 70c9a9862d533..608b6c80e24f3 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -184,11 +184,11 @@ func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#SparseVec
 // CHECK-LABEL: func @sparse_get_md(
 //  CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>
 //       CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_sz at 0
-//       CHECK: return %[[T]] : i64
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+//       CHECK: return %[[T]] : index
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
   %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
-       : !sparse_tensor.storage_specifier<#SparseVector> to i64
-  return %0 : i64
+       : !sparse_tensor.storage_specifier<#SparseVector>
+  return %0 : index
 }
 
 // -----
@@ -197,13 +197,13 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
 
 // CHECK-LABEL: func @sparse_set_md(
 //  CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>,
-//  CHECK-SAME: %[[I:.*]]: i64)
+//  CHECK-SAME: %[[I:.*]]: index)
 //       CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.set %[[A]] dim_sz at 0 with %[[I]]
 //       CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}>
-func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64)
+func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: index)
           -> !sparse_tensor.storage_specifier<#SparseVector> {
   %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
-       : i64, !sparse_tensor.storage_specifier<#SparseVector>
+       : !sparse_tensor.storage_specifier<#SparseVector>
   return %0 : !sparse_tensor.storage_specifier<#SparseVector>
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 70a5fa1338ad9..aaa7b581675d8 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -25,8 +25,7 @@
 // CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index
 // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  idx_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i64 to index
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  idx_mem_sz at 1 : !sparse_tensor.storage_specifier
 // CHECK:           %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index
 // CHECK:           %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index
 // CHECK:           %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
@@ -42,16 +41,13 @@
 // CHECK:           } else {
 // CHECK:             %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index
 // CHECK:             memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK:             %[[VAL_22:.*]], %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref<?xindex>, index
-// CHECK:             %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : index to i64
-// CHECK:             %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]]  idx_mem_sz at 1 with %[[VAL_24]] : i64, !sparse_tensor.storage_specifier
+// CHECK:             %[[VAL_22:.*]], %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref<?xindex>, index
+// CHECK:             %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]]  idx_mem_sz at 1 with %[[VAL_24]] : !sparse_tensor.storage_specifier
 // CHECK:             scf.yield %[[VAL_22]], %[[VAL_25]] : memref<?xindex>, !sparse_tensor.storage_specifier
 // CHECK:           }
-// CHECK:           %[[VAL_26:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1  val_mem_sz : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_28:.*]] = arith.index_cast %[[VAL_26]] : i64 to index
+// CHECK:           %[[VAL_28:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1  val_mem_sz : !sparse_tensor.storage_specifier
 // CHECK:           %[[VAL_29:.*]], %[[VAL_30:.*]] = sparse_tensor.push_back %[[VAL_28]], %[[VAL_2]], %[[VAL_6]] : index, memref<?xf64>, f64
-// CHECK:           %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : index to i64
-// CHECK:           %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1  val_mem_sz with %[[VAL_31]] : i64, !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1  val_mem_sz with %[[VAL_30]] : !sparse_tensor.storage_specifier
 // CHECK:           return %[[VAL_0]], %[[VAL_27]]#0, %[[VAL_29]], %[[VAL_32]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 // CHECK:         }
 
@@ -64,94 +60,89 @@
 // CHECK-SAME:      %[[VAL_5:.*5]]: memref<?xindex>,
 // CHECK-SAME:      %[[VAL_6:.*6]]: memref<?xf64>,
 // CHECK-SAME:      %[[VAL_7:.*7]]: !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_8:.*]] = arith.constant 4 : index
-// CHECK:           %[[VAL_9:.*]] = arith.constant 4 : i64
-// CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK:           %[[VAL_11:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_13:.*]] = arith.constant false
-// CHECK:           %[[VAL_14:.*]] = arith.constant true
-// CHECK:           %[[VAL_15:.*]] = memref.alloc() : memref<16xindex>
-// CHECK:           %[[VAL_16:.*]] = memref.cast %[[VAL_15]] : memref<16xindex> to memref<?xindex>
-// CHECK:           %[[VAL_17:.*]] = memref.alloc() : memref<16xindex>
-// CHECK:           %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<16xindex> to memref<?xindex>
-// CHECK:           %[[VAL_19:.*]] = memref.alloc() : memref<16xf64>
-// CHECK:           %[[VAL_20:.*]] = memref.cast %[[VAL_19]] : memref<16xf64> to memref<?xf64>
-// CHECK:           %[[VAL_21:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]]  dim_sz at 0 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]]  dim_sz at 1 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_24:.*]] = sparse_tensor.storage_specifier.get %[[VAL_23]]  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_25:.*]] = arith.index_cast %[[VAL_24]] : i64 to index
-// CHECK:           %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_16]], %[[VAL_11]] : index, memref<?xindex>, index
-// CHECK:           %[[VAL_28:.*]] = arith.index_cast %[[VAL_27]] : index to i64
-// CHECK:           %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]]  ptr_mem_sz at 1 with %[[VAL_28]] : i64, !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_32:.*]], %[[VAL_33:.*]] = sparse_tensor.push_back %[[VAL_27]], %[[VAL_26]], %[[VAL_11]], %[[VAL_8]] : index, memref<?xindex>, index, index
-// CHECK:           %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i64
-// CHECK:           %[[VAL_35:.*]] = sparse_tensor.storage_specifier.set %[[VAL_29]]  ptr_mem_sz at 1 with %[[VAL_34]] : i64, !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_36:.*]] = memref.alloc() : memref<4xf64>
-// CHECK:           %[[VAL_37:.*]] = memref.alloc() : memref<4xi1>
-// CHECK:           %[[VAL_38:.*]] = memref.alloc() : memref<4xindex>
-// CHECK:           %[[VAL_39:.*]] = memref.cast %[[VAL_38]] : memref<4xindex> to memref<?xindex>
-// CHECK:           linalg.fill ins(%[[VAL_10]] : f64) outs(%[[VAL_36]] : memref<4xf64>)
-// CHECK:           linalg.fill ins(%[[VAL_13]] : i1) outs(%[[VAL_37]] : memref<4xi1>)
-// CHECK:           %[[VAL_40:.*]]:4 = scf.for %[[VAL_41:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_12]] iter_args(%[[VAL_42:.*]] = %[[VAL_32]], %[[VAL_43:.*]] = %[[VAL_18]], %[[VAL_44:.*]] = %[[VAL_20]], %[[VAL_45:.*]] = %[[VAL_35]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK:             %[[VAL_46:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK:             %[[VAL_47:.*]] = arith.addi %[[VAL_41]], %[[VAL_12]] : index
-// CHECK:             %[[VAL_48:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_47]]] : memref<?xindex>
-// CHECK:             %[[VAL_49:.*]] = scf.for %[[VAL_50:.*]] = %[[VAL_46]] to %[[VAL_48]] step %[[VAL_12]] iter_args(%[[VAL_51:.*]] = %[[VAL_11]]) -> (index) {
-// CHECK:               %[[VAL_52:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_50]]] : memref<?xindex>
-// CHECK:               %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_50]]] : memref<?xf64>
-// CHECK:               %[[VAL_54:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_52]]] : memref<?xindex>
-// CHECK:               %[[VAL_55:.*]] = arith.addi %[[VAL_52]], %[[VAL_12]] : index
-// CHECK:               %[[VAL_56:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK:               %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_12]] iter_args(%[[VAL_59:.*]] = %[[VAL_51]]) -> (index) {
-// CHECK:                 %[[VAL_60:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_58]]] : memref<?xindex>
-// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64>
-// CHECK:                 %[[VAL_62:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_58]]] : memref<?xf64>
-// CHECK:                 %[[VAL_63:.*]] = arith.mulf %[[VAL_53]], %[[VAL_62]] : f64
-// CHECK:                 %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f64
-// CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1>
-// CHECK:                 %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_13]] : i1
-// CHECK:                 %[[VAL_67:.*]] = scf.if %[[VAL_66]] -> (index) {
-// CHECK:                   memref.store %[[VAL_14]], %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1>
-// CHECK:                   memref.store %[[VAL_60]], %[[VAL_38]]{{\[}}%[[VAL_59]]] : memref<4xindex>
-// CHECK:                   %[[VAL_68:.*]] = arith.addi %[[VAL_59]], %[[VAL_12]] : index
-// CHECK:                   scf.yield %[[VAL_68]] : index
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant true
+// CHECK:           %[[VAL_14:.*]] = memref.alloc() : memref<16xindex>
+// CHECK:           %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<16xindex> to memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<16xindex>
+// CHECK:           %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<16xindex> to memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = memref.alloc() : memref<16xf64>
+// CHECK:           %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xf64> to memref<?xf64>
+// CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]]  dim_sz at 0 with %[[VAL_8]] : !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]]  dim_sz at 1 with %[[VAL_8]] : !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_23:.*]] = sparse_tensor.storage_specifier.get %[[VAL_22]]  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_24:.*]], %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_23]], %[[VAL_15]], %[[VAL_10]] : index, memref<?xindex>, index
+// CHECK:           %[[VAL_26:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]]  ptr_mem_sz at 1 with %[[VAL_25]] : !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_27:.*]], %[[VAL_28:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_24]], %[[VAL_10]], %[[VAL_8]] : index, memref<?xindex>, index, index
+// CHECK:           %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_26]]  ptr_mem_sz at 1 with %[[VAL_28]] : !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_30:.*]] = memref.alloc() : memref<4xf64>
+// CHECK:           %[[VAL_31:.*]] = m
+// CHECK:           %[[VAL_32:.*]] = memref.alloc() : memref<4xindex>
+// CHECK:           %[[VAL_33:.*]] = memref.cast %[[VAL_32]] : memref<4xindex> to memref<?xindex>
+// CHECK:           linalg.fill ins(%[[VAL_9]] : f64) outs(%[[VAL_30]] : memref<4xf64>)
+// CHECK:           linalg.fill ins(%[[VAL_12]] : i1) outs(%[[VAL_31]] : memref<4xi1>)
+// CHECK:           %[[VAL_34:.*]]:4 = scf.for %[[VAL_35:.*]] = %[[VAL_10]] to %[[VAL_8]] step %[[VAL_11]] iter_args(%[[VAL_36:.*]] = %[[VAL_27]], %[[VAL_37:.*]] = %[[VAL_17]], %[[VAL_38:.*]] = %[[VAL_19]], %[[VAL_39:.*]] = %[[VAL_29]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK:             %[[VAL_40:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK:             %[[VAL_41:.*]] = arith.addi %[[VAL_35]], %[[VAL_11]] : index
+// CHECK:             %[[VAL_42:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref<?xindex>
+// CHECK:             %[[VAL_43:.*]] = scf.for %[[VAL_44:.*]] = %[[VAL_40]] to %[[VAL_42]] step %[[VAL_11]] iter_args(%[[VAL_45:.*]] = %[[VAL_10]]) -> (index) {
+// CHECK:               %[[VAL_46:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK:               %[[VAL_47:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_44]]] : memref<?xf64>
+// CHECK:               %[[VAL_48:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_46]]] : memref<?xindex>
+// CHECK:               %[[VAL_49:.*]] = arith.addi %[[VAL_46]], %[[VAL_11]] : index
+// CHECK:               %[[VAL_50:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_49]]] : memref<?xindex>
+// CHECK:               %[[VAL_51:.*]] = scf.for %[[VAL_52:.*]] = %[[VAL_48]] to %[[VAL_50]] step %[[VAL_11]] iter_args(%[[VAL_53:.*]] = %[[VAL_45]]) -> (index) {
+// CHECK:                 %[[VAL_54:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_52]]] : memref<?xindex>
+// CHECK:                 %[[VAL_55:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xf64>
+// CHECK:                 %[[VAL_56:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_52]]] : memref<?xf64>
+// CHECK:                 %[[VAL_57:.*]] = arith.mulf %[[VAL_47]], %[[VAL_56]] : f64
+// CHECK:                 %[[VAL_58:.*]] = arith.addf %[[VAL_55]], %[[VAL_57]] : f64
+// CHECK:                 %[[VAL_59:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_54]]] : memref<4xi1>
+// CHECK:                 %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_59]], %[[VAL_12]] : i1
+// CHECK:                 %[[VAL_61:.*]] = scf.if %[[VAL_60]] -> (index) {
+// CHECK:                   memref.store %[[VAL_13]], %[[VAL_31]]{{\[}}%[[VAL_54]]] : memref<4xi1>
+// CHECK:                   memref.store %[[VAL_54]], %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref<4xindex>
+// CHECK:                   %[[VAL_62:.*]] = arith.addi %[[VAL_53]], %[[VAL_11]] : index
+// CHECK:                   scf.yield %[[VAL_62]] : index
 // CHECK:                 } else {
-// CHECK:                   scf.yield %[[VAL_59]] : index
+// CHECK:                   scf.yield %[[VAL_53]] : index
 // CHECK:                 }
-// CHECK:                 memref.store %[[VAL_64]], %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64>
-// CHECK:                 scf.yield %[[VAL_69:.*]] : index
+// CHECK:                 memref.store %[[VAL_58]], %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xf64>
+// CHECK:                 scf.yield %[[VAL_63:.*]] : index
 // CHECK:               } {"Emitted from" = "linalg.generic"}
-// CHECK:               scf.yield %[[VAL_70:.*]] : index
+// CHECK:               scf.yield %[[VAL_64:.*]] : index
 // CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:             sparse_tensor.sort hybrid_quick_sort %[[VAL_71:.*]], %[[VAL_39]] : memref<?xindex>
-// CHECK:             %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK:               %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex>
-// CHECK:               %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
-// CHECK:               %[[VAL_80:.*]]:4 = func.call @_insert_dense_compressed_4_4_f64_0_0(%[[VAL_74]], %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_41]], %[[VAL_78]], %[[VAL_79]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK:               memref.store %[[VAL_10]], %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
-// CHECK:               memref.store %[[VAL_13]], %[[VAL_37]]{{\[}}%[[VAL_78]]] : memref<4xi1>
-// CHECK:               scf.yield %[[VAL_80]]#0, %[[VAL_80]]#1, %[[VAL_80]]#2, %[[VAL_80]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK:             sparse_tensor.sort  hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref<?xindex>
+// CHECK:             %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK:               %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex>
+// CHECK:               %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>
+// CHECK:               %[[VAL_74:.*]]:4 = func.call @_insert_dense_compressed_4_4_f64_0_0(%[[VAL_68]], %[[VAL_69]], %[[VAL_70]], %[[VAL_71]], %[[VAL_35]], %[[VAL_72]], %[[VAL_73]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifie
+// CHECK:               memref.store %[[VAL_9]], %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>
+// CHECK:               memref.store %[[VAL_12]], %[[VAL_31]]{{\[}}%[[VAL_72]]] : memref<4xi1>
+// CHECK:               scf.yield %[[VAL_74]]#0, %[[VAL_74]]#1, %[[VAL_74]]#2, %[[VAL_74]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 // CHECK:             }
-// CHECK:             scf.yield %[[VAL_81:.*]]#0, %[[VAL_81]]#1, %[[VAL_81]]#2, %[[VAL_81]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK:             scf.yield %[[VAL_75:.*]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 // CHECK:           } {"Emitted from" = "linalg.generic"}
-// CHECK:           memref.dealloc %[[VAL_36]] : memref<4xf64>
-// CHECK:           memref.dealloc %[[VAL_37]] : memref<4xi1>
-// CHECK:           memref.dealloc %[[VAL_38]] : memref<4xindex>
-// CHECK:           %[[VAL_82:.*]] = sparse_tensor.storage_specifier.get %[[VAL_83:.*]]#3  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_84:.*]] = arith.index_cast %[[VAL_82]] : i64 to index
-// CHECK:           %[[VAL_85:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK:           %[[VAL_86:.*]] = scf.for %[[VAL_87:.*]] = %[[VAL_12]] to %[[VAL_84]] step %[[VAL_12]] iter_args(%[[VAL_88:.*]] = %[[VAL_85]]) -> (index) {
-// CHECK:             %[[VAL_89:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref<?xindex>
-// CHECK:             %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_89]], %[[VAL_11]] : index
-// CHECK:             %[[VAL_91:.*]] = arith.select %[[VAL_90]], %[[VAL_88]], %[[VAL_89]] : index
-// CHECK:             scf.if %[[VAL_90]] {
-// CHECK:               memref.store %[[VAL_88]], %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref<?xindex>
+// CHECK:           memref.dealloc %[[VAL_30]] : memref<4xf64>
+// CHECK:           memref.dealloc %[[VAL_31]] : memref<4xi1>
+// CHECK:           memref.dealloc %[[VAL_32]] : memref<4xindex>
+// CHECK:           %[[VAL_76:.*]] = sparse_tensor.storage_specifier.get %[[VAL_77:.*]]#3  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_78:.*]] = memref.load %[[VAL_77]]#0{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:           %[[VAL_79:.*]] = scf.for %[[VAL_80:.*]] = %[[VAL_11]] to %[[VAL_76]] step %[[VAL_11]] iter_args(%[[VAL_81:.*]] = %[[VAL_78]]) -> (index) {
+// CHECK:             %[[VAL_82:.*]] = memref.load %[[VAL_77]]#0{{\[}}%[[VAL_80]]] : memref<?xindex>
+// CHECK:             %[[VAL_83:.*]] = arith.cmpi eq, %[[VAL_82]], %[[VAL_10]] : index
+// CHECK:             %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_81]], %[[VAL_82]] : index
+// CHECK:             scf.if %[[VAL_83]] {
+// CHECK:               memref.store %[[VAL_81]], %[[VAL_77]]#0{{\[}}%[[VAL_80]]] : memref<?xindex>
 // CHECK:             }
-// CHECK:             scf.yield %[[VAL_91]] : index
+// CHECK:             scf.yield %[[VAL_84]] : index
 // CHECK:           }
-// CHECK:           return %[[VAL_83]]#0, %[[VAL_83]]#1, %[[VAL_83]]#2, %[[VAL_83]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK:           return %[[VAL_77]]#0, %[[VAL_77]]#1, %[[VAL_77]]#2, %[[VAL_77]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 func.func @matmul(%A: tensor<4x8xf64, #CSR>,
                   %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
   %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 057153a20c955..cdfd856b19d93 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -19,16 +19,13 @@
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init :
 // CHECK:           %[[VAL_11:.*]] = arith.constant 6 : index
 // CHECK:           %[[VAL_12:.*]] = arith.constant 100 : index
-// CHECK:           %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i32
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  dim_sz at 0 with %[[VAL_13]] : i32,
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  dim_sz at 0 with %[[VAL_12]]
 // CHECK:           %[[VAL_15:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_16:.*]] = arith.index_cast %[[VAL_15]] : index to i32
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  ptr_mem_sz at 0 with %[[VAL_16]] : i32,
-// CHECK:           %[[VAL_18:.*]] = arith.index_cast %[[VAL_11]] : index to i32
-// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]]  idx_mem_sz at 0 with %[[VAL_18]] : i32,
-// CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  dim_sz at 1 with %[[VAL_13]] : i32,
-// CHECK:           %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]]  idx_mem_sz at 1 with %[[VAL_18]] : i32,
-// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]]  val_mem_sz with %[[VAL_18]] : i32,
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  ptr_mem_sz at 0 with %[[VAL_15]]
+// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]]  idx_mem_sz at 0 with %[[VAL_11]]
+// CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  dim_sz at 1 with %[[VAL_12]]
+// CHECK:           %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]]  idx_mem_sz at 1 with %[[VAL_11]]
+// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]]  val_mem_sz with %[[VAL_11]]
 // CHECK:           return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xi32>, memref<?xf64>,
 // CHECK:         }
 func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
@@ -68,8 +65,7 @@ func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
 // CHECK:           %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
 // CHECK:           %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
 // CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i32 to index
-// CHECK:           return %[[VAL_19]], %[[VAL_21]], %[[VAL_23]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK:           return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index
 // CHECK:         }
 func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
   %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>

diff  --git a/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
index ecdaf3bf9c964..36dce9e417b5f 100644
--- a/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
+++ b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
@@ -16,23 +16,25 @@ func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#CSR> {
 }
 
 // CHECK-LABEL:   func.func @sparse_get_md(
-// CHECK-SAME:      %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> i64 {
+// CHECK-SAME:      %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> index {
 // CHECK:           %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
-// CHECK:           return %[[VAL_1]] : i64
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> i64 {
+// CHECK:           %[[CAST:.*]] = arith.index_cast %[[VAL_1]] : i64 to index
+// CHECK:           return %[[CAST]] : index
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> index {
   %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
-       : !sparse_tensor.storage_specifier<#CSR> to i64
-  return %0 : i64
+       : !sparse_tensor.storage_specifier<#CSR>
+  return %0 : index
 }
 
 // CHECK-LABEL:   func.func @sparse_set_md(
 // CHECK-SAME:      %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>,
-// CHECK-SAME:      %[[VAL_1:.*]]: i64) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
-// CHECK:           %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK-SAME:      %[[VAL_1:.*]]: index) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
+// CHECK:           %[[CAST:.*]] = arith.index_cast %[[VAL_1]] : index to i64
+// CHECK:           %[[VAL_2:.*]] = llvm.insertvalue %[[CAST]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
 // CHECK:           return %[[VAL_2]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
-func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: i64)
+func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: index)
           -> !sparse_tensor.storage_specifier<#CSR> {
   %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
-       : i64, !sparse_tensor.storage_specifier<#CSR>
+       : !sparse_tensor.storage_specifier<#CSR>
   return %0 : !sparse_tensor.storage_specifier<#CSR>
 }


        


More information about the Mlir-commits mailing list