[Mlir-commits] [mlir] 44ff23d - [mlir][sparse] unconditionally use IndexType for sparse_tensor.specifier
Peiming Liu
llvmlistbot at llvm.org
Wed Feb 22 12:21:39 PST 2023
Author: Peiming Liu
Date: 2023-02-22T20:21:34Z
New Revision: 44ff23d5e49058bcaa170f71540398b4a290a642
URL: https://github.com/llvm/llvm-project/commit/44ff23d5e49058bcaa170f71540398b4a290a642
DIFF: https://github.com/llvm/llvm-project/commit/44ff23d5e49058bcaa170f71540398b4a290a642.diff
LOG: [mlir][sparse] unconditionally use IndexType for sparse_tensor.specifier
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D144574
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
mlir/test/Dialect/SparseTensor/fold.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index b719d29be3c4d..4d06cb0b088d0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -365,7 +365,7 @@ def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get"
Arguments<(ins SparseTensorStorageSpecifier:$specifier,
SparseTensorStorageSpecifierKindAttr:$specifierKind,
OptionalAttr<IndexAttr>:$dim)>,
- Results<(outs AnyType:$result)> {
+ Results<(outs Index:$result)> {
let summary = "";
let description = [{
Returns the requested field of the given storage_specifier.
@@ -374,12 +374,12 @@ def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get"
```mlir
%0 = sparse_tensor.storage_specifier.get %arg0 idx_mem_sz at 0
- : !sparse_tensor.storage_specifier<#COO> to i64
+ : !sparse_tensor.storage_specifier<#COO>
```
}];
let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? attr-dict `:` "
- "qualified(type($specifier)) `to` type($result)";
+ "qualified(type($specifier))";
let hasVerifier = 1;
let hasFolder = 1;
}
@@ -389,7 +389,7 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
Arguments<(ins SparseTensorStorageSpecifier:$specifier,
SparseTensorStorageSpecifierKindAttr:$specifierKind,
OptionalAttr<IndexAttr>:$dim,
- AnyType:$value)>,
+ Index:$value)>,
Results<(outs SparseTensorStorageSpecifier:$result)> {
let summary = "";
let description = [{
@@ -400,12 +400,12 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
```mlir
%0 = sparse_tensor.storage_specifier.set %arg0 idx_mem_sz at 0 with %new_sz
- : i32, !sparse_tensor.storage_specifier<#COO>
+ : !sparse_tensor.storage_specifier<#COO>
```
}];
let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? `with` $value attr-dict `:` "
- "type($value) `,` qualified(type($result))";
+ "qualified(type($result))";
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index fc4c5d870d62a..3ae40d625c8a2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -65,13 +65,6 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
}]>
];
- let extraClassDeclaration = [{
- // Get the integer type used to store memory and dimension sizes.
- IntegerType getSizesType() const;
- Type getFieldType(StorageSpecifierKind kind, std::optional<unsigned> dim) const;
- Type getFieldType(StorageSpecifierKind kind, std::optional<APInt> dim) const;
- }];
-
// We skipped the default builder that simply takes the input sparse tensor encoding
// attribute since we need to normalize the dimension level type and remove unrelated
// fields that are irrelavant to sparse tensor storage scheme.
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 207834bacf7e6..9279ec7dddca9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -571,7 +571,11 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
enc.getContext(), dlts,
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
- enc.getPointerBitWidth(), enc.getIndexBitWidth(),
+ // Always use index for memSize, dimSize instead of reusing
+ // getBitwidth from pointers/indices.
+ // It allows us to reuse the same SSA value for
diff erent bitwidth,
+ // It also avoids casting between index/integer (returned by DimOp)
+ 0, 0,
// FIXME: we should keep the slice information, for now it is okay as only
// constant can be used for slice
ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
@@ -582,36 +586,6 @@ StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
}
-IntegerType StorageSpecifierType::getSizesType() const {
- unsigned idxBitWidth =
- getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;
- unsigned ptrBitWidth =
- getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;
-
- return IntegerType::get(getContext(), std::max(idxBitWidth, ptrBitWidth));
-}
-
-// FIXME: see note [CLARIFY_DIM_LVL] in
-// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h"
-Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
- std::optional<unsigned> dim) const {
- if (kind != StorageSpecifierKind::ValMemSize)
- assert(dim);
-
- // Right now, we store every sizes metadata using the same size type.
- // TODO: the field size type can be defined dimensional wise after sparse
- // tensor encoding supports per dimension index/pointer bitwidth.
- return getSizesType();
-}
-
-// FIXME: see note [CLARIFY_DIM_LVL] in
-// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h"
-Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
- std::optional<APInt> dim) const {
- return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue())
- : std::nullopt);
-}
-
//===----------------------------------------------------------------------===//
// SparseTensorDialect Operations.
//===----------------------------------------------------------------------===//
@@ -776,12 +750,6 @@ LogicalResult ToSliceStrideOp::verify() {
LogicalResult GetStorageSpecifierOp::verify() {
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
- // Checks the result type
- if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
- getResult().getType()) {
- return emitError(
- "type mismatch between requested specifier field and result value");
- }
return success();
}
@@ -802,12 +770,6 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
LogicalResult SetStorageSpecifierOp::verify() {
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
- // Checks the input type
- if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
- getValue().getType()) {
- return emitError(
- "type mismatch between requested specifier field and input value");
- }
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 809e9712c752a..cb4eda192d9a9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -204,6 +204,39 @@ StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
// Misc code generators.
//===----------------------------------------------------------------------===//
+Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
+ Type dstTy) {
+ Type srcTy = value.getType();
+ if (srcTy != dstTy) {
+ // int <=> index
+ if (dstTy.isa<IndexType>() || srcTy.isa<IndexType>())
+ return builder.create<arith::IndexCastOp>(loc, dstTy, value);
+
+ bool ext = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
+
+ // float => float.
+ if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && ext)
+ return builder.create<arith::ExtFOp>(loc, dstTy, value);
+
+ if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !ext)
+ return builder.create<arith::TruncFOp>(loc, dstTy, value);
+
+ // int => int
+ if (srcTy.isUnsignedInteger() && dstTy.isa<IntegerType>() && ext)
+ return builder.create<arith::ExtUIOp>(loc, dstTy, value);
+
+ if (srcTy.isSignedInteger() && dstTy.isa<IntegerType>() && ext)
+ return builder.create<arith::ExtSIOp>(loc, dstTy, value);
+
+ if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !ext)
+ return builder.create<arith::TruncIOp>(loc, dstTy, value);
+
+ llvm_unreachable("unhandled type casting");
+ }
+
+ return value;
+}
+
mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
if (tp.isa<FloatType>())
return builder.getFloatAttr(tp, 1.0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 2624d5c826f4a..de78010e57f22 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -78,6 +78,9 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
// Misc code generators and utilities.
//===----------------------------------------------------------------------===//
+/// Add type casting between arith and index types when needed.
+Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
+
/// Generates a 1-valued attribute of the given type. This supports
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
/// for unsupported types we raise `llvm_unreachable` rather than
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index d634a1d1f5377..caa4dd5d722f6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -27,7 +27,9 @@ static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
const Level lvlRank = enc.getLvlRank();
SmallVector<Type, 2> result;
- auto indexType = tp.getSizesType();
+ // TODO: how can we get the lowering type for index type in the later pipeline
+ // to be consistent? LLVM::StructureType does not allow index fields.
+ auto indexType = IntegerType::get(tp.getContext(), 64);
auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, lvlRank);
auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
getNumDataFieldsFromEncoding(enc));
@@ -49,6 +51,21 @@ constexpr uint64_t kDimSizePosInSpecifier = 0;
constexpr uint64_t kMemSizePosInSpecifier = 1;
class SpecifierStructBuilder : public StructBuilder {
+private:
+ Value extractField(OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> indices) {
+ return genCast(builder, loc,
+ builder.create<LLVM::ExtractValueOp>(loc, value, indices),
+ builder.getIndexType());
+ }
+
+ void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
+ Value v) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
+ indices);
+ }
+
public:
explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
assert(value);
@@ -83,29 +100,30 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
/// Builds IR inserting the pos-th size into the descriptor.
Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
unsigned dim) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
+ return extractField(builder, loc,
+ ArrayRef<int64_t>{kDimSizePosInSpecifier, dim});
}
/// Builds IR inserting the pos-th size into the descriptor.
void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
unsigned dim, Value size) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, size, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
+
+ insertField(builder, loc, ArrayRef<int64_t>{kDimSizePosInSpecifier, dim},
+ size);
}
/// Builds IR extracting the pos-th memory size into the descriptor.
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
+ return extractField(builder, loc,
+ ArrayRef<int64_t>{kMemSizePosInSpecifier, pos});
}
/// Builds IR inserting the pos-th memory size into the descriptor.
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
unsigned pos, Value size) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
+ insertField(builder, loc, ArrayRef<int64_t>{kMemSizePosInSpecifier, pos},
+ size);
}
} // namespace
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 335f743e2db3d..fc1ea386b4223 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -67,25 +67,18 @@ static void flattenOperands(ValueRange operands,
}
}
-/// Adds index conversions where needed.
-static Value toType(OpBuilder &builder, Location loc, Value value, Type tp) {
- if (value.getType() != tp)
- return builder.create<arith::IndexCastOp>(loc, tp, value);
- return value;
-}
-
/// Generates a load with proper index typing.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
- idx = toType(builder, loc, idx, builder.getIndexType());
+ idx = genCast(builder, loc, idx, builder.getIndexType());
return builder.create<memref::LoadOp>(loc, mem, idx);
}
/// Generates a store with proper index typing and (for indices) proper value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
Value idx) {
- idx = toType(builder, loc, idx, builder.getIndexType());
- val = toType(builder, loc, val,
- mem.getType().cast<ShapedType>().getElementType());
+ idx = genCast(builder, loc, idx, builder.getIndexType());
+ val = genCast(builder, loc, val,
+ mem.getType().cast<ShapedType>().getElementType());
builder.create<memref::StoreOp>(loc, val, mem, idx);
}
@@ -141,7 +134,7 @@ static void createPushback(OpBuilder &builder, Location loc,
auto pushBackOp = builder.create<PushBackOp>(
loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
- toType(builder, loc, value, etp), repeat);
+ genCast(builder, loc, value, etp), repeat);
desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
desc.setSpecifierField(builder, loc, specFieldKind, lvl,
@@ -338,7 +331,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
msz = builder.create<arith::DivUIOp>(loc, msz, idxStrideC);
}
Value phim1 = builder.create<arith::SubIOp>(
- loc, toType(builder, loc, phi, indexType), one);
+ loc, genCast(builder, loc, phi, indexType), one);
// Conditional expression.
Value lt =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, plo, phi);
@@ -350,9 +343,9 @@ static Value genCompressed(OpBuilder &builder, Location loc,
builder, loc, desc.getMemRefField(idxIndex),
idxStride > 1 ? builder.create<arith::MulIOp>(loc, phim1, idxStrideC)
: phim1);
- Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- toType(builder, loc, crd, indexType),
- indices[lvl]);
+ Value eq = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
+ indices[lvl]);
builder.create<scf::YieldOp>(loc, eq);
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
if (lvl > 0)
@@ -1226,8 +1219,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// Converts MemRefs back to Tensors.
Value data = rewriter.create<bufferization::ToTensorOp>(loc, dataBuf);
Value indices = rewriter.create<bufferization::ToTensorOp>(loc, idxBuf);
- Value nnz = toType(rewriter, loc, desc.getValMemSize(rewriter, loc),
- op.getNnz().getType());
+ Value nnz = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
+ op.getNnz().getType());
rewriter.replaceOp(op, {data, indices, nnz});
return success();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index be59ba83f0f4b..b3336b5fc6ae2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -22,13 +22,6 @@ using namespace sparse_tensor;
// Private helper methods.
//===----------------------------------------------------------------------===//
-static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
- Type to) {
- if (value.getType() != to)
- return builder.create<arith::IndexCastOp>(loc, to, value);
- return value;
-}
-
static IntegerAttr fromOptionalInt(MLIRContext *ctx,
std::optional<unsigned> dim) {
if (!dim)
@@ -90,20 +83,17 @@ Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind,
std::optional<unsigned> dim) {
- return createIndexCast(builder, loc,
- builder.create<GetStorageSpecifierOp>(
- loc, getFieldType(kind, dim), specifier, kind,
- fromOptionalInt(specifier.getContext(), dim)),
- builder.getIndexType());
+ return builder.create<GetStorageSpecifierOp>(
+ loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim));
}
void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
Value v,
StorageSpecifierKind kind,
std::optional<unsigned> dim) {
+ assert(v.getType().isIndex());
specifier = builder.create<SetStorageSpecifierOp>(
- loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim),
- createIndexCast(builder, loc, v, getFieldType(kind, dim)));
+ loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim), v);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 40207d756425e..c30a15d87baca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -241,11 +241,6 @@ class SparseTensorSpecifier {
StorageSpecifierKind kind,
std::optional<unsigned> dim);
- // FIXME: see note [CLARIFY_DIM_LVL].
- Type getFieldType(StorageSpecifierKind kind, std::optional<unsigned> dim) {
- return specifier.getType().getFieldType(kind, dim);
- }
-
private:
TypedValue<StorageSpecifierType> specifier;
};
@@ -283,6 +278,8 @@ class SparseTensorDescriptorImpl {
/// Getters: get the value for required field.
///
+ Value getSpecifier() const { return fields.back(); }
+
// FIXME: see note [CLARIFY_DIM_LVL].
Value getSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind,
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 081ad5b2cf1e1..4b1beb51713e7 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -190,8 +190,7 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
// CHECK-SAME: %[[A0:.*]]: memref<?xf64>,
// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier
// CHECK: %[[A2:.*]] = sparse_tensor.storage_specifier.get %[[A1]] dim_sz at 2
-// CHECK: %[[A3:.*]] = arith.index_cast %[[A2]] : i64 to index
-// CHECK: return %[[A3]] : index
+// CHECK: return %[[A2]] : index
func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
%c = arith.constant 1 : index
%0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #Dense3D>
@@ -260,8 +259,7 @@ func.func @sparse_values_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xf64> {
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]] idx_mem_sz at 1
-// CHECK: %[[S1:.*]] = arith.index_cast %[[S0]]
-// CHECK: %[[S2:.*]] = arith.divui %[[S1]], %[[C2]] : index
+// CHECK: %[[S2:.*]] = arith.divui %[[S0]], %[[C2]] : index
// CHECK: %[[R1:.*]] = memref.subview %[[A3]][0] {{\[}}%[[S2]]] [2] : memref<?xindex> to memref<?xindex, strided<[2]>>
// CHECK: %[[R2:.*]] = memref.cast %[[R1]] : memref<?xindex, strided<[2]>> to memref<?xindex, strided<[?], offset: ?>>
// CHECK: return %[[R2]] : memref<?xindex, strided<[?], offset: ?>>
@@ -288,8 +286,7 @@ func.func @sparse_indices_buffer_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<
// CHECK-SAME: %[[A1:.*]]: memref<?xi64>,
// CHECK-SAME: %[[A2:.*]]: memref<?xf64>,
// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
-// CHECK: %[[A4:.*]] = sparse_tensor.storage_specifier.get %[[A3]] val_mem_sz
-// CHECK: %[[NOE:.*]] = arith.index_cast %[[A4]] : i64 to index
+// CHECK: %[[NOE:.*]] = sparse_tensor.storage_specifier.get %[[A3]] val_mem_sz
// CHECK: return %[[NOE]] : index
func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
%0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
@@ -312,8 +309,8 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
// CHECK-LABEL: func.func @sparse_alloc_csc(
// CHECK-SAME: %[[A0:.*]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK: %[[A1:.*]] = arith.constant 10 : i64
-// CHECK: %[[A2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[A1:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[A2:.*]] = arith.constant 0 : index
// CHECK: %[[A3:.*]] = memref.alloc() : memref<16xindex>
// CHECK: %[[A4:.*]] = memref.cast %[[A3]] : memref<16xindex> to memref<?xindex>
// CHECK: %[[A5:.*]] = memref.alloc() : memref<16xindex>
@@ -321,17 +318,13 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
// CHECK: %[[A7:.*]] = memref.alloc() : memref<16xf64>
// CHECK: %[[A8:.*]] = memref.cast %[[A7]] : memref<16xf64> to memref<?xf64>
// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
-// CHECK: %[[A10:.*]] = arith.index_cast %[[A0]] : index to i64
-// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 0 with %[[A10]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[A12:.*]] = sparse_tensor.storage_specifier.set %[[A11]] dim_sz at 1 with %[[A1]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[A13:.*]] = sparse_tensor.storage_specifier.get %[[A12]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK: %[[A14:.*]] = arith.index_cast %[[A13]] : i64 to index
-// CHECK: %[[A15:.*]], %[[A16:.*]] = sparse_tensor.push_back %[[A14]], %[[A4]], %[[A2]] : index, memref<?xindex>, index
-// CHECK: %[[A17:.*]] = arith.index_cast %[[A16]] : index to i64
-// CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A12]] ptr_mem_sz at 1 with %[[A17]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[A23:.*]], %[[A24:.*]] = sparse_tensor.push_back %[[A16]], %[[A15]], %[[A2]], %[[A0]] : index, memref<?xindex>, index, index
-// CHECK: %[[A25:.*]] = arith.index_cast %[[A24]] : index to i64
-// CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]] ptr_mem_sz at 1 with %[[A25]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 0 with %[[A0]] : !sparse_tensor.storage_specifier
+// CHECK: %[[A12:.*]] = sparse_tensor.storage_specifier.set %[[A11]] dim_sz at 1 with %[[A1]] : !sparse_tensor.storage_specifier
+// CHECK: %[[A14:.*]] = sparse_tensor.storage_specifier.get %[[A12]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[A15:.*]], %[[A17:.*]] = sparse_tensor.push_back %[[A14]], %[[A4]], %[[A2]] : index, memref<?xindex>, index
+// CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A12]] ptr_mem_sz at 1 with %[[A17]] : !sparse_tensor.storage_specifier
+// CHECK: %[[A23:.*]], %[[A25:.*]] = sparse_tensor.push_back %[[A17]], %[[A15]], %[[A2]], %[[A0]] : index, memref<?xindex>, index, index
+// CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]] ptr_mem_sz at 1 with %[[A25]] : !sparse_tensor.storage_specifier
// CHECK: return %[[A23]], %[[A6]], %[[A8]], %[[A26]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
@@ -340,23 +333,21 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
}
// CHECK-LABEL: func.func @sparse_alloc_3d() -> (memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK: %[[A0:.*]] = arith.constant 6000 : index
-// CHECK: %[[A1:.*]] = arith.constant 20 : i64
-// CHECK: %[[A2:.*]] = arith.constant 10 : i64
-// CHECK: %[[A3:.*]] = arith.constant 30 : i64
-// CHECK: %[[A4:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[A0:.*]] = arith.constant 6000 : index
+// CHECK-DAG: %[[A1:.*]] = arith.constant 20 : index
+// CHECK-DAG: %[[A2:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[A3:.*]] = arith.constant 30 : index
+// CHECK-DAG: %[[A4:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[A5:.*]] = memref.alloc() : memref<6000xf64>
// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<6000xf64> to memref<?xf64>
// CHECK: %[[A7:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
-// CHECK: %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]] dim_sz at 0 with %[[A3]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]] dim_sz at 1 with %[[A2]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[A10:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 2 with %[[A1]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.get %[[A10]] val_mem_sz : !sparse_tensor.storage_specifier
-// CHECK: %[[A12:.*]] = arith.index_cast %[[A11]] : i64 to index
-// CHECK: %[[A13:.*]], %[[A14:.*]] = sparse_tensor.push_back %[[A12]], %[[A6]], %[[A4]], %[[A0]] : index, memref<?xf64>, f64, index
-// CHECK: %[[A15:.*]] = arith.index_cast %[[A14]] : index to i64
-// CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]] val_mem_sz with %[[A15]] : i64, !sparse_tensor.storage_specifier
-// CHECK: return %[[A13]], %[[A16]] : memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]] dim_sz at 0 with %[[A3]] : !sparse_tensor.storage_specifier
+// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]] dim_sz at 1 with %[[A2]] : !sparse_tensor.storage_specifier
+// CHECK: %[[A10:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 2 with %[[A1]] : !sparse_tensor.storage_specifier
+// CHECK: %[[A12:.*]] = sparse_tensor.storage_specifier.get %[[A10]] val_mem_sz : !sparse_tensor.storage_specifier
+// CHECK: %[[A15:.*]], %[[A14:.*]] = sparse_tensor.push_back %[[A12]], %[[A6]], %[[A4]], %[[A0]] : index, memref<?xf64>, f64, index
+// CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]] val_mem_sz with %[[A14]] : !sparse_tensor.storage_specifier
+// CHECK: return %[[A15]], %[[A16]] : memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
%0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
%1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
@@ -503,8 +494,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
// CHECK: memref.dealloc %[[A4]] : memref<?xf64>
// CHECK: memref.dealloc %[[A5]] : memref<?xi1>
// CHECK: memref.dealloc %[[A6]] : memref<?xindex>
-// CHECK: %[[A23:.*]] = sparse_tensor.storage_specifier.get %[[A24:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK: %[[A25:.*]] = arith.index_cast %[[A23]] : i64 to index
+// CHECK: %[[A25:.*]] = sparse_tensor.storage_specifier.get %[[A24:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
// CHECK: %[[A26:.*]] = memref.load %[[A24]]#0{{\[}}%[[A13]]] : memref<?xi32>
// CHECK: %[[A27:.*]] = scf.for %[[A28:.*]] = %[[A12]] to %[[A25]] step %[[A12]] iter_args(%[[A29:.*]] = %[[A26]]) -> (i32) {
// CHECK: %[[A30:.*]] = memref.load %[[A24]]#0{{\[}}%[[A28]]] : memref<?xi32>
@@ -562,8 +552,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
// CHECK: memref.dealloc %[[A4]] : memref<?xf64>
// CHECK: memref.dealloc %[[A5]] : memref<?xi1>
// CHECK: memref.dealloc %[[A6]] : memref<?xindex>
-// CHECK: %[[A22:.*]] = sparse_tensor.storage_specifier.get %[[A23:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK: %[[A24:.*]] = arith.index_cast %[[A22]] : i64 to index
+// CHECK: %[[A24:.*]] = sparse_tensor.storage_specifier.get %[[A23:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
// CHECK: %[[A25:.*]] = memref.load %[[A23]]#0{{\[}}%[[A11]]] : memref<?xindex>
// CHECK: %[[A26:.*]] = scf.for %[[A27:.*]] = %[[A12]] to %[[A24]] step %[[A12]] iter_args(%[[A28:.*]] = %[[A25]]) -> (index) {
// CHECK: %[[A29:.*]] = memref.load %[[A23]]#0{{\[}}%[[A27]]] : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
index 33bbe6a71ad07..7a0d668082d52 100644
--- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir
@@ -17,16 +17,12 @@
// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<16xf64> to memref<?xf64>
// CHECK: linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_8]] : memref<16xf64>)
// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_0]] : index to i64
-// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_11]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_12]] ptr_mem_sz at 0 : !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i64 to index
-// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = sparse_tensor.push_back %[[VAL_14]], %[[VAL_5]], %[[VAL_3]] : index, memref<?xindex>, index
-// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_16]] : index to i64
-// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] ptr_mem_sz at 0 with %[[VAL_17]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_19:.*]], %[[VAL_20:.*]] = sparse_tensor.push_back %[[VAL_16]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref<?xindex>, index, index
-// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_20]] : index to i64
-// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] ptr_mem_sz at 0 with %[[VAL_21]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_0]] : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.get %[[VAL_12]] ptr_mem_sz at 0 : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_15:.*]], %[[VAL_17:.*]] = sparse_tensor.push_back %[[VAL_14]], %[[VAL_5]], %[[VAL_3]] : index, memref<?xindex>, index
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] ptr_mem_sz at 0 with %[[VAL_17]] : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_19:.*]], %[[VAL_21:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref<?xindex>, index, index
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] ptr_mem_sz at 0 with %[[VAL_21]] : !sparse_tensor.storage_specifier
// CHECK: return %[[VAL_19]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor<?xf64, #SV> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<?xf64, #SV>
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 7397c0b22958a..3e559109189fe 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -48,17 +48,17 @@ func.func @sparse_concat_dce(%arg0: tensor<2xf64, #SparseVector>,
// CHECK-LABEL: func @sparse_get_specifier_dce_fold(
// CHECK-SAME: %[[A0:.*]]: !sparse_tensor.storage_specifier
-// CHECK-SAME: %[[A1:.*]]: i64,
-// CHECK-SAME: %[[A2:.*]]: i64)
+// CHECK-SAME: %[[A1:.*]]: index,
+// CHECK-SAME: %[[A2:.*]]: index)
// CHECK-NOT: sparse_tensor.storage_specifier.set
// CHECK-NOT: sparse_tensor.storage_specifier.get
// CHECK: return %[[A1]]
-func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64, %arg2: i64) -> i64 {
+func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: index, %arg2: index) -> index {
%0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
- : i64, !sparse_tensor.storage_specifier<#SparseVector>
+ : !sparse_tensor.storage_specifier<#SparseVector>
%1 = sparse_tensor.storage_specifier.set %0 ptr_mem_sz at 0 with %arg2
- : i64, !sparse_tensor.storage_specifier<#SparseVector>
+ : !sparse_tensor.storage_specifier<#SparseVector>
%2 = sparse_tensor.storage_specifier.get %1 dim_sz at 0
- : !sparse_tensor.storage_specifier<#SparseVector> to i64
- return %2 : i64
+ : !sparse_tensor.storage_specifier<#SparseVector>
+ return %2 : index
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 8f52d20942f9d..8d0a0e7d69869 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -252,68 +252,44 @@ func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index {
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
// expected-error at +1 {{redundant level argument for querying value memory size}}
%0 = sparse_tensor.storage_specifier.get %arg0 val_mem_sz at 0
- : !sparse_tensor.storage_specifier<#SparseVector> to i64
- return %0 : i64
+ : !sparse_tensor.storage_specifier<#SparseVector>
+ return %0 : index
}
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
// expected-error at +1 {{missing level argument}}
%0 = sparse_tensor.storage_specifier.get %arg0 idx_mem_sz
- : !sparse_tensor.storage_specifier<#SparseVector> to i64
- return %0 : i64
+ : !sparse_tensor.storage_specifier<#SparseVector>
+ return %0 : index
}
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
// expected-error at +1 {{requested level out of bound}}
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 1
- : !sparse_tensor.storage_specifier<#SparseVector> to i64
- return %0 : i64
+ : !sparse_tensor.storage_specifier<#SparseVector>
+ return %0 : index
}
// -----
#COO = #sparse_tensor.encoding<{dimLevelType = ["compressed-nu", "singleton"]}>
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> i64 {
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> index {
// expected-error at +1 {{requested pointer memory size on a singleton level}}
%0 = sparse_tensor.storage_specifier.get %arg0 ptr_mem_sz at 1
- : !sparse_tensor.storage_specifier<#COO> to i64
- return %0 : i64
-}
-
-// -----
-
-#COO = #sparse_tensor.encoding<{dimLevelType = ["compressed-nu", "singleton"]}>
-
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#COO>) -> i64 {
- // expected-error at +1 {{type mismatch between requested }}
- %0 = sparse_tensor.storage_specifier.get %arg0 ptr_mem_sz at 0
- : !sparse_tensor.storage_specifier<#COO> to i32
- return %0 : i32
-}
-
-// -----
-
-#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
-
-func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>,
- %arg1: i32)
- -> !sparse_tensor.storage_specifier<#SparseVector> {
- // expected-error at +1 {{type mismatch between requested }}
- %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
- : i32, !sparse_tensor.storage_specifier<#SparseVector>
- return %0 : !sparse_tensor.storage_specifier<#SparseVector>
+ : !sparse_tensor.storage_specifier<#COO>
+ return %0 : index
}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 70c9a9862d533..608b6c80e24f3 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -184,11 +184,11 @@ func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#SparseVec
// CHECK-LABEL: func @sparse_get_md(
// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>
// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.get %[[A]] dim_sz at 0
-// CHECK: return %[[T]] : i64
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+// CHECK: return %[[T]] : index
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> index {
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
- : !sparse_tensor.storage_specifier<#SparseVector> to i64
- return %0 : i64
+ : !sparse_tensor.storage_specifier<#SparseVector>
+ return %0 : index
}
// -----
@@ -197,13 +197,13 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
// CHECK-LABEL: func @sparse_set_md(
// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>,
-// CHECK-SAME: %[[I:.*]]: i64)
+// CHECK-SAME: %[[I:.*]]: index)
// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.set %[[A]] dim_sz at 0 with %[[I]]
// CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}>
-func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64)
+func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: index)
-> !sparse_tensor.storage_specifier<#SparseVector> {
%0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
- : i64, !sparse_tensor.storage_specifier<#SparseVector>
+ : !sparse_tensor.storage_specifier<#SparseVector>
return %0 : !sparse_tensor.storage_specifier<#SparseVector>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 70a5fa1338ad9..aaa7b581675d8 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -25,8 +25,7 @@
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] idx_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i64 to index
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] idx_mem_sz at 1 : !sparse_tensor.storage_specifier
// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index
// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index
// CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
@@ -42,16 +41,13 @@
// CHECK: } else {
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index
// CHECK: memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref<?xindex>, index
-// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : index to i64
-// CHECK: %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]] idx_mem_sz at 1 with %[[VAL_24]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_22:.*]], %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref<?xindex>, index
+// CHECK: %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]] idx_mem_sz at 1 with %[[VAL_24]] : !sparse_tensor.storage_specifier
// CHECK: scf.yield %[[VAL_22]], %[[VAL_25]] : memref<?xindex>, !sparse_tensor.storage_specifier
// CHECK: }
-// CHECK: %[[VAL_26:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1 val_mem_sz : !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_26]] : i64 to index
+// CHECK: %[[VAL_28:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1 val_mem_sz : !sparse_tensor.storage_specifier
// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]] = sparse_tensor.push_back %[[VAL_28]], %[[VAL_2]], %[[VAL_6]] : index, memref<?xf64>, f64
-// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : index to i64
-// CHECK: %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1 val_mem_sz with %[[VAL_31]] : i64, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1 val_mem_sz with %[[VAL_30]] : !sparse_tensor.storage_specifier
// CHECK: return %[[VAL_0]], %[[VAL_27]]#0, %[[VAL_29]], %[[VAL_32]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: }
@@ -64,94 +60,89 @@
// CHECK-SAME: %[[VAL_5:.*5]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_6:.*6]]: memref<?xf64>,
// CHECK-SAME: %[[VAL_7:.*7]]: !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_8:.*]] = arith.constant 4 : index
-// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64
-// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK: %[[VAL_11:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_13:.*]] = arith.constant false
-// CHECK: %[[VAL_14:.*]] = arith.constant true
-// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<16xindex>
-// CHECK: %[[VAL_16:.*]] = memref.cast %[[VAL_15]] : memref<16xindex> to memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<16xindex>
-// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<16xindex> to memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<16xf64>
-// CHECK: %[[VAL_20:.*]] = memref.cast %[[VAL_19]] : memref<16xf64> to memref<?xf64>
-// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] dim_sz at 0 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]] dim_sz at 1 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_24:.*]] = sparse_tensor.storage_specifier.get %[[VAL_23]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_25:.*]] = arith.index_cast %[[VAL_24]] : i64 to index
-// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_16]], %[[VAL_11]] : index, memref<?xindex>, index
-// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_27]] : index to i64
-// CHECK: %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]] ptr_mem_sz at 1 with %[[VAL_28]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = sparse_tensor.push_back %[[VAL_27]], %[[VAL_26]], %[[VAL_11]], %[[VAL_8]] : index, memref<?xindex>, index, index
-// CHECK: %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i64
-// CHECK: %[[VAL_35:.*]] = sparse_tensor.storage_specifier.set %[[VAL_29]] ptr_mem_sz at 1 with %[[VAL_34]] : i64, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_36:.*]] = memref.alloc() : memref<4xf64>
-// CHECK: %[[VAL_37:.*]] = memref.alloc() : memref<4xi1>
-// CHECK: %[[VAL_38:.*]] = memref.alloc() : memref<4xindex>
-// CHECK: %[[VAL_39:.*]] = memref.cast %[[VAL_38]] : memref<4xindex> to memref<?xindex>
-// CHECK: linalg.fill ins(%[[VAL_10]] : f64) outs(%[[VAL_36]] : memref<4xf64>)
-// CHECK: linalg.fill ins(%[[VAL_13]] : i1) outs(%[[VAL_37]] : memref<4xi1>)
-// CHECK: %[[VAL_40:.*]]:4 = scf.for %[[VAL_41:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_12]] iter_args(%[[VAL_42:.*]] = %[[VAL_32]], %[[VAL_43:.*]] = %[[VAL_18]], %[[VAL_44:.*]] = %[[VAL_20]], %[[VAL_45:.*]] = %[[VAL_35]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_41]], %[[VAL_12]] : index
-// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_47]]] : memref<?xindex>
-// CHECK: %[[VAL_49:.*]] = scf.for %[[VAL_50:.*]] = %[[VAL_46]] to %[[VAL_48]] step %[[VAL_12]] iter_args(%[[VAL_51:.*]] = %[[VAL_11]]) -> (index) {
-// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_50]]] : memref<?xindex>
-// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_50]]] : memref<?xf64>
-// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_52]]] : memref<?xindex>
-// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_52]], %[[VAL_12]] : index
-// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_12]] iter_args(%[[VAL_59:.*]] = %[[VAL_51]]) -> (index) {
-// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_58]]] : memref<?xindex>
-// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64>
-// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_58]]] : memref<?xf64>
-// CHECK: %[[VAL_63:.*]] = arith.mulf %[[VAL_53]], %[[VAL_62]] : f64
-// CHECK: %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f64
-// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1>
-// CHECK: %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_13]] : i1
-// CHECK: %[[VAL_67:.*]] = scf.if %[[VAL_66]] -> (index) {
-// CHECK: memref.store %[[VAL_14]], %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1>
-// CHECK: memref.store %[[VAL_60]], %[[VAL_38]]{{\[}}%[[VAL_59]]] : memref<4xindex>
-// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_59]], %[[VAL_12]] : index
-// CHECK: scf.yield %[[VAL_68]] : index
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_12:.*]] = arith.constant false
+// CHECK-DAG: %[[VAL_13:.*]] = arith.constant true
+// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xf64>
+// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xf64> to memref<?xf64>
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] dim_sz at 0 with %[[VAL_8]] : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] dim_sz at 1 with %[[VAL_8]] : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_23:.*]] = sparse_tensor.storage_specifier.get %[[VAL_22]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_23]], %[[VAL_15]], %[[VAL_10]] : index, memref<?xindex>, index
+// CHECK: %[[VAL_26:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]] ptr_mem_sz at 1 with %[[VAL_25]] : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_27:.*]], %[[VAL_28:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_24]], %[[VAL_10]], %[[VAL_8]] : index, memref<?xindex>, index, index
+// CHECK: %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_26]] ptr_mem_sz at 1 with %[[VAL_28]] : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_30:.*]] = memref.alloc() : memref<4xf64>
+// CHECK: %[[VAL_31:.*]] = m
+// CHECK: %[[VAL_32:.*]] = memref.alloc() : memref<4xindex>
+// CHECK: %[[VAL_33:.*]] = memref.cast %[[VAL_32]] : memref<4xindex> to memref<?xindex>
+// CHECK: linalg.fill ins(%[[VAL_9]] : f64) outs(%[[VAL_30]] : memref<4xf64>)
+// CHECK: linalg.fill ins(%[[VAL_12]] : i1) outs(%[[VAL_31]] : memref<4xi1>)
+// CHECK: %[[VAL_34:.*]]:4 = scf.for %[[VAL_35:.*]] = %[[VAL_10]] to %[[VAL_8]] step %[[VAL_11]] iter_args(%[[VAL_36:.*]] = %[[VAL_27]], %[[VAL_37:.*]] = %[[VAL_17]], %[[VAL_38:.*]] = %[[VAL_19]], %[[VAL_39:.*]] = %[[VAL_29]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_35]], %[[VAL_11]] : index
+// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref<?xindex>
+// CHECK: %[[VAL_43:.*]] = scf.for %[[VAL_44:.*]] = %[[VAL_40]] to %[[VAL_42]] step %[[VAL_11]] iter_args(%[[VAL_45:.*]] = %[[VAL_10]]) -> (index) {
+// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_44]]] : memref<?xf64>
+// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_46]]] : memref<?xindex>
+// CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_46]], %[[VAL_11]] : index
+// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_49]]] : memref<?xindex>
+// CHECK: %[[VAL_51:.*]] = scf.for %[[VAL_52:.*]] = %[[VAL_48]] to %[[VAL_50]] step %[[VAL_11]] iter_args(%[[VAL_53:.*]] = %[[VAL_45]]) -> (index) {
+// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_52]]] : memref<?xindex>
+// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xf64>
+// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_52]]] : memref<?xf64>
+// CHECK: %[[VAL_57:.*]] = arith.mulf %[[VAL_47]], %[[VAL_56]] : f64
+// CHECK: %[[VAL_58:.*]] = arith.addf %[[VAL_55]], %[[VAL_57]] : f64
+// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_54]]] : memref<4xi1>
+// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_59]], %[[VAL_12]] : i1
+// CHECK: %[[VAL_61:.*]] = scf.if %[[VAL_60]] -> (index) {
+// CHECK: memref.store %[[VAL_13]], %[[VAL_31]]{{\[}}%[[VAL_54]]] : memref<4xi1>
+// CHECK: memref.store %[[VAL_54]], %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref<4xindex>
+// CHECK: %[[VAL_62:.*]] = arith.addi %[[VAL_53]], %[[VAL_11]] : index
+// CHECK: scf.yield %[[VAL_62]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_59]] : index
+// CHECK: scf.yield %[[VAL_53]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_64]], %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64>
-// CHECK: scf.yield %[[VAL_69:.*]] : index
+// CHECK: memref.store %[[VAL_58]], %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<4xf64>
+// CHECK: scf.yield %[[VAL_63:.*]] : index
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: scf.yield %[[VAL_70:.*]] : index
+// CHECK: scf.yield %[[VAL_64:.*]] : index
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_71:.*]], %[[VAL_39]] : memref<?xindex>
-// CHECK: %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex>
-// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
-// CHECK: %[[VAL_80:.*]]:4 = func.call @_insert_dense_compressed_4_4_f64_0_0(%[[VAL_74]], %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_41]], %[[VAL_78]], %[[VAL_79]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK: memref.store %[[VAL_10]], %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
-// CHECK: memref.store %[[VAL_13]], %[[VAL_37]]{{\[}}%[[VAL_78]]] : memref<4xi1>
-// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_80]]#1, %[[VAL_80]]#2, %[[VAL_80]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref<?xindex>
+// CHECK: %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex>
+// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>
+// CHECK: %[[VAL_74:.*]]:4 = func.call @_insert_dense_compressed_4_4_f64_0_0(%[[VAL_68]], %[[VAL_69]], %[[VAL_70]], %[[VAL_71]], %[[VAL_35]], %[[VAL_72]], %[[VAL_73]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifie
+// CHECK: memref.store %[[VAL_9]], %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>
+// CHECK: memref.store %[[VAL_12]], %[[VAL_31]]{{\[}}%[[VAL_72]]] : memref<4xi1>
+// CHECK: scf.yield %[[VAL_74]]#0, %[[VAL_74]]#1, %[[VAL_74]]#2, %[[VAL_74]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: }
-// CHECK: scf.yield %[[VAL_81:.*]]#0, %[[VAL_81]]#1, %[[VAL_81]]#2, %[[VAL_81]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: scf.yield %[[VAL_75:.*]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: memref.dealloc %[[VAL_36]] : memref<4xf64>
-// CHECK: memref.dealloc %[[VAL_37]] : memref<4xi1>
-// CHECK: memref.dealloc %[[VAL_38]] : memref<4xindex>
-// CHECK: %[[VAL_82:.*]] = sparse_tensor.storage_specifier.get %[[VAL_83:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_84:.*]] = arith.index_cast %[[VAL_82]] : i64 to index
-// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK: %[[VAL_86:.*]] = scf.for %[[VAL_87:.*]] = %[[VAL_12]] to %[[VAL_84]] step %[[VAL_12]] iter_args(%[[VAL_88:.*]] = %[[VAL_85]]) -> (index) {
-// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref<?xindex>
-// CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_89]], %[[VAL_11]] : index
-// CHECK: %[[VAL_91:.*]] = arith.select %[[VAL_90]], %[[VAL_88]], %[[VAL_89]] : index
-// CHECK: scf.if %[[VAL_90]] {
-// CHECK: memref.store %[[VAL_88]], %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref<?xindex>
+// CHECK: memref.dealloc %[[VAL_30]] : memref<4xf64>
+// CHECK: memref.dealloc %[[VAL_31]] : memref<4xi1>
+// CHECK: memref.dealloc %[[VAL_32]] : memref<4xindex>
+// CHECK: %[[VAL_76:.*]] = sparse_tensor.storage_specifier.get %[[VAL_77:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
+// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_77]]#0{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK: %[[VAL_79:.*]] = scf.for %[[VAL_80:.*]] = %[[VAL_11]] to %[[VAL_76]] step %[[VAL_11]] iter_args(%[[VAL_81:.*]] = %[[VAL_78]]) -> (index) {
+// CHECK: %[[VAL_82:.*]] = memref.load %[[VAL_77]]#0{{\[}}%[[VAL_80]]] : memref<?xindex>
+// CHECK: %[[VAL_83:.*]] = arith.cmpi eq, %[[VAL_82]], %[[VAL_10]] : index
+// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_81]], %[[VAL_82]] : index
+// CHECK: scf.if %[[VAL_83]] {
+// CHECK: memref.store %[[VAL_81]], %[[VAL_77]]#0{{\[}}%[[VAL_80]]] : memref<?xindex>
// CHECK: }
-// CHECK: scf.yield %[[VAL_91]] : index
+// CHECK: scf.yield %[[VAL_84]] : index
// CHECK: }
-// CHECK: return %[[VAL_83]]#0, %[[VAL_83]]#1, %[[VAL_83]]#2, %[[VAL_83]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK: return %[[VAL_77]]#0, %[[VAL_77]]#1, %[[VAL_77]]#2, %[[VAL_77]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @matmul(%A: tensor<4x8xf64, #CSR>,
%B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
%C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 057153a20c955..cdfd856b19d93 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -19,16 +19,13 @@
// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init :
// CHECK: %[[VAL_11:.*]] = arith.constant 6 : index
// CHECK: %[[VAL_12:.*]] = arith.constant 100 : index
-// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i32
-// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_13]] : i32,
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_12]]
// CHECK: %[[VAL_15:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_15]] : index to i32
-// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] ptr_mem_sz at 0 with %[[VAL_16]] : i32,
-// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_11]] : index to i32
-// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] idx_mem_sz at 0 with %[[VAL_18]] : i32,
-// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] dim_sz at 1 with %[[VAL_13]] : i32,
-// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] idx_mem_sz at 1 with %[[VAL_18]] : i32,
-// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] val_mem_sz with %[[VAL_18]] : i32,
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] ptr_mem_sz at 0 with %[[VAL_15]]
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] idx_mem_sz at 0 with %[[VAL_11]]
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] dim_sz at 1 with %[[VAL_12]]
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] idx_mem_sz at 1 with %[[VAL_11]]
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] val_mem_sz with %[[VAL_11]]
// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xi32>, memref<?xf64>,
// CHECK: }
func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
@@ -68,8 +65,7 @@ func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier
-// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i32 to index
-// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_23]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index
// CHECK: }
func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
%d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
diff --git a/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
index ecdaf3bf9c964..36dce9e417b5f 100644
--- a/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
+++ b/mlir/test/Dialect/SparseTensor/specifier_to_llvm.mlir
@@ -16,23 +16,25 @@ func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#CSR> {
}
// CHECK-LABEL: func.func @sparse_get_md(
-// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> i64 {
+// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> index {
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
-// CHECK: return %[[VAL_1]] : i64
-func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> i64 {
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[VAL_1]] : i64 to index
+// CHECK: return %[[CAST]] : index
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> index {
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
- : !sparse_tensor.storage_specifier<#CSR> to i64
- return %0 : i64
+ : !sparse_tensor.storage_specifier<#CSR>
+ return %0 : index
}
// CHECK-LABEL: func.func @sparse_set_md(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>,
-// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
-// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
+// CHECK-SAME: %[[VAL_1:.*]]: index) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[VAL_1]] : index to i64
+// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[CAST]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
// CHECK: return %[[VAL_2]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
-func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: i64)
+func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: index)
-> !sparse_tensor.storage_specifier<#CSR> {
%0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
- : i64, !sparse_tensor.storage_specifier<#CSR>
+ : !sparse_tensor.storage_specifier<#CSR>
return %0 : !sparse_tensor.storage_specifier<#CSR>
}
More information about the Mlir-commits
mailing list