[Mlir-commits] [mlir] a361035 - [mlir][sparse] change memref argument to proper SSA components
Aart Bik
llvmlistbot at llvm.org
Tue Sep 27 16:37:49 PDT 2022
Author: Aart Bik
Date: 2022-09-27T16:37:37-07:00
New Revision: a3610359b51dffaa7aefa8e325e58978c105a8a3
URL: https://github.com/llvm/llvm-project/commit/a3610359b51dffaa7aefa8e325e58978c105a8a3
DIFF: https://github.com/llvm/llvm-project/commit/a3610359b51dffaa7aefa8e325e58978c105a8a3.diff
LOG: [mlir][sparse] change memref argument to proper SSA components
The indices for insert/compress were previously provided as
a memref<?xindex> with proper rank, since that matched the
argument for the runtime support libary better. However, with
proper codegen coming, providing the indices as SSA values
is much cleaner. This also brings the sparse_tensor.insert
closer to unification with tensor.insert, planned in the
longer run.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D134404
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/sparse_expand.mlir
mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
mlir/test/Dialect/SparseTensor/sparse_index.mlir
mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
mlir/test/Dialect/SparseTensor/sparse_out.mlir
mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2d2a0499a7b9b..0a670ec1445a3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -197,14 +197,18 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [NoSideEffect]>,
// as our sparse abstractions evolve.
//===----------------------------------------------------------------------===//
-def SparseTensor_InsertOp : SparseTensor_Op<"insert", []>,
- Arguments<(ins AnySparseTensor:$tensor,
- StridedMemRefRankOf<[Index], [1]>:$indices,
- AnyType:$value)> {
+def SparseTensor_InsertOp : SparseTensor_Op<"insert",
+ [TypesMatchWith<"value type matches element type of tensor",
+ "tensor", "value",
+ "$_self.cast<ShapedType>().getElementType()">]>,
+ Arguments<(ins AnyType:$value,
+ AnySparseTensor:$tensor,
+ Variadic<Index>:$indices)> {
string summary = "Inserts a value into given sparse tensor";
string description = [{
- Inserts the given value at given indices into the underlying sparse
- storage format of the given tensor with the given indices. This
+ Inserts the given value at given indices into the underlying
+ sparse storage format of the given tensor with the given indices.
+ The arity of indices must match the rank of the tensor. This
operation can only be applied when a tensor materializes unintialized
with a `bufferization.alloc_tensor` operation and the final tensor
is constructed with a `load` operation that has the `hasInserts`
@@ -219,17 +223,18 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert", []>,
Note that this operation is "impure" in the sense that its behavior
is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ may be refined over time as our sparse abstractions evolve. In
+ particular, this operation is scheduled to be unified with the
+ dense counterpart `tensor.insert` that has proper SSA semantics.
Example:
```mlir
- sparse_tensor.insert %tensor, %indices, %val
- : tensor<1024x1024xf64, #CSR>, memref<?xindex>, memref<f64>
+ sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
```
}];
- let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
- " type($tensor) `,` type($indices) `,` type($value)";
+ let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict`:` type($tensor)";
+ let hasVerifier = 1;
}
def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
@@ -294,29 +299,31 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
Example:
```mlir
- %values, %filled, %added, %count = sparse_tensor.expand %0
- : tensor<4x4xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ %values, %filled, %added, %count = sparse_tensor.expand %tensor
+ : tensor<4x4xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>
```
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
- " `,` type($filled) `,` type($added) `,` type($count)";
+ " `,` type($filled) `,` type($added)";
}
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
- Arguments<(ins AnySparseTensor:$tensor,
- StridedMemRefRankOf<[Index],[1]>:$indices,
- AnyStridedMemRefOfRank<1>:$values,
+ Arguments<(ins AnyStridedMemRefOfRank<1>:$values,
StridedMemRefRankOf<[I1],[1]>:$filled,
StridedMemRefRankOf<[Index],[1]>:$added,
- Index:$count)> {
+ Index:$count,
+ AnySparseTensor:$tensor,
+ Variadic<Index>:$indices)> {
string summary = "Compressed an access pattern for insertion";
string description = [{
Finishes a single access pattern expansion by moving inserted elements
- into the sparse storage scheme. The values and filled array are reset
- in a *sparse* fashion by only iterating over set elements through an
- indirection using the added array, so that the operations are kept
- proportional to the number of nonzeros. See the 'expand' operation
- for more details.
+ into the sparse storage scheme of the given tensor with the given
+ indices. The arity of indices is one less than the rank of the tensor,
+ with the remainder innermost indices defined through the added array.
+ The values and filled array are reset in a *sparse* fashion by only
+ iterating over set elements through an indirection using the added
+ array, so that the operations are kept proportional to the number of
+ nonzeros. See the `sparse_tensor.expand` operation for more details.
Note that this operation is "impure" in the sense that its behavior
is solely defined by side-effects and not SSA values. The semantics
@@ -325,15 +332,15 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
Example:
```mlir
- sparse_tensor.compress %0, %1, %values, %filled, %added, %2
- : tensor<4x4xf64, #CSR>, memref<?xindex>, memref<?xf64>,
- memref<?xi1>, memref<?xindex>, index
+ sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #CSR>
```
}];
- let assemblyFormat = "$tensor `,` $indices `,` $values `,` $filled `,`"
- " $added `,` $count attr-dict `:` type($tensor) `,`"
- " type($indices) `,` type($values) `,` type($filled) `,`"
- " type($added) `,` type($count)";
+ let assemblyFormat = "$values `,` $filled `,` $added `,` $count"
+ " `into` $tensor `[` $indices `]` attr-dict"
+ " `:` type($values) `,` type($filled) `,` type($added)"
+ " `,` type($tensor)";
+ let hasVerifier = 1;
}
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 867803ceb1ddb..9b625483d4daa 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -453,6 +453,20 @@ LogicalResult ConcatenateOp::verify() {
return success();
}
+LogicalResult InsertOp::verify() {
+ RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
+ if (ttp.getRank() != static_cast<int64_t>(getIndices().size()))
+ return emitOpError("incorrect number of indices");
+ return success();
+}
+
+LogicalResult CompressOp::verify() {
+ RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
+ if (ttp.getRank() != 1 + static_cast<int64_t>(getIndices().size()))
+ return emitOpError("incorrect number of indices");
+ return success();
+}
+
LogicalResult ForeachOp::verify() {
auto t = getTensor().getType().cast<RankedTensorType>();
auto args = getBody()->getArguments();
@@ -471,7 +485,6 @@ LogicalResult ForeachOp::verify() {
emitError(llvm::formatv("Unmatched element type between input tensor and "
"block argument, expected:{0}, got: {1}",
elemTp, valueTp));
-
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index d6fe145e77610..3469eb5613977 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1144,10 +1144,21 @@ class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Note that the current regime only allows for strict lexicographic
- // index order.
- Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
+ // index order. All values are passed by reference through stack
+ // allocated memrefs.
+ Location loc = op->getLoc();
+ auto tp = op.getTensor().getType().cast<RankedTensorType>();
+ auto elemTp = tp.getElementType();
+ unsigned rank = tp.getRank();
+ auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
+ auto vref = genAllocaScalar(rewriter, loc, elemTp);
+ for (unsigned i = 0; i < rank; i++)
+ rewriter.create<memref::StoreOp>(loc, adaptor.getIndices()[i], mref,
+ constantIndex(rewriter, loc, i));
+ rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
- replaceOpWithFuncCall(rewriter, op, name, {}, adaptor.getOperands(),
+ replaceOpWithFuncCall(rewriter, op, name, {},
+ {adaptor.getTensor(), mref, vref},
EmitCInterface::On);
return success();
}
@@ -1212,9 +1223,21 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
// all-zero/false by only iterating over the set elements, so the
// complexity remains proportional to the sparsity of the expanded
// access pattern.
- Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
+ Value values = adaptor.getValues();
+ Value filled = adaptor.getFilled();
+ Value added = adaptor.getAdded();
+ Value count = adaptor.getCount();
+ Value tensor = adaptor.getTensor();
+ auto tp = op.getTensor().getType().cast<RankedTensorType>();
+ Type elemTp = tp.getElementType();
+ unsigned rank = tp.getRank();
+ auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
+ for (unsigned i = 0; i < rank - 1; i++)
+ rewriter.create<memref::StoreOp>(loc, adaptor.getIndices()[i], mref,
+ constantIndex(rewriter, loc, i));
SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
- replaceOpWithFuncCall(rewriter, op, name, {}, adaptor.getOperands(),
+ replaceOpWithFuncCall(rewriter, op, name, {},
+ {tensor, mref, values, filled, added, count},
EmitCInterface::On);
// Deallocate the buffers on exit of the loop nest.
Operation *parent = op;
@@ -1225,9 +1248,9 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
parent = parent->getParentOp())
;
rewriter.setInsertionPointAfter(parent);
- rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[2]);
- rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[3]);
- rewriter.create<memref::DeallocOp>(loc, adaptor.getOperands()[4]);
+ rewriter.create<memref::DeallocOp>(loc, values);
+ rewriter.create<memref::DeallocOp>(loc, filled);
+ rewriter.create<memref::DeallocOp>(loc, added);
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index bce09130e58bb..35845677082dd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -55,14 +55,14 @@ enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom };
// Code generation.
struct CodeGen {
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops,
- OpOperand *op, unsigned nest)
+ OpOperand *op, unsigned nest, std::vector<unsigned> &ts)
: options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
pointers(numTensors, std::vector<Value>(numLoops)),
indices(numTensors, std::vector<Value>(numLoops)),
highs(numTensors, std::vector<Value>(numLoops)),
pidxs(numTensors, std::vector<Value>(numLoops)),
idxs(numTensors, std::vector<Value>(numLoops)), sparseOut(op),
- outerParNest(nest) {}
+ outerParNest(nest), topSort(ts) {}
/// Sparsification options.
SparsificationOptions options;
/// Universal dense indices and upper bounds (by index). The loops array
@@ -89,13 +89,10 @@ struct CodeGen {
Reduction redKind = kNoReduc;
unsigned redCustom = -1u;
// Sparse tensor as output. Implemented either through direct injective
- // insertion in lexicographic index order (where indices are updated
- // in the temporary array `lexIdx`) or through access pattern expansion
+ // insertion in lexicographic index order or through access pattern expansion
// in the innermost loop nest (`expValues` through `expCount`).
OpOperand *sparseOut;
unsigned outerParNest;
- Value lexIdx;
- Value lexVal;
Value expValues;
Value expFilled;
Value expAdded;
@@ -103,6 +100,8 @@ struct CodeGen {
// Current vector length and mask.
unsigned curVecLength = 1;
Value curVecMask;
+ // Topsort (reference should remain in scope).
+ std::vector<unsigned> &topSort;
};
} // namespace
@@ -203,8 +202,8 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
/// Helper method to inspect sparse encodings in the tensor types.
/// Fills the per-dimension sparsity information for all tensors.
/// Returns true if the sparse annotations and affine subscript
-/// expressions of all tensors are admissable. Returns false if
-/// no annotations are found or inadmissable constructs occur.
+/// expressions of all tensors are admissible. Returns false if
+/// no annotations are found or inadmissible constructs occur.
static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
bool annotated = false;
for (OpOperand *t : op.getInputAndOutputOperands()) {
@@ -217,7 +216,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
unsigned tensor = t->getOperandNumber();
AffineExpr a = map.getResult(perm(enc, d));
if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d)))
- return false; // inadmissable affine expression
+ return false; // inadmissible affine expression
}
}
return annotated;
@@ -330,12 +329,16 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
unsigned tensor = t->getOperandNumber();
for (unsigned i = 0; i < n; i++)
if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
- merger.isDimLevelType(tensor, i, DimLvlType::kSingleton))
+ merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
for (unsigned j = 0; j < n; j++)
if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) {
adjM[i][j] = true;
inDegree[j]++;
}
+ } else {
+ assert(merger.isDimLevelType(tensor, i, DimLvlType::kDense) ||
+ merger.isDimLevelType(tensor, i, DimLvlType::kUndef));
+ }
}
}
// Topologically sort the iteration graph to determine loop order.
@@ -351,9 +354,9 @@ static bool isMaterializing(Value val) {
val.getDefiningOp<bufferization::AllocTensorOp>();
}
-/// Returns true when the tensor expression is admissable for codegen.
-/// Since all sparse input tensors are admissable, we just need to check
-/// whether the out tensor in the tensor expression codegen is admissable.
+/// Returns true when the tensor expression is admissible for codegen.
+/// Since all sparse input tensors are admissible, we just need to check
+/// whether the out tensor in the tensor expression codegen is admissible.
/// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
/// nesting depth when a "truly dynamic" sparse tensor output occurs.
static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
@@ -368,7 +371,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
if (!enc)
return true;
// An all-dense annotated "sparse" output tensor becomes a linearized random
- // access 1-dim memref. Also admissable since insertions cannot occur.
+ // access 1-dim memref. Also admissible since insertions cannot occur.
bool allDense = true;
auto iteratorTypes = op.iterator_types().getValue();
unsigned numLoops = iteratorTypes.size();
@@ -377,12 +380,15 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
allDense = false;
break;
+ } else {
+ assert(merger.isDimLevelType(tensor, i, DimLvlType::kDense) ||
+ merger.isDimLevelType(tensor, i, DimLvlType::kUndef));
}
if (allDense)
return true;
// A tensor expression with a sparse output tensor that changes its values
// but not its nonzero structure, an operation called "simply dynamic" in
- // [Bik96,Ch9], is also admissable without special codegen.
+ // [Bik96,Ch9], is also admissible without special codegen.
if (merger.isSingleCondition(tensor, exp))
return true;
// Accept "truly dynamic" if the output tensor materializes uninitialized
@@ -394,9 +400,9 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
break; // terminate at first reduction
nest++;
}
- // Determine admissable dynamic insertion situations:
+ // Determine admissible dynamic insertion situations:
// (1) fully injective, since there are no reductions,
- // (2) admissable 1-d expansion in innermost dimension.
+ // (2) admissible 1-d expansion in innermost dimension.
if (nest >= op.getRank(lhs) - 1) {
*sparseOut = lhs;
outerParNest = nest;
@@ -537,14 +543,13 @@ static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder,
}
/// Local bufferization of all dense and sparse data structures.
-/// This code enables testing the first prototype sparse compiler.
-// TODO: replace this with a proliferated bufferization strategy
static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
linalg::GenericOp op) {
Location loc = op.getLoc();
assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
// For every tensor, find lower and upper bound on dimensions, set the
// same bounds on loop indices, and obtain dense or sparse buffer(s).
+ auto dynShape = {ShapedType::kDynamicSize};
SmallVector<Value, 4> args;
for (OpOperand *t : op.getInputAndOutputOperands()) {
unsigned tensor = t->getOperandNumber();
@@ -558,21 +563,28 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
if (a.getKind() != AffineExprKind::DimId)
continue; // compound
unsigned idx = a.cast<AffineDimExpr>().getPosition();
- // Handle sparse storage schemes.
+ // Handle the
diff erent storage schemes.
if (merger.isDimLevelType(tensor, idx, DimLvlType::kCompressed)) {
- auto dynShape = {ShapedType::kDynamicSize};
+ // Compressed dimension, fetch pointer and indices.
auto ptrTp =
MemRefType::get(dynShape, getPointerOverheadType(builder, enc));
auto indTp =
MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
auto dim = builder.getIndexAttr(d);
- // Generate sparse primitives to obtains pointer and indices.
codegen.pointers[tensor][idx] =
builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
codegen.indices[tensor][idx] =
builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
} else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) {
- llvm_unreachable("TODO: not implemented yet");
+ // Singleton dimension, fetch indices.
+ auto indTp =
+ MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
+ auto dim = builder.getIndexAttr(d);
+ codegen.indices[tensor][idx] =
+ builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
+ } else {
+ // Dense dimension, nothing to fetch.
+ assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense));
}
// Find upper bound in current dimension.
unsigned p = perm(enc, d);
@@ -595,17 +607,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
else
codegen.buffers[tensor] =
genOutputBuffer(codegen, builder, op, denseTp, args);
- } else if (t == codegen.sparseOut) {
- // True sparse output needs a lexIdx array.
- Value rank = constantIndex(builder, loc, op.getRank(t));
- auto dynShape = {ShapedType::kDynamicSize};
- auto memTp = MemRefType::get(dynShape, builder.getIndexType());
- codegen.lexIdx = builder.create<memref::AllocaOp>(loc, memTp, rank);
- codegen.lexVal = builder.create<memref::AllocaOp>(
- loc, MemRefType::get({}, elementType));
- } else {
- // Annotated sparse tensors.
- auto dynShape = {ShapedType::kDynamicSize};
+ } else if (t != codegen.sparseOut) {
+ // Annotated sparse tensors (not involved in output).
auto sparseTp = MemRefType::get(dynShape, elementType);
codegen.buffers[tensor] =
builder.create<ToValuesOp>(loc, sparseTp, t->get());
@@ -802,8 +805,13 @@ static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
Location loc = op.getLoc();
// Direct insertion in lexicographic index order.
if (!codegen.expValues) {
- builder.create<memref::StoreOp>(loc, rhs, codegen.lexVal);
- builder.create<InsertOp>(loc, t->get(), codegen.lexIdx, codegen.lexVal);
+ unsigned rank = op.getRank(t);
+ SmallVector<Value, 4> indices;
+ for (unsigned i = 0; i < rank; i++) {
+ assert(codegen.loops[codegen.topSort[i]]);
+ indices.push_back(codegen.loops[codegen.topSort[i]]);
+ }
+ builder.create<InsertOp>(loc, rhs, t->get(), indices);
return;
}
// Generates insertion code along expanded access pattern.
@@ -1155,9 +1163,14 @@ static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder,
codegen.expCount = res.getResult(3);
} else {
assert(codegen.expValues);
- builder.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues,
- codegen.expFilled, codegen.expAdded,
- codegen.expCount);
+ SmallVector<Value, 4> indices;
+ for (unsigned i = 0; i < at; i++) {
+ assert(codegen.loops[codegen.topSort[i]]);
+ indices.push_back(codegen.loops[codegen.topSort[i]]);
+ }
+ builder.create<CompressOp>(loc, codegen.expValues, codegen.expFilled,
+ codegen.expAdded, codegen.expCount, tensor,
+ indices);
codegen.expValues = codegen.expFilled = codegen.expAdded =
codegen.expCount = Value();
}
@@ -1167,37 +1180,53 @@ static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder,
/// current index level. Returns true if the loop sequence needs to
/// maintain the universal index.
static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, std::vector<unsigned> &topSort,
- unsigned at, BitVector &inits) {
+ linalg::GenericOp op, unsigned at, BitVector &inits) {
+ std::vector<unsigned> &topSort(codegen.topSort);
bool needsUniv = false;
Location loc = op.getLoc();
unsigned idx = topSort[at];
// Initialize sparse positions.
for (unsigned b = 0, be = inits.size(); b < be; b++) {
- if (inits[b]) {
- unsigned tensor = merger.tensor(b);
- assert(idx == merger.index(b));
- if (merger.isDimLevelType(b, DimLvlType::kCompressed)) {
- // Initialize sparse index.
- unsigned pat = at;
- for (; pat != 0; pat--) {
- if (codegen.pidxs[tensor][topSort[pat - 1]])
- break;
- }
- Value ptr = codegen.pointers[tensor][idx];
- Value one = constantIndex(builder, loc, 1);
- Value p0 = (pat == 0) ? constantIndex(builder, loc, 0)
- : codegen.pidxs[tensor][topSort[pat - 1]];
- codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0);
- Value p1 = builder.create<arith::AddIOp>(loc, p0, one);
- codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1);
- } else if (merger.isDimLevelType(b, DimLvlType::kSingleton)) {
- llvm_unreachable("TODO: not implemented yet");
- } else {
- // Dense index still in play.
- needsUniv = true;
+ if (!inits[b])
+ continue;
+ unsigned tensor = merger.tensor(b);
+ assert(idx == merger.index(b));
+ if (merger.isDimLevelType(b, DimLvlType::kCompressed)) {
+ // Initialize sparse index that will implement the iteration:
+ // for pidx_idx = pointers(pidx_idx-1), pointers(1+pidx_idx-1)
+ unsigned pat = at;
+ for (; pat != 0; pat--) {
+ if (codegen.pidxs[tensor][topSort[pat - 1]])
+ break;
+ }
+ Value ptr = codegen.pointers[tensor][idx];
+ Value one = constantIndex(builder, loc, 1);
+ Value p0 = (pat == 0) ? constantIndex(builder, loc, 0)
+ : codegen.pidxs[tensor][topSort[pat - 1]];
+ codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0);
+ Value p1 = builder.create<arith::AddIOp>(loc, p0, one);
+ codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1);
+ } else if (merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+ // Initialize sparse index that will implement the "iteration":
+ // for pidx_idx = pidx_idx-1, 1+pidx_idx-1
+ // We rely on subsequent loop unrolling to get rid of the loop
+ // if it is not involved in co-iteration with anything else.
+ unsigned pat = at;
+ for (; pat != 0; pat--) {
+ if (codegen.pidxs[tensor][topSort[pat - 1]])
+ break;
}
+ Value one = constantIndex(builder, loc, 1);
+ Value p0 = (pat == 0) ? constantIndex(builder, loc, 0)
+ : codegen.pidxs[tensor][topSort[pat - 1]];
+ codegen.pidxs[tensor][idx] = p0;
+ codegen.highs[tensor][idx] = builder.create<arith::AddIOp>(loc, p0, one);
+ } else {
+ assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
+ merger.isDimLevelType(b, DimLvlType::kUndef));
+ // Dense index still in play.
+ needsUniv = true;
}
}
@@ -1284,15 +1313,13 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
assert(idx == merger.index(fb));
auto iteratorTypes = op.iterator_types().getValue();
bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
- bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed);
+ bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed) ||
+ merger.isDimLevelType(fb, DimLvlType::kSingleton);
bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
denseUnitStrides(merger, op, idx);
bool isParallel =
isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
- assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) &&
- "TODO: implement");
-
// Prepare vector length.
if (isVector)
codegen.curVecLength = codegen.options.vectorLength;
@@ -1360,11 +1387,17 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
// Construct the while-loop with a parameter for each index.
Type indexType = builder.getIndexType();
for (unsigned b = 0, be = indices.size(); b < be; b++) {
- if (indices[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
+ if (!indices[b])
+ continue;
+ if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
+ merger.isDimLevelType(b, DimLvlType::kSingleton)) {
unsigned tensor = merger.tensor(b);
assert(idx == merger.index(b));
types.push_back(indexType);
operands.push_back(codegen.pidxs[tensor][idx]);
+ } else {
+ assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
+ merger.isDimLevelType(b, DimLvlType::kUndef));
}
}
if (codegen.redVal) {
@@ -1393,8 +1426,10 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
Value cond;
unsigned o = 0;
for (unsigned b = 0, be = indices.size(); b < be; b++) {
- // TODO: singleton
- if (indices[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
+ if (!indices[b])
+ continue;
+ if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
+ merger.isDimLevelType(b, DimLvlType::kSingleton)) {
unsigned tensor = merger.tensor(b);
assert(idx == merger.index(b));
Value op1 = before->getArgument(o);
@@ -1403,6 +1438,9 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
op1, op2);
cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc;
codegen.pidxs[tensor][idx] = after->getArgument(o++);
+ } else {
+ assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
+ merger.isDimLevelType(b, DimLvlType::kUndef));
}
}
if (codegen.redVal)
@@ -1420,12 +1458,12 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
/// Generates a for-loop or a while-loop, depending on whether it implements
/// singleton iteration or co-iteration over the given conjunction.
static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, std::vector<unsigned> &topSort,
- unsigned at, bool needsUniv, BitVector &indices) {
- unsigned idx = topSort[at];
+ linalg::GenericOp op, unsigned at, bool needsUniv,
+ BitVector &indices) {
+ unsigned idx = codegen.topSort[at];
if (indices.count() == 1) {
bool isOuter = at == 0;
- bool isInner = at == topSort.size() - 1;
+ bool isInner = at == codegen.topSort.size() - 1;
return genFor(merger, codegen, builder, op, isOuter, isInner, idx, indices);
}
return genWhile(merger, codegen, builder, op, idx, needsUniv, indices);
@@ -1434,16 +1472,19 @@ static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder,
/// Generates the local variables for this loop, consisting of the sparse
/// indices, restored universal dense index, and dense positions.
static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, std::vector<unsigned> &topSort,
- unsigned at, bool needsUniv, BitVector &locals) {
+ linalg::GenericOp op, unsigned at, bool needsUniv,
+ BitVector &locals) {
+ std::vector<unsigned> &topSort(codegen.topSort);
Location loc = op.getLoc();
unsigned idx = topSort[at];
// Initialize sparse indices.
Value min;
for (unsigned b = 0, be = locals.size(); b < be; b++) {
- // TODO: singleton
- if (locals[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
+ if (!locals[b])
+ continue;
+ if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
+ merger.isDimLevelType(b, DimLvlType::kSingleton)) {
unsigned tensor = merger.tensor(b);
assert(idx == merger.index(b));
Value ptr = codegen.indices[tensor][idx];
@@ -1459,6 +1500,9 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
min = load;
}
}
+ } else {
+ assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
+ merger.isDimLevelType(b, DimLvlType::kUndef));
}
}
@@ -1486,14 +1530,6 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
codegen, builder, loc, codegen.sizes[idx], p, codegen.loops[idx]);
}
}
-
- // Move the insertion indices in lexicographic index order. During access
- // pattern expansion, we can skip setting the innermost dimension.
- if (codegen.sparseOut && !codegen.expValues) {
- Value pos = constantIndex(builder, loc, at);
- builder.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx,
- pos);
- }
}
/// Generates the induction structure for a while-loop.
@@ -1531,8 +1567,10 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
SmallVector<Value, 4> operands;
Value one = constantIndex(builder, loc, 1);
for (unsigned b = 0, be = induction.size(); b < be; b++) {
- // TODO: singleton
- if (induction[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
+ if (!induction[b])
+ continue;
+ if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
+ merger.isDimLevelType(b, DimLvlType::kSingleton)) {
unsigned tensor = merger.tensor(b);
assert(idx == merger.index(b));
Value op1 = codegen.idxs[tensor][idx];
@@ -1543,6 +1581,9 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
Value add = builder.create<arith::AddIOp>(loc, op3, one);
operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3));
codegen.pidxs[tensor][idx] = whileOp->getResult(o++);
+ } else {
+ assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
+ merger.isDimLevelType(b, DimLvlType::kUndef));
}
}
if (codegen.redVal) {
@@ -1592,21 +1633,23 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
SmallVector<Type, 4> types;
Value cond;
for (unsigned b = 0, be = conditions.size(); b < be; b++) {
- if (conditions[b]) {
- unsigned tensor = merger.tensor(b);
- assert(idx == merger.index(b));
- Value clause;
- // TODO: singleton
- if (merger.isDimLevelType(b, DimLvlType::kCompressed)) {
- Value op1 = codegen.idxs[tensor][idx];
- Value op2 = codegen.loops[idx];
- clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- op1, op2);
- } else {
- clause = constantI1(builder, loc, true);
- }
- cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
+ if (!conditions[b])
+ continue;
+ unsigned tensor = merger.tensor(b);
+ assert(idx == merger.index(b));
+ Value clause;
+ if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
+ merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+ Value op1 = codegen.idxs[tensor][idx];
+ Value op2 = codegen.loops[idx];
+ clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
+ op2);
+ } else {
+ assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
+ merger.isDimLevelType(b, DimLvlType::kUndef));
+ clause = constantI1(builder, loc, true);
}
+ cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
}
if (codegen.redVal)
types.push_back(codegen.redVal.getType());
@@ -1642,9 +1685,8 @@ static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
/// Starts a loop sequence at given level. Returns true if
/// the universal loop index must be maintained at this level.
static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, std::vector<unsigned> &topSort,
- unsigned exp, unsigned at, unsigned idx, unsigned ldx,
- unsigned lts) {
+ linalg::GenericOp op, unsigned exp, unsigned at,
+ unsigned idx, unsigned ldx, unsigned lts) {
assert(codegen.curVecLength == 1);
assert(!codegen.loops[idx]);
// Emit invariants at this loop sequence level.
@@ -1654,7 +1696,7 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
// Emit further intitialization at this loop sequence level.
unsigned l0 = merger.set(lts)[0];
bool needsUniv =
- genInit(merger, codegen, builder, op, topSort, at, merger.lat(l0).bits);
+ genInit(merger, codegen, builder, op, at, merger.lat(l0).bits);
// Maintain the universal index only if it is actually
// consumed by a subsequent lattice point.
if (needsUniv) {
@@ -1671,15 +1713,13 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
/// Starts a single loop in current sequence.
static Operation *startLoop(Merger &merger, CodeGen &codegen,
OpBuilder &builder, linalg::GenericOp op,
- std::vector<unsigned> &topSort, unsigned at,
- unsigned li, bool needsUniv) {
+ unsigned at, unsigned li, bool needsUniv) {
assert(codegen.curVecLength == 1);
// Emit the for/while-loop control.
- Operation *loop = genLoop(merger, codegen, builder, op, topSort, at,
- needsUniv, merger.lat(li).simple);
+ Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv,
+ merger.lat(li).simple);
// Emit the locals for this loop.
- genLocals(merger, codegen, builder, op, topSort, at, needsUniv,
- merger.lat(li).bits);
+ genLocals(merger, codegen, builder, op, at, needsUniv, merger.lat(li).bits);
return loop;
}
@@ -1704,6 +1744,7 @@ static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
linalg::GenericOp op, unsigned exp, unsigned at,
unsigned idx, unsigned ldx) {
assert(codegen.curVecLength == 1);
+ assert(codegen.loops[idx]);
codegen.loops[idx] = Value();
// Bring a pending reduction back from SIMD form when sequence ends.
if (codegen.redVal)
@@ -1720,24 +1761,23 @@ static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
/// to manage the complexity of implementing co-iteration over unions
/// and intersections of sparse iterations spaces.
static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
- linalg::GenericOp op, std::vector<unsigned> &topSort,
- unsigned exp, unsigned at) {
+ linalg::GenericOp op, unsigned exp, unsigned at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
- if (at == topSort.size()) {
- unsigned ldx = topSort[at - 1];
+ if (at == codegen.topSort.size()) {
+ unsigned ldx = codegen.topSort[at - 1];
Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx);
genTensorStore(merger, codegen, rewriter, op, exp, rhs);
return;
}
// Construct iteration lattices for current loop index, with L0 at top.
- unsigned idx = topSort[at];
- unsigned ldx = at == 0 ? -1u : topSort[at - 1];
+ unsigned idx = codegen.topSort[at];
+ unsigned ldx = at == 0 ? -1u : codegen.topSort[at - 1];
unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
// Start a loop sequence.
- bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at,
- idx, ldx, lts);
+ bool needsUniv =
+ startLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx, lts);
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
unsigned lsize = merger.set(lts).size();
@@ -1745,7 +1785,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
// Start a loop.
unsigned li = merger.set(lts)[i];
Operation *loop =
- startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv);
+ startLoop(merger, codegen, rewriter, op, at, li, needsUniv);
// Visit all lattices points with Li >= Lj to generate the
// loop-body, possibly with if statements for coiteration.
@@ -1760,10 +1800,10 @@ static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
if (isWhile) {
scf::IfOp ifOp =
genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
- genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
+ genStmt(merger, codegen, rewriter, op, ej, at + 1);
endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput);
} else {
- genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
+ genStmt(merger, codegen, rewriter, op, ej, at + 1);
}
}
}
@@ -1859,9 +1899,10 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// Recursively generates code if admissible.
merger.setHasSparseOut(sparseOut != nullptr);
- CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest);
+ CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest,
+ topSort);
genBuffers(merger, codegen, rewriter, op);
- genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
+ genStmt(merger, codegen, rewriter, op, exp, 0);
genResult(merger, codegen, rewriter, op);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index bd8c8f2a65b81..35d11c4d8251a 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -855,7 +855,7 @@ Type Merger::inferType(unsigned e, Value src) {
/// Ensures that sparse compiler can generate code for expression.
static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
- // Arguments are always admissable.
+ // Arguments are always admissible.
if (auto arg = v.dyn_cast<BlockArgument>())
return true;
// Accept index anywhere.
@@ -866,7 +866,7 @@ static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
if (def->getBlock() != block)
return def->getBlock() != op->getBlock(); // invariant?
// Operation defined within branch. Anything is accepted,
- // as long as all subexpressions are admissable.
+ // as long as all subexpressions are admissible.
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
if (!isAdmissableBranchExp(op, block, def->getOperand(i)))
return false;
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 207eaa1324d4f..baef485b82384 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -309,7 +309,7 @@ func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
func.func @sparse_expansion1() -> memref<?xindex> {
%0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR>
%values, %filled, %added, %count = sparse_tensor.expand %0
- : tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ : tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
}
@@ -324,7 +324,7 @@ func.func @sparse_expansion1() -> memref<?xindex> {
func.func @sparse_expansion2() -> memref<?xindex> {
%0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC>
%values, %filled, %added, %count = sparse_tensor.expand %0
- : tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ : tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
}
@@ -344,7 +344,7 @@ func.func @sparse_expansion2() -> memref<?xindex> {
func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSC>
%values, %filled, %added, %count = sparse_tensor.expand %0
- : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
}
@@ -354,34 +354,34 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-// CHECK-SAME: %[[A5:.*5]]: memref<?xindex>,
-// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
-// CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
-// CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
+// CHECK-SAME: %[[A8:.*8]]: index,
// CHECK-SAME: %[[A9:.*9]]: index)
// CHECK-DAG: %[[B0:.*]] = arith.constant false
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// TODO: sort
-// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] {
-// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
+// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
+// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
// TODO: insert
-// CHECK-DAG: memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref<?xf64>
-// CHECK-DAG: memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref<?xi1>
+// CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
+// CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
// CHECK-NEXT: }
-// CHECK-DAG: memref.dealloc %[[A6]] : memref<?xf64>
-// CHECK-DAG: memref.dealloc %[[A7]] : memref<?xi1>
-// CHECK-DAG: memref.dealloc %[[A8]] : memref<?xindex>
+// CHECK-DAG: memref.dealloc %[[A5]] : memref<?xf64>
+// CHECK-DAG: memref.dealloc %[[A6]] : memref<?xi1>
+// CHECK-DAG: memref.dealloc %[[A7]] : memref<?xindex>
// CHECK: return
-func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>,
- %arg1: memref<?xindex>,
- %arg2: memref<?xf64>,
- %arg3: memref<?xi1>,
- %arg4: memref<?xindex>,
- %arg5: index) {
- sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
- : tensor<8x8xf64, #CSR>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
+ %values: memref<?xf64>,
+ %filled: memref<?xi1>,
+ %added: memref<?xindex>,
+ %count: index,
+ %i: index) {
+ sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
return
}
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 8c03d37f2bd32..39c606ae409e6 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -487,14 +487,20 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
// CHECK-LABEL: func @sparse_insert(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
-// CHECK-SAME: %[[B:.*]]: memref<?xindex>,
-// CHECK-SAME: %[[C:.*]]: memref<f32>) {
-// CHECK: call @lexInsertF32(%[[A]], %[[B]], %[[C]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f32>) -> ()
+// CHECK-SAME: %[[B:.*]]: index,
+// CHECK-SAME: %[[C:.*]]: f32) {
+// CHECK-DAG: %[[M:.*]] = memref.alloca() : memref<1xindex>
+// CHECK-DAG: %[[V:.*]] = memref.alloca() : memref<f32>
+// CHECK-DAG: %[[MC:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: memref.store %[[B]], %[[M]][%[[C0]]] : memref<1xindex>
+// CHECK-DAG: memref.store %[[C]], %[[V]][] : memref<f32>
+// CHECK: call @lexInsertF32(%[[A]], %[[MC]], %[[V]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f32>) -> ()
// CHECK: return
func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
- %arg1: memref<?xindex>,
- %arg2: memref<f32>) {
- sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf32, #SparseVector>, memref<?xindex>, memref<f32>
+ %arg1: index,
+ %arg2: f32) {
+ sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
return
}
@@ -510,7 +516,7 @@ func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
func.func @sparse_expansion1() -> memref<?xindex> {
%0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR>
%values, %filled, %added, %count = sparse_tensor.expand %0
- : tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ : tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
}
@@ -526,7 +532,7 @@ func.func @sparse_expansion1() -> memref<?xindex> {
func.func @sparse_expansion2() -> memref<?xindex> {
%0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC>
%values, %filled, %added, %count = sparse_tensor.expand %0
- : tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ : tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
}
@@ -543,26 +549,34 @@ func.func @sparse_expansion2() -> memref<?xindex> {
func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSC>
%values, %filled, %added, %count = sparse_tensor.expand %0
- : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
}
// CHECK-LABEL: func @sparse_compression(
// CHECK-SAME: %[[A:.*0]]: !llvm.ptr<i8>,
-// CHECK-SAME: %[[B:.*1]]: memref<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: memref<?xf64>,
-// CHECK-SAME: %[[D:.*3]]: memref<?xi1>,
-// CHECK-SAME: %[[E:.*4]]: memref<?xindex>,
-// CHECK: call @expInsertF64(%[[A]],
-// CHECK-DAG: memref.dealloc %[[C]] : memref<?xf64>
-// CHECK-DAG: memref.dealloc %[[D]] : memref<?xi1>
-// CHECK-DAG: memref.dealloc %[[E]] : memref<?xindex>
+// CHECK-SAME: %[[B:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[C:.*2]]: memref<?xi1>,
+// CHECK-SAME: %[[D:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[E:.*4]]: index,
+// CHECK-SAME: %[[F:.*5]]: index)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[X:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[Y:.*]] = memref.cast %[[X]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[F]], %[[X]][%[[C0]]] : memref<2xindex>
+// CHECK: call @expInsertF64(%[[A]], %[[Y]], %[[B]], %[[C]], %[[D]], %[[E]])
+// CHECK-DAG: memref.dealloc %[[B]] : memref<?xf64>
+// CHECK-DAG: memref.dealloc %[[C]] : memref<?xi1>
+// CHECK-DAG: memref.dealloc %[[D]] : memref<?xindex>
// CHECK: return
-func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>,
- %arg1: memref<?xindex>, %arg2: memref<?xf64>, %arg3: memref<?xi1>,
- %arg4: memref<?xindex>, %arg5: index) {
- sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
- : tensor<8x8xf64, #CSR>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
+ %values: memref<?xf64>,
+ %filled: memref<?xi1>,
+ %added: memref<?xindex>,
+ %count: index,
+ %i: index) {
+ sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
return
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7425d19efb3b8..7ac5179b87bf9 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -106,9 +106,19 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
// -----
-func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>, %arg2: f64) {
- // expected-error at +1 {{'sparse_tensor.insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
- sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
+func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: index, %arg2: f64) {
+ // expected-error at +1 {{'sparse_tensor.insert' 'tensor' must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
+ sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64>
+ return
+}
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+
+func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: index, %arg2: f64) {
+ // expected-error at +1 {{'sparse_tensor.insert' op incorrect number of indices}}
+ sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128x64xf64, #CSR>
return
}
@@ -117,18 +127,38 @@ func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xind
func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
// expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
%values, %filled, %added, %count = sparse_tensor.expand %arg0
- : tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ : tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return
}
// -----
-func.func @sparse_unannotated_compression(%arg0: tensor<128xf64>, %arg1: memref<?xindex>,
- %arg2: memref<?xf64>, %arg3: memref<?xi1>,
- %arg4: memref<?xindex>, %arg5: index) {
- // expected-error at +1 {{'sparse_tensor.compress' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
- sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
- : tensor<128xf64>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+func.func @sparse_unannotated_compression(%arg0: memref<?xf64>,
+ %arg1: memref<?xi1>,
+ %arg2: memref<?xindex>,
+ %arg3: index,
+ %arg4: tensor<8x8xf64>,
+ %arg5: index) {
+ // expected-error at +1 {{'sparse_tensor.compress' op operand #4 must be sparse tensor of any type values, but got 'tensor<8x8xf64>'}}
+ sparse_tensor.compress %arg0, %arg1, %arg2, %arg3 into %arg4[%arg5]
+ : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64>
+ return
+}
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+
+func.func @sparse_wrong_arity_compression(%arg0: memref<?xf64>,
+ %arg1: memref<?xi1>,
+ %arg2: memref<?xindex>,
+ %arg3: index,
+ %arg4: tensor<8x8xf64, #CSR>,
+ %arg5: index) {
+ // expected-error at +1 {{'sparse_tensor.compress' op incorrect number of indices}}
+ sparse_tensor.compress %arg0, %arg1, %arg2, %arg3 into %arg4[%arg5,%arg5]
+ : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
+ return
}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 7059b15905436..e5ffe85284c70 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -121,12 +121,12 @@ func.func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x
// CHECK-LABEL: func @sparse_insert(
// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME: %[[B:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[B:.*]]: index,
// CHECK-SAME: %[[C:.*]]: f64) {
-// CHECK: sparse_tensor.insert %[[A]], %[[B]], %[[C]] : tensor<128xf64, #{{.*}}>, memref<?xindex>, f64
+// CHECK: sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
// CHECK: return
-func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: memref<?xindex>, %arg2: f64) {
- sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf64, #SparseVector>, memref<?xindex>, f64
+func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) {
+ sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
return
}
@@ -149,12 +149,12 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
// CHECK-LABEL: func @sparse_expansion(
// CHECK-SAME: %[[A:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>)
-// CHECK: sparse_tensor.expand %[[A]]
-// CHECK: return
-func.func @sparse_expansion(%arg0: tensor<8x8xf64, #SparseMatrix>) {
- %values, %filled, %added, %count = sparse_tensor.expand %arg0
- : tensor<8x8xf64, #SparseMatrix> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
- return
+// CHECK: %{{.*}}, %{{.*}}, %{{.*}}, %[[T:.*]] = sparse_tensor.expand %[[A]]
+// CHECK: return %[[T]] : index
+func.func @sparse_expansion(%tensor: tensor<8x8xf64, #SparseMatrix>) -> index {
+ %values, %filled, %added, %count = sparse_tensor.expand %tensor
+ : tensor<8x8xf64, #SparseMatrix> to memref<?xf64>, memref<?xi1>, memref<?xindex>
+ return %count : index
}
// -----
@@ -162,14 +162,22 @@ func.func @sparse_expansion(%arg0: tensor<8x8xf64, #SparseMatrix>) {
#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
// CHECK-LABEL: func @sparse_compression(
-// CHECK-SAME: %[[A:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK: sparse_tensor.compress %[[A]]
+// CHECK-SAME: %[[A0:.*0]]: memref<?xf64>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xi1>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: index
+// CHECK-SAME: %[[A4:.*4]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME: %[[A5:.*5]]: index)
+// CHECK: sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]]
// CHECK: return
-func.func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>,
- %arg1: memref<?xindex>, %arg2: memref<?xf64>, %arg3: memref<?xi1>,
- %arg4: memref<?xindex>, %arg5: index) {
- sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
- : tensor<8x8xf64, #SparseMatrix>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+func.func @sparse_compression(%values: memref<?xf64>,
+ %filled: memref<?xi1>,
+ %added: memref<?xindex>,
+ %count: index,
+ %tensor: tensor<8x8xf64, #SparseMatrix>,
+ %index: index) {
+ sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index]
+ : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #SparseMatrix>
return
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
index fd4e297985f80..96c8b00e4de25 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir
@@ -37,11 +37,11 @@
//
// CHECK-SPARSE-LABEL: func @kernel(
// CHECK-SPARSE: %[[A:.*]], %[[B:.*]], %[[C:.*]], %{{.*}} = sparse_tensor.expand
-// CHECK-SPARSE: scf.for {{.*}} {
+// CHECK-SPARSE: %[[COUNT:.*]] = scf.for {{.*}} {
// CHECK-SPARSE: scf.for {{.*}} {
// CHECK-SPARSE: }
// CHECK-SPARSE: }
-// CHECK-SPARSE: sparse_tensor.compress %{{.*}}, %{{.*}}, %[[A]], %[[B]], %[[C]]
+// CHECK-SPARSE: sparse_tensor.compress %[[A]], %[[B]], %[[C]], %[[COUNT]] into
// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.load %{{.*}} hasInserts
// CHECK-SPARSE: return %[[RET]]
//
@@ -86,11 +86,11 @@ func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
// CHECK-SPARSE-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-SPARSE: scf.for %{{.*}} = %[[C0]] to %[[C8]] step %[[C1]] {
// CHECK-SPARSE: %[[A:.*]], %[[B:.*]], %[[C:.*]], %{{.*}} = sparse_tensor.expand
-// CHECK-SPARSE: scf.for {{.*}} {
+// CHECK-SPARSE: %[[COUNT:.*]] = scf.for {{.*}} {
// CHECK-SPARSE: scf.for {{.*}} {
// CHECK-SPARSE: }
// CHECK-SPARSE: }
-// CHECK-SPARSE: sparse_tensor.compress %{{.*}}, %{{.*}}, %[[A]], %[[B]], %[[C]]
+// CHECK-SPARSE: sparse_tensor.compress %[[A]], %[[B]], %[[C]], %[[COUNT]] into
// CHECK-SPARSE: }
// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.load %{{.*}} hasInserts
// CHECK-SPARSE: return %[[RET]]
@@ -134,11 +134,11 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
// CHECK-SPARSE-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-SPARSE: scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK-SPARSE: %[[A:.*]], %[[B:.*]], %[[C:.*]], %{{.*}} = sparse_tensor.expand
-// CHECK-SPARSE: scf.for {{.*}} {
+// CHECK-SPARSE: %[[COUNT:.*]] = scf.for {{.*}} {
// CHECK-SPARSE: scf.for {{.*}} {
// CHECK-SPARSE: }
// CHECK-SPARSE: }
-// CHECK-SPARSE: sparse_tensor.compress %{{.*}}, %{{.*}}, %[[A]], %[[B]], %[[C]]
+// CHECK-SPARSE: sparse_tensor.compress %[[A]], %[[B]], %[[C]], %[[COUNT]]
// CHECK-SPARSE: }
// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.load %{{.*}} hasInserts
// CHECK-SPARSE: return %[[RET]]
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 7b87411a08f2d..132566653c972 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -5,16 +5,16 @@
// CHECK-LABEL: func.func @fill_zero_after_alloc(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_7:.*]] = arith.constant false
-// CHECK: %[[VAL_8:.*]] = arith.constant true
-// CHECK: %[[VAL_9:.*]] = arith.constant 100 : index
-// CHECK: %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK: %[[VAL_11:.*]] = arith.constant 1 : i8
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant false
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 1 : i8
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
@@ -47,67 +47,67 @@
// CHECK: %[[VAL_33:.*]] = call @sparsePointers0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK: %[[VAL_34:.*]] = call @sparseIndices0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
// CHECK: %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr<i8>) -> memref<?xf64>
-// CHECK: %[[VAL_36:.*]] = memref.alloca() : memref<2xindex>
-// CHECK: %[[VAL_37:.*]] = memref.cast %[[VAL_36]] : memref<2xindex> to memref<?xindex>
-// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_40:.*]] = %[[VAL_38]] to %[[VAL_39]] step %[[VAL_6]] {
-// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_40]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_41]], %[[VAL_36]]{{\[}}%[[VAL_5]]] : memref<2xindex>
-// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_40]]] : memref<?xindex>
-// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_6]] : index
-// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_43]]] : memref<?xindex>
-// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: %[[VAL_47:.*]]:3 = scf.while (%[[VAL_48:.*]] = %[[VAL_42]], %[[VAL_49:.*]] = %[[VAL_45]], %[[VAL_50:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
-// CHECK: %[[VAL_51:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_44]] : index
-// CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_49]], %[[VAL_46]] : index
-// CHECK: %[[VAL_53:.*]] = arith.andi %[[VAL_51]], %[[VAL_52]] : i1
-// CHECK: scf.condition(%[[VAL_53]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]] : index, index, index
+// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_36]] to %[[VAL_37]] step %[[VAL_6]] {
+// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_38]]] : memref<?xindex>
+// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_38]]] : memref<?xindex>
+// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_38]], %[[VAL_6]] : index
+// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_41]]] : memref<?xindex>
+// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[VAL_45:.*]]:3 = scf.while (%[[VAL_46:.*]] = %[[VAL_40]], %[[VAL_47:.*]] = %[[VAL_43]], %[[VAL_48:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
+// CHECK: %[[VAL_49:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_42]] : index
+// CHECK: %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_44]] : index
+// CHECK: %[[VAL_51:.*]] = arith.andi %[[VAL_49]], %[[VAL_50]] : i1
+// CHECK: scf.condition(%[[VAL_51]]) %[[VAL_46]], %[[VAL_47]], %[[VAL_48]] : index, index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_54:.*]]: index, %[[VAL_55:.*]]: index, %[[VAL_56:.*]]: index):
-// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_54]]] : memref<?xindex>
-// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_57]] : index
-// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index
-// CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index
-// CHECK: %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index
-// CHECK: %[[VAL_63:.*]] = arith.andi %[[VAL_61]], %[[VAL_62]] : i1
-// CHECK: %[[VAL_64:.*]] = scf.if %[[VAL_63]] -> (index) {
-// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref<?xf64>
-// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_55]], %[[VAL_6]] : index
-// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_67]]] : memref<?xindex>
-// CHECK: %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_66]] to %[[VAL_68]] step %[[VAL_6]] iter_args(%[[VAL_71:.*]] = %[[VAL_56]]) -> (index) {
-// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_70]]] : memref<?xindex>
-// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_72]]] : memref<300xf64>
-// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_70]]] : memref<?xf64>
-// CHECK: %[[VAL_75:.*]] = arith.mulf %[[VAL_65]], %[[VAL_74]] : f64
-// CHECK: %[[VAL_76:.*]] = arith.addf %[[VAL_73]], %[[VAL_75]] : f64
-// CHECK: %[[VAL_77:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_72]]] : memref<300xi1>
-// CHECK: %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_77]], %[[VAL_7]] : i1
-// CHECK: %[[VAL_79:.*]] = scf.if %[[VAL_78]] -> (index) {
-// CHECK: memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_72]]] : memref<300xi1>
-// CHECK: memref.store %[[VAL_72]], %[[VAL_24]]{{\[}}%[[VAL_71]]] : memref<300xindex>
-// CHECK: %[[VAL_80:.*]] = arith.addi %[[VAL_71]], %[[VAL_6]] : index
-// CHECK: scf.yield %[[VAL_80]] : index
+// CHECK: ^bb0(%[[VAL_52:.*]]: index, %[[VAL_53:.*]]: index, %[[VAL_54:.*]]: index):
+// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_52]]] : memref<?xindex>
+// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref<?xindex>
+// CHECK: %[[VAL_57:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_55]] : index
+// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : index
+// CHECK: %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_58]] : index
+// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_58]] : index
+// CHECK: %[[VAL_61:.*]] = arith.andi %[[VAL_59]], %[[VAL_60]] : i1
+// CHECK: %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (index) {
+// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_52]]] : memref<?xf64>
+// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_53]]] : memref<?xindex>
+// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
+// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_65]]] : memref<?xindex>
+// CHECK: %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_6]] iter_args(%[[VAL_69:.*]] = %[[VAL_54]]) -> (index) {
+// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_68]]] : memref<?xindex>
+// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
+// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_68]]] : memref<?xf64>
+// CHECK: %[[VAL_73:.*]] = arith.mulf %[[VAL_63]], %[[VAL_72]] : f64
+// CHECK: %[[VAL_74:.*]] = arith.addf %[[VAL_71]], %[[VAL_73]] : f64
+// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
+// CHECK: %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_75]], %[[VAL_7]] : i1
+// CHECK: %[[VAL_77:.*]] = scf.if %[[VAL_76]] -> (index) {
+// CHECK: memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
+// CHECK: memref.store %[[VAL_70]], %[[VAL_24]]{{\[}}%[[VAL_69]]] : memref<300xindex>
+// CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_69]], %[[VAL_6]] : index
+// CHECK: scf.yield %[[VAL_78]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_71]] : index
+// CHECK: scf.yield %[[VAL_69]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_76]], %[[VAL_20]]{{\[}}%[[VAL_72]]] : memref<300xf64>
-// CHECK: scf.yield %[[VAL_81:.*]] : index
+// CHECK: memref.store %[[VAL_74]], %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
+// CHECK: scf.yield %[[VAL_79:.*]] : index
// CHECK: }
-// CHECK: scf.yield %[[VAL_82:.*]] : index
+// CHECK: scf.yield %[[VAL_80:.*]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_56]] : index
+// CHECK: scf.yield %[[VAL_54]] : index
// CHECK: }
-// CHECK: %[[VAL_83:.*]] = arith.addi %[[VAL_54]], %[[VAL_6]] : index
-// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_61]], %[[VAL_83]], %[[VAL_54]] : index
-// CHECK: %[[VAL_85:.*]] = arith.addi %[[VAL_55]], %[[VAL_6]] : index
-// CHECK: %[[VAL_86:.*]] = arith.select %[[VAL_62]], %[[VAL_85]], %[[VAL_55]] : index
-// CHECK: scf.yield %[[VAL_84]], %[[VAL_86]], %[[VAL_87:.*]] : index, index, index
+// CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_52]], %[[VAL_6]] : index
+// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_59]], %[[VAL_81]], %[[VAL_52]] : index
+// CHECK: %[[VAL_83:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
+// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_60]], %[[VAL_83]], %[[VAL_53]] : index
+// CHECK: scf.yield %[[VAL_82]], %[[VAL_84]], %[[VAL_85:.*]] : index, index, index
// CHECK: }
-// CHECK: func.call @expInsertF64(%[[VAL_19]], %[[VAL_37]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_88:.*]]#2) : (!llvm.ptr<i8>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
+// CHECK: %[[VAL_86:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[VAL_87:.*]] = memref.cast %[[VAL_86]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[VAL_39]], %[[VAL_86]]{{\[}}%[[VAL_5]]] : memref<2xindex>
+// CHECK: func.call @expInsertF64(%[[VAL_19]], %[[VAL_87]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_88:.*]]#2) : (!llvm.ptr<i8>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
// CHECK: }
// CHECK: memref.dealloc %[[VAL_20]] : memref<300xf64>
// CHECK: memref.dealloc %[[VAL_22]] : memref<300xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index d47d2832c969c..6d363cfe47bd0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -350,36 +350,32 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
return %0 : tensor<32xf64>
}
-// CHECK-LABEL: func @zero_preserving_math(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
-// CHECK: %[[VAL_8:.*]] = memref.alloca(%[[VAL_2]]) : memref<?xindex>
-// CHECK: %[[BUF:.*]] = memref.alloca() : memref<f64>
-// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_2]] {
-// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_12]], %[[VAL_8]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf64>
-// CHECK: %[[VAL_14:.*]] = math.absf %[[VAL_13]] : f64
-// CHECK: %[[VAL_15:.*]] = math.ceil %[[VAL_14]] : f64
-// CHECK: %[[VAL_16:.*]] = math.floor %[[VAL_15]] : f64
-// CHECK: %[[VAL_17:.*]] = math.sqrt %[[VAL_16]] : f64
-// CHECK: %[[VAL_18:.*]] = math.expm1 %[[VAL_17]] : f64
-// CHECK: %[[VAL_19:.*]] = math.log1p %[[VAL_18]] : f64
-// CHECK: %[[VAL_20:.*]] = math.sin %[[VAL_19]] : f64
-// CHECK: %[[VAL_21:.*]] = math.tanh %[[VAL_20]] : f64
-// CHECK: memref.store %[[VAL_21]], %[[BUF]][] : memref<f64>
-// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_8]], %[[BUF]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref<?xindex>, memref<f64>
+// CHECK-LABEL: func.func @zero_preserving_math(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = bufferization.alloc_tensor() : tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xf64>
+// CHECK: %[[VAL_7:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_2]] {
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xf64>
+// CHECK: %[[VAL_12:.*]] = math.absf %[[VAL_11]] : f64
+// CHECK: %[[VAL_13:.*]] = math.ceil %[[VAL_12]] : f64
+// CHECK: %[[VAL_14:.*]] = math.floor %[[VAL_13]] : f64
+// CHECK: %[[VAL_15:.*]] = math.sqrt %[[VAL_14]] : f64
+// CHECK: %[[VAL_16:.*]] = math.expm1 %[[VAL_15]] : f64
+// CHECK: %[[VAL_17:.*]] = math.log1p %[[VAL_16]] : f64
+// CHECK: %[[VAL_18:.*]] = math.sin %[[VAL_17]] : f64
+// CHECK: %[[VAL_19:.*]] = math.tanh %[[VAL_18]] : f64
+// CHECK: sparse_tensor.insert %[[VAL_19]] into %[[VAL_3]]{{\[}}%[[VAL_10]]] : tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK: }
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts : tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK: return %[[VAL_20]] : tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
// CHECK: }
-// CHECK: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: return %[[VAL_22]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: }
func.func @zero_preserving_math(%arga: tensor<32xf64, #SV>) -> tensor<32xf64, #SV> {
%c32 = arith.constant 32 : index
%xinp = bufferization.alloc_tensor() : tensor<32xf64, #SV>
@@ -400,30 +396,26 @@ func.func @zero_preserving_math(%arga: tensor<32xf64, #SV>) -> tensor<32xf64, #S
return %0 : tensor<32xf64, #SV>
}
-// CHECK-LABEL: func.func @complex_divbyc(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_3:.*]] = complex.constant [0.000000e+00, 1.000000e+00] : complex<f64>
-// CHECK: %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>> to memref<?xcomplex<f64>>
-// CHECK: %[[VAL_8:.*]] = memref.alloca(%[[VAL_2]]) : memref<?xindex>
-// CHECK: %[[VAL_9:.*]] = memref.alloca() : memref<complex<f64>>
-// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] {
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_13]], %[[VAL_8]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xcomplex<f64>>
-// CHECK: %[[VAL_15:.*]] = complex.div %[[VAL_14]], %[[VAL_3]] : complex<f64>
-// CHECK: memref.store %[[VAL_15]], %[[VAL_9]][] : memref<complex<f64>>
-// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_8]], %[[VAL_9]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>>, memref<?xindex>, memref<complex<f64>>
+// CHECK-LABEL: func.func @complex_divbyc(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = complex.constant [0.000000e+00, 1.000000e+00] : complex<f64>
+// CHECK: %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xcomplex<f64>>
+// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_2]] {
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_10]]] : memref<?xcomplex<f64>>
+// CHECK: %[[VAL_13:.*]] = complex.div %[[VAL_12]], %[[VAL_3]] : complex<f64>
+// CHECK: sparse_tensor.insert %[[VAL_13]] into %[[VAL_4]]{{\[}}%[[VAL_11]]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK: }
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+// CHECK: return %[[VAL_14]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
// CHECK: }
-// CHECK: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: return %[[VAL_16]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: }
func.func @complex_divbyc(%arg0: tensor<32xcomplex<f64>, #SV>) -> tensor<32xcomplex<f64>, #SV> {
%c = complex.constant [0.0, 1.0] : complex<f64>
%init = bufferization.alloc_tensor() : tensor<32xcomplex<f64>, #SV>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index cd4ac517ddfe2..db332cc192d26 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -1,4 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// RUN: mlir-opt %s -sparsification | FileCheck %s
#DenseMatrix = #sparse_tensor.encoding<{
@@ -18,7 +17,7 @@
doc = "X(i,j) = A(i,j) * i * j"
}
-// CHECK-LABEL: func @dense_index(
+// CHECK-LABEL: func.func @dense_index(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
@@ -68,43 +67,38 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
return %r : tensor<?x?xi64, #DenseMatrix>
}
-// CHECK-LABEL: func @sparse_index(
+
+// CHECK-LABEL: func.func @sparse_index(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_6:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]]) : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK: %[[VAL_12:.*]] = memref.alloca(%[[VAL_3]]) : memref<?xindex>
-// CHECK: %[[BUF:.*]] = memref.alloca() : memref<i64>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_2]] {
-// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_16]], %[[VAL_12]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_2]] {
-// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_21]], %[[VAL_12]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = arith.index_cast %[[VAL_21]] : index to i64
-// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_16]] : index to i64
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xi64>
-// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_24]] : i64
-// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_22]], %[[VAL_25]] : i64
-// CHECK: memref.store %[[VAL_26]], %[[BUF]][] : memref<i64>
-// CHECK: sparse_tensor.insert %[[VAL_6]], %[[VAL_12]], %[[BUF]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_12]] step %[[VAL_2]] {
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_2]] {
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_19]] : index to i64
+// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_14]] : index to i64
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xi64>
+// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_21]], %[[VAL_22]] : i64
+// CHECK: %[[VAL_24:.*]] = arith.muli %[[VAL_20]], %[[VAL_23]] : i64
+// CHECK: sparse_tensor.insert %[[VAL_24]] into %[[VAL_5]]{{\[}}%[[VAL_14]], %[[VAL_19]]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_27:.*]] = sparse_tensor.load %[[VAL_6]] hasInserts : tensor<?x?xi64, #sparse_tensor.encoding
-// CHECK: return %[[VAL_27]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: %[[VAL_25:.*]] = sparse_tensor.load %[[VAL_5]] hasInserts : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK: return %[[VAL_25]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK: }
func.func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
-> tensor<?x?xi64, #SparseMatrix> {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index 3d4e0f09c7529..0f629eaa23197 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -1,4 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// RUN: mlir-opt %s \
// RUN: --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
// RUN: --sparsification | FileCheck %s
@@ -7,41 +6,41 @@
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
-// CHECK-LABEL: func @matmul1(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-LABEL: func.func @matmul1(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<20x30xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<10x30xf32>) -> tensor<10x30xf32> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 30 : index
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30xf32>
-// CHECK-DAG: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32>
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_4]] {
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_4]] : index
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_4]] {
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xf32>
-// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] {
-// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_17]], %[[VAL_24]]] : memref<10x30xf32>
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]], %[[VAL_24]]] : memref<20x30xf32>
-// CHECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_23]], %[[VAL_26]] : f32
-// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_25]], %[[VAL_27]] : f32
-// CHECK: memref.store %[[VAL_28]], %[[VAL_13]]{{\[}}%[[VAL_17]], %[[VAL_24]]] : memref<10x30xf32>
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 30 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf32>
+// CHECK: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30xf32>
+// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_5]] {
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_5]] : index
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_5]] {
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xf32>
+// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_23]]] : memref<10x30xf32>
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_23]]] : memref<20x30xf32>
+// CHECK: %[[VAL_26:.*]] = arith.mulf %[[VAL_22]], %[[VAL_25]] : f32
+// CHECK: %[[VAL_27:.*]] = arith.addf %[[VAL_24]], %[[VAL_26]] : f32
+// CHECK: memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_23]]] : memref<10x30xf32>
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_29:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<10x30xf32>
-// CHECK: return %[[VAL_29]] : tensor<10x30xf32>
+// CHECK: %[[VAL_28:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<10x30xf32>
+// CHECK: return %[[VAL_28]] : tensor<10x30xf32>
// CHECK: }
func.func @matmul1(%a: tensor<10x20xf32, #DCSR>,
%b: tensor<20x30xf32>,
@@ -55,91 +54,88 @@ func.func @matmul1(%a: tensor<10x20xf32, #DCSR>,
//
// Computes C = A x B with all matrices sparse (SpMSpM) in DCSR.
//
-// CHECK-LABEL: func @matmul2(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant true
-// CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
-// CHECK: %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
-// CHECK: %[[VAL_19:.*]] = memref.alloca(%[[VAL_5]]) : memref<?xindex>
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_4]] {
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_23]], %[[VAL_19]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_8]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
-// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
-// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_22]], %[[VAL_4]] : index
-// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xindex>
-// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_33:.*]]:3 = scf.while (%[[VAL_34:.*]] = %[[VAL_28]], %[[VAL_35:.*]] = %[[VAL_31]], %[[VAL_36:.*]] = %[[VAL_27]]) : (index, index, index) -> (index, index, index) {
-// CHECK: %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_30]] : index
-// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_32]] : index
-// CHECK: %[[VAL_39:.*]] = arith.andi %[[VAL_37]], %[[VAL_38]] : i1
-// CHECK: scf.condition(%[[VAL_39]]) %[[VAL_34]], %[[VAL_35]], %[[VAL_36]] : index, index, index
+// CHECK-LABEL: func.func @matmul2(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> tensor<4x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> {
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant false
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant true
+// CHECK: %[[VAL_6:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<4x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf64>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<8x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<8x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<8x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<8x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<8x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf64>
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_3]] {
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_21:.*]], %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]] = sparse_tensor.expand %[[VAL_6]] : tensor<4x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_19]], %[[VAL_3]] : index
+// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xindex>
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_30:.*]]:3 = scf.while (%[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_28]], %[[VAL_33:.*]] = %[[VAL_24]]) : (index, index, index) -> (index, index, index) {
+// CHECK: %[[VAL_34:.*]] = arith.cmpi ult, %[[VAL_31]], %[[VAL_27]] : index
+// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_32]], %[[VAL_29]] : index
+// CHECK: %[[VAL_36:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1
+// CHECK: scf.condition(%[[VAL_36]]) %[[VAL_31]], %[[VAL_32]], %[[VAL_33]] : index, index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index):
-// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_40]]] : memref<?xindex>
-// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK: %[[VAL_45:.*]] = arith.cmpi ult, %[[VAL_44]], %[[VAL_43]] : index
-// CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[VAL_44]], %[[VAL_43]] : index
-// CHECK: %[[VAL_47:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_46]] : index
-// CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_46]] : index
-// CHECK: %[[VAL_49:.*]] = arith.andi %[[VAL_47]], %[[VAL_48]] : i1
-// CHECK: %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (index) {
-// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref<?xf64>
-// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_41]], %[[VAL_4]] : index
-// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// CHECK: %[[VAL_55:.*]] = scf.for %[[VAL_56:.*]] = %[[VAL_52]] to %[[VAL_54]] step %[[VAL_4]] iter_args(%[[VAL_57:.*]] = %[[VAL_42]]) -> (index) {
-// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_56]]] : memref<?xindex>
-// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref<?xf64>
-// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_56]]] : memref<?xf64>
-// CHECK: %[[VAL_61:.*]] = arith.mulf %[[VAL_51]], %[[VAL_60]] : f64
-// CHECK: %[[VAL_62:.*]] = arith.addf %[[VAL_59]], %[[VAL_61]] : f64
-// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref<?xi1>
-// CHECK: %[[VAL_64:.*]] = arith.cmpi eq, %[[VAL_63]], %[[VAL_6]] : i1
-// CHECK: %[[VAL_65:.*]] = scf.if %[[VAL_64]] -> (index) {
-// CHECK: memref.store %[[VAL_7]], %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref<?xi1>
-// CHECK: memref.store %[[VAL_58]], %[[VAL_26]]{{\[}}%[[VAL_57]]] : memref<?xindex>
-// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_57]], %[[VAL_4]] : index
-// CHECK: scf.yield %[[VAL_66]] : index
+// CHECK: ^bb0(%[[VAL_37:.*]]: index, %[[VAL_38:.*]]: index, %[[VAL_39:.*]]: index):
+// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_37]]] : memref<?xindex>
+// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref<?xindex>
+// CHECK: %[[VAL_42:.*]] = arith.cmpi ult, %[[VAL_41]], %[[VAL_40]] : index
+// CHECK: %[[VAL_43:.*]] = arith.select %[[VAL_42]], %[[VAL_41]], %[[VAL_40]] : index
+// CHECK: %[[VAL_44:.*]] = arith.cmpi eq, %[[VAL_40]], %[[VAL_43]] : index
+// CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_41]], %[[VAL_43]] : index
+// CHECK: %[[VAL_46:.*]] = arith.andi %[[VAL_44]], %[[VAL_45]] : i1
+// CHECK: %[[VAL_47:.*]] = scf.if %[[VAL_46]] -> (index) {
+// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_37]]] : memref<?xf64>
+// CHECK: %[[VAL_49:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_38]]] : memref<?xindex>
+// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_38]], %[[VAL_3]] : index
+// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_50]]] : memref<?xindex>
+// CHECK: %[[VAL_52:.*]] = scf.for %[[VAL_53:.*]] = %[[VAL_49]] to %[[VAL_51]] step %[[VAL_3]] iter_args(%[[VAL_54:.*]] = %[[VAL_39]]) -> (index) {
+// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_53]]] : memref<?xindex>
+// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_55]]] : memref<?xf64>
+// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_53]]] : memref<?xf64>
+// CHECK: %[[VAL_58:.*]] = arith.mulf %[[VAL_48]], %[[VAL_57]] : f64
+// CHECK: %[[VAL_59:.*]] = arith.addf %[[VAL_56]], %[[VAL_58]] : f64
+// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_55]]] : memref<?xi1>
+// CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_60]], %[[VAL_4]] : i1
+// CHECK: %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (index) {
+// CHECK: memref.store %[[VAL_5]], %[[VAL_22]]{{\[}}%[[VAL_55]]] : memref<?xi1>
+// CHECK: memref.store %[[VAL_55]], %[[VAL_23]]{{\[}}%[[VAL_54]]] : memref<?xindex>
+// CHECK: %[[VAL_63:.*]] = arith.addi %[[VAL_54]], %[[VAL_3]] : index
+// CHECK: scf.yield %[[VAL_63]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_57]] : index
+// CHECK: scf.yield %[[VAL_54]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_62]], %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref<?xf64>
-// CHECK: scf.yield %[[VAL_67:.*]] : index
+// CHECK: memref.store %[[VAL_59]], %[[VAL_21]]{{\[}}%[[VAL_55]]] : memref<?xf64>
+// CHECK: scf.yield %[[VAL_64:.*]] : index
// CHECK: }
-// CHECK: scf.yield %[[VAL_68:.*]] : index
+// CHECK: scf.yield %[[VAL_65:.*]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_42]] : index
+// CHECK: scf.yield %[[VAL_39]] : index
// CHECK: }
-// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_46]] : index
-// CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_40]], %[[VAL_4]] : index
-// CHECK: %[[VAL_71:.*]] = arith.select %[[VAL_69]], %[[VAL_70]], %[[VAL_40]] : index
-// CHECK: %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_46]] : index
-// CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_41]], %[[VAL_4]] : index
-// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_72]], %[[VAL_73]], %[[VAL_41]] : index
-// CHECK: scf.yield %[[VAL_71]], %[[VAL_74]], %[[VAL_75:.*]] : index, index, index
+// CHECK: %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_40]], %[[VAL_43]] : index
+// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_37]], %[[VAL_3]] : index
+// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_66]], %[[VAL_67]], %[[VAL_37]] : index
+// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_41]], %[[VAL_43]] : index
+// CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_38]], %[[VAL_3]] : index
+// CHECK: %[[VAL_71:.*]] = arith.select %[[VAL_69]], %[[VAL_70]], %[[VAL_38]] : index
+// CHECK: scf.yield %[[VAL_68]], %[[VAL_71]], %[[VAL_72:.*]] : index, index, index
// CHECK: }
-// CHECK: sparse_tensor.compress %[[VAL_8]], %[[VAL_19]], %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_76:.*]]#2 : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+// CHECK: sparse_tensor.compress %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_73:.*]]#2 into %[[VAL_6]]{{\[}}%[[VAL_20]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
-// CHECK: %[[VAL_77:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: return %[[VAL_77]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_74:.*]] = sparse_tensor.load %[[VAL_6]] hasInserts : tensor<4x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: return %[[VAL_74]] : tensor<4x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
%B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
@@ -151,45 +147,45 @@ func.func @matmul2(%A: tensor<4x8xf64, #DCSR>,
return %D: tensor<4x4xf64, #DCSR>
}
-// CHECK-LABEL: func @conv2d(
+// CHECK-LABEL: func.func @conv2d(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<6x6xi32>) -> tensor<6x6xi32> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 6 : index
-// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<8x8xi32>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x3xi32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_4]] {
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_4]] : index
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_4]] {
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref<?xi32>
-// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] {
-// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] {
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_25]], %[[VAL_24]]] : memref<6x6xi32>
-// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_17]] : index
-// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_24]], %[[VAL_22]] : index
-// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_27]], %[[VAL_28]]] : memref<8x8xi32>
-// CHECK: %[[VAL_30:.*]] = arith.muli %[[VAL_29]], %[[VAL_23]] : i32
-// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_26]], %[[VAL_30]] : i32
-// CHECK: memref.store %[[VAL_31]], %[[VAL_13]]{{\[}}%[[VAL_25]], %[[VAL_24]]] : memref<6x6xi32>
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<8x8xi32>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x3xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xi32>
+// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_5]] {
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_5]] : index
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_5]] {
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xi32>
+// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]], %[[VAL_23]]] : memref<6x6xi32>
+// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_24]], %[[VAL_16]] : index
+// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_26]], %[[VAL_27]]] : memref<8x8xi32>
+// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_28]], %[[VAL_22]] : i32
+// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_25]], %[[VAL_29]] : i32
+// CHECK: memref.store %[[VAL_30]], %[[VAL_12]]{{\[}}%[[VAL_24]], %[[VAL_23]]] : memref<6x6xi32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_32:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<6x6xi32>
-// CHECK: return %[[VAL_32]] : tensor<6x6xi32>
+// CHECK: %[[VAL_31:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x6xi32>
+// CHECK: return %[[VAL_31]] : tensor<6x6xi32>
// CHECK: }
func.func @conv2d(%input: tensor<8x8xi32>,
%filter: tensor<3x3xi32, #DCSR>,
@@ -200,45 +196,45 @@ func.func @conv2d(%input: tensor<8x8xi32>,
return %0 : tensor<6x6xi32>
}
-// CHECK-LABEL: func @quantized_matmul(
+// CHECK-LABEL: func.func @quantized_matmul(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x3xi8>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x6xi8, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<5x6xi64>) -> tensor<5x6xi64> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : i64
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<5x3xi8>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x6xi8, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<5x6xi64>
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_5]] {
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref<?xindex>
-// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_17]], %[[VAL_5]] : index
-// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_19]] to %[[VAL_21]] step %[[VAL_5]] {
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref<?xi8>
-// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_4]] to %[[VAL_6]] step %[[VAL_5]] {
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_25]], %[[VAL_23]]] : memref<5x6xi64>
-// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]], %[[VAL_18]]] : memref<5x3xi8>
-// CHECK: %[[VAL_28:.*]] = arith.extsi %[[VAL_27]] : i8 to i64
-// CHECK: %[[VAL_29:.*]] = arith.subi %[[VAL_28]], %[[VAL_3]] : i64
-// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_24]] : i8 to i64
-// CHECK: %[[VAL_31:.*]] = arith.muli %[[VAL_29]], %[[VAL_30]] : i64
-// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_26]], %[[VAL_31]] : i64
-// CHECK: memref.store %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_25]], %[[VAL_23]]] : memref<5x6xi64>
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 2 : i64
+// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<5x3xi8>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x6xi8, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xi8>
+// CHECK: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<5x6xi64>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] {
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<?xi8>
+// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]], %[[VAL_22]]] : memref<5x6xi64>
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]], %[[VAL_17]]] : memref<5x3xi8>
+// CHECK: %[[VAL_27:.*]] = arith.extsi %[[VAL_26]] : i8 to i64
+// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_27]], %[[VAL_6]] : i64
+// CHECK: %[[VAL_29:.*]] = arith.extsi %[[VAL_23]] : i8 to i64
+// CHECK: %[[VAL_30:.*]] = arith.muli %[[VAL_28]], %[[VAL_29]] : i64
+// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_25]], %[[VAL_30]] : i64
+// CHECK: memref.store %[[VAL_31]], %[[VAL_13]]{{\[}}%[[VAL_24]], %[[VAL_22]]] : memref<5x6xi64>
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<5x6xi64>
-// CHECK: return %[[VAL_33]] : tensor<5x6xi64>
+// CHECK: %[[VAL_32:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<5x6xi64>
+// CHECK: return %[[VAL_32]] : tensor<5x6xi64>
// CHECK: }
func.func @quantized_matmul(%input1: tensor<5x3xi8>,
%input2: tensor<3x6xi8, #DCSR>,
@@ -251,55 +247,58 @@ func.func @quantized_matmul(%input1: tensor<5x3xi8>,
return %0: tensor<5x6xi64>
}
-// CHECK-LABEL: func @sparse_dot(
+// CHECK-LABEL: func.func @sparse_dot(
+// CHECK-SAME: %[[VAL_0:.*0]]: tensor<1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>,
+// CHECK-SAME: %[[VAL_1:.*1]]: tensor<1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>,
+// CHECK-SAME: %[[VAL_2:.*2]]: tensor<f32>) -> tensor<f32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0:.*]] {dimension = 0 : index} : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1:.*]] {dimension = 0 : index} : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
-// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2:.*]] : memref<f32>
-// CHECK-DAG: %[[VAL_13:.*]] = memref.load %[[VAL_11]][] : memref<f32>
-// CHECK-DAG: %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK-DAG: %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK-DAG: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK-DAG: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]]:3 = scf.while (%[[VAL_19:.*]] = %[[VAL_14]], %[[VAL_20:.*]] = %[[VAL_16]], %[[VAL_21:.*]] = %[[VAL_13]]) : (index, index, f32) -> (index, index, f32) {
-// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_19]], %[[VAL_15]] : index
-// CHECK: %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_17]] : index
-// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_22]], %[[VAL_23]] : i1
-// CHECK: scf.condition(%[[VAL_24]]) %[[VAL_19]], %[[VAL_20]], %[[VAL_21]] : index, index, f32
+// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xf32>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<1024xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> to memref<?xf32>
+// CHECK: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<f32>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_11]][] : memref<f32>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]]:3 = scf.while (%[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_15]], %[[VAL_20:.*]] = %[[VAL_12]]) : (index, index, f32) -> (index, index, f32) {
+// CHECK: %[[VAL_21:.*]] = arith.cmpi ult, %[[VAL_18]], %[[VAL_14]] : index
+// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_19]], %[[VAL_16]] : index
+// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
+// CHECK: scf.condition(%[[VAL_23]]) %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, f32
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index, %[[VAL_27:.*]]: f32):
-// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_25]]] : memref<?xindex>
-// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xindex>
-// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
-// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
-// CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
-// CHECK: %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
-// CHECK: %[[VAL_34:.*]] = arith.andi %[[VAL_32]], %[[VAL_33]] : i1
-// CHECK: %[[VAL_35:.*]] = scf.if %[[VAL_34]] -> (f32) {
-// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref<?xf32>
-// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xf32>
-// CHECK: %[[VAL_38:.*]] = arith.mulf %[[VAL_36]], %[[VAL_37]] : f32
-// CHECK: %[[VAL_39:.*]] = arith.addf %[[VAL_27]], %[[VAL_38]] : f32
-// CHECK: scf.yield %[[VAL_39]] : f32
+// CHECK: ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index, %[[VAL_26:.*]]: f32):
+// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK: %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_28]], %[[VAL_27]] : index
+// CHECK: %[[VAL_30:.*]] = arith.select %[[VAL_29]], %[[VAL_28]], %[[VAL_27]] : index
+// CHECK: %[[VAL_31:.*]] = arith.cmpi eq, %[[VAL_27]], %[[VAL_30]] : index
+// CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_30]] : index
+// CHECK: %[[VAL_33:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1
+// CHECK: %[[VAL_34:.*]] = scf.if %[[VAL_33]] -> (f32) {
+// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK: %[[VAL_37:.*]] = arith.mulf %[[VAL_35]], %[[VAL_36]] : f32
+// CHECK: %[[VAL_38:.*]] = arith.addf %[[VAL_26]], %[[VAL_37]] : f32
+// CHECK: scf.yield %[[VAL_38]] : f32
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_27]] : f32
+// CHECK: scf.yield %[[VAL_26]] : f32
// CHECK: }
-// CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
-// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_25]], %[[VAL_4]] : index
-// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_40]], %[[VAL_41]], %[[VAL_25]] : index
-// CHECK: %[[VAL_43:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
-// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_26]], %[[VAL_4]] : index
-// CHECK: %[[VAL_45:.*]] = arith.select %[[VAL_43]], %[[VAL_44]], %[[VAL_26]] : index
-// CHECK: scf.yield %[[VAL_42]], %[[VAL_45]], %[[VAL_46:.*]] : index, index, f32
+// CHECK: %[[VAL_39:.*]] = arith.cmpi eq, %[[VAL_27]], %[[VAL_30]] : index
+// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_24]], %[[VAL_4]] : index
+// CHECK: %[[VAL_41:.*]] = arith.select %[[VAL_39]], %[[VAL_40]], %[[VAL_24]] : index
+// CHECK: %[[VAL_42:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_30]] : index
+// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_25]], %[[VAL_4]] : index
+// CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_42]], %[[VAL_43]], %[[VAL_25]] : index
+// CHECK: scf.yield %[[VAL_41]], %[[VAL_44]], %[[VAL_45:.*]] : index, index, f32
// CHECK: }
-// CHECK: memref.store %[[VAL_47:.*]]#2, %[[VAL_11]][] : memref<f32>
-// CHECK: %[[VAL_48:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<f32>
-// CHECK: return %[[VAL_48]] : tensor<f32>
+// CHECK: memref.store %[[VAL_46:.*]]#2, %[[VAL_11]][] : memref<f32>
+// CHECK: %[[VAL_47:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<f32>
+// CHECK: return %[[VAL_47]] : tensor<f32>
// CHECK: }
func.func @sparse_dot(%a: tensor<1024xf32, #SparseVector>,
%b: tensor<1024xf32, #SparseVector>,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 9e1724d5d45ff..d52a81f8d6ffb 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -1,4 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// RUN: mlir-opt %s -sparsification | FileCheck %s
#CSR = #sparse_tensor.encoding<{
@@ -23,7 +22,7 @@
doc = "X(i,j) *= 2 or X(i,j) += X(i,j)"
}
-// CHECK-LABEL: func @sparse_simply_dynamic1(
+// CHECK-LABEL: func.func @sparse_simply_dynamic1(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -57,7 +56,7 @@ func.func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR>) -> tensor<32x
return %0 : tensor<32x16xf32, #DCSR>
}
-// CHECK-LABEL: func @sparse_simply_dynamic2(
+// CHECK-LABEL: func.func @sparse_simply_dynamic2(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
@@ -99,35 +98,29 @@ func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR>) -> tensor<32x
doc = "X(i,j) = A(i,j) * 2.0"
}
-// CHECK-LABEL: func @sparse_truly_dynamic(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 10 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_7:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: %[[VAL_11:.*]] = memref.alloca(%[[VAL_5]]) : memref<?xindex>
-// CHECK: %[[BUF:.*]] = memref.alloca() : memref<f32>
-// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_2]] step %[[VAL_4]] {
-// CHECK: memref.store %[[VAL_12]], %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_4]] {
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_17]], %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xf32>
-// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_18]], %[[VAL_1]] : f32
-// CHECK: memref.store %[[VAL_19]], %[[BUF]][] : memref<f32>
-// CHECK: sparse_tensor.insert %[[VAL_7]], %[[VAL_11]], %[[BUF]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK-LABEL: func.func @sparse_truly_dynamic(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xf32>
+// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_3]] : index
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_3]] {
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xf32>
+// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VAL_15]], %[[VAL_4]] : f32
+// CHECK: sparse_tensor.insert %[[VAL_16]] into %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_14]]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_7]] hasInserts : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: return %[[VAL_20]] : tensor<10x20xf32, #sparse_tensor.encoding<{
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.load %[[VAL_5]] hasInserts : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: return %[[VAL_17]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #DCSR> {
%s = arith.constant 2.0 : f32
@@ -152,136 +145,129 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)"
}
-// CHECK-LABEL: func @sumred(
+// CHECK-LABEL: func.func @sumred(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>)
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>) -> tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>>
-// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>>
-// CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor(%[[VAL_6]], %[[VAL_7]]) : tensor<?x?xi32, #{{.*}}>>
-// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 2 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 2 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xi32>
-// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_20:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 2 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_21:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 2 : index} : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xi32>
-// CHECK: %[[VAL_23:.*]] = memref.alloca(%[[VAL_4]]) : memref<?xindex>
-// CHECK: %[[BUF:.*]] = memref.alloca() : memref<i32>
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_28:.*]]:2 = scf.while (%[[VAL_29:.*]] = %[[VAL_24]], %[[VAL_30:.*]] = %[[VAL_26]]) : (index, index) -> (index, index) {
-// CHECK: %[[VAL_31:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_25]] : index
-// CHECK: %[[VAL_32:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_27]] : index
-// CHECK: %[[VAL_33:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1
-// CHECK: scf.condition(%[[VAL_33]]) %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_7:.*]] = bufferization.alloc_tensor(%[[VAL_5]], %[[VAL_6]]) : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 2 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 2 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xi32>
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 2 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 2 : index} : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>> to memref<?xi32>
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_26:.*]]:2 = scf.while (%[[VAL_27:.*]] = %[[VAL_22]], %[[VAL_28:.*]] = %[[VAL_24]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_27]], %[[VAL_23]] : index
+// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_28]], %[[VAL_25]] : index
+// CHECK: %[[VAL_31:.*]] = arith.andi %[[VAL_29]], %[[VAL_30]] : i1
+// CHECK: scf.condition(%[[VAL_31]]) %[[VAL_27]], %[[VAL_28]] : index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index):
-// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_34]]] : memref<?xindex>
-// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_35]]] : memref<?xindex>
-// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_36]] : index
-// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index
-// CHECK: memref.store %[[VAL_39]], %[[VAL_23]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index
-// CHECK: %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index
-// CHECK: %[[VAL_42:.*]] = arith.andi %[[VAL_40]], %[[VAL_41]] : i1
-// CHECK: scf.if %[[VAL_42]] {
-// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_34]]] : memref<?xindex>
-// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index
-// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_44]]] : memref<?xindex>
-// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_35]]] : memref<?xindex>
-// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index
-// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_47]]] : memref<?xindex>
-// CHECK: %[[VAL_49:.*]]:2 = scf.while (%[[VAL_50:.*]] = %[[VAL_43]], %[[VAL_51:.*]] = %[[VAL_46]]) : (index, index) -> (index, index) {
-// CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_50]], %[[VAL_45]] : index
-// CHECK: %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_48]] : index
-// CHECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1
-// CHECK: scf.condition(%[[VAL_54]]) %[[VAL_50]], %[[VAL_51]] : index, index
+// CHECK: ^bb0(%[[VAL_32:.*]]: index, %[[VAL_33:.*]]: index):
+// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_32]]] : memref<?xindex>
+// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_33]]] : memref<?xindex>
+// CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_34]] : index
+// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_36]], %[[VAL_35]], %[[VAL_34]] : index
+// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_37]] : index
+// CHECK: %[[VAL_39:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_37]] : index
+// CHECK: %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1
+// CHECK: scf.if %[[VAL_40]] {
+// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_32]]] : memref<?xindex>
+// CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_32]], %[[VAL_3]] : index
+// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_42]]] : memref<?xindex>
+// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_33]]] : memref<?xindex>
+// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_33]], %[[VAL_3]] : index
+// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_45]]] : memref<?xindex>
+// CHECK: %[[VAL_47:.*]]:2 = scf.while (%[[VAL_48:.*]] = %[[VAL_41]], %[[VAL_49:.*]] = %[[VAL_44]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_43]] : index
+// CHECK: %[[VAL_51:.*]] = arith.cmpi ult, %[[VAL_49]], %[[VAL_46]] : index
+// CHECK: %[[VAL_52:.*]] = arith.andi %[[VAL_50]], %[[VAL_51]] : i1
+// CHECK: scf.condition(%[[VAL_52]]) %[[VAL_48]], %[[VAL_49]] : index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_55:.*]]: index, %[[VAL_56:.*]]: index):
-// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_56]]] : memref<?xindex>
-// CHECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_57]] : index
-// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index
-// CHECK: memref.store %[[VAL_60]], %[[VAL_23]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index
-// CHECK: %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index
-// CHECK: %[[VAL_63:.*]] = arith.andi %[[VAL_61]], %[[VAL_62]] : i1
-// CHECK: scf.if %[[VAL_63]] {
-// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index
-// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_65]]] : memref<?xindex>
-// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_56]]] : memref<?xindex>
-// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index
-// CHECK: %[[VAL_69:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_68]]] : memref<?xindex>
-// CHECK: %[[VAL_70:.*]]:3 = scf.while (%[[VAL_71:.*]] = %[[VAL_64]], %[[VAL_72:.*]] = %[[VAL_67]], %[[VAL_73:.*]] = %[[VAL_5]]) : (index, index, i32) -> (index, index, i32) {
-// CHECK: %[[VAL_74:.*]] = arith.cmpi ult, %[[VAL_71]], %[[VAL_66]] : index
-// CHECK: %[[VAL_75:.*]] = arith.cmpi ult, %[[VAL_72]], %[[VAL_69]] : index
-// CHECK: %[[VAL_76:.*]] = arith.andi %[[VAL_74]], %[[VAL_75]] : i1
-// CHECK: scf.condition(%[[VAL_76]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, i32
+// CHECK: ^bb0(%[[VAL_53:.*]]: index, %[[VAL_54:.*]]: index):
+// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_53]]] : memref<?xindex>
+// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_54]]] : memref<?xindex>
+// CHECK: %[[VAL_57:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_55]] : index
+// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : index
+// CHECK: %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_58]] : index
+// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_58]] : index
+// CHECK: %[[VAL_61:.*]] = arith.andi %[[VAL_59]], %[[VAL_60]] : i1
+// CHECK: scf.if %[[VAL_61]] {
+// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_53]]] : memref<?xindex>
+// CHECK: %[[VAL_63:.*]] = arith.addi %[[VAL_53]], %[[VAL_3]] : index
+// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_63]]] : memref<?xindex>
+// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_54]]] : memref<?xindex>
+// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_54]], %[[VAL_3]] : index
+// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_66]]] : memref<?xindex>
+// CHECK: %[[VAL_68:.*]]:3 = scf.while (%[[VAL_69:.*]] = %[[VAL_62]], %[[VAL_70:.*]] = %[[VAL_65]], %[[VAL_71:.*]] = %[[VAL_4]]) : (index, index, i32) -> (index, index, i32) {
+// CHECK: %[[VAL_72:.*]] = arith.cmpi ult, %[[VAL_69]], %[[VAL_64]] : index
+// CHECK: %[[VAL_73:.*]] = arith.cmpi ult, %[[VAL_70]], %[[VAL_67]] : index
+// CHECK: %[[VAL_74:.*]] = arith.andi %[[VAL_72]], %[[VAL_73]] : i1
+// CHECK: scf.condition(%[[VAL_74]]) %[[VAL_69]], %[[VAL_70]], %[[VAL_71]] : index, index, i32
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_77:.*]]: index, %[[VAL_78:.*]]: index, %[[VAL_79:.*]]: i32):
-// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_77]]] : memref<?xindex>
-// CHECK: %[[VAL_81:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_78]]] : memref<?xindex>
-// CHECK: %[[VAL_82:.*]] = arith.cmpi ult, %[[VAL_81]], %[[VAL_80]] : index
-// CHECK: %[[VAL_83:.*]] = arith.select %[[VAL_82]], %[[VAL_81]], %[[VAL_80]] : index
-// CHECK: memref.store %[[VAL_83]], %[[VAL_23]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index
-// CHECK: %[[VAL_85:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index
-// CHECK: %[[VAL_86:.*]] = arith.andi %[[VAL_84]], %[[VAL_85]] : i1
-// CHECK: %[[VAL_87:.*]] = scf.if %[[VAL_86]] -> (i32) {
-// CHECK: %[[VAL_88:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_77]]] : memref<?xi32>
-// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_78]]] : memref<?xi32>
-// CHECK: %[[VAL_90:.*]] = arith.muli %[[VAL_88]], %[[VAL_89]] : i32
-// CHECK: %[[VAL_91:.*]] = arith.addi %[[VAL_79]], %[[VAL_90]] : i32
-// CHECK: scf.yield %[[VAL_91]] : i32
+// CHECK: ^bb0(%[[VAL_75:.*]]: index, %[[VAL_76:.*]]: index, %[[VAL_77:.*]]: i32):
+// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_75]]] : memref<?xindex>
+// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_76]]] : memref<?xindex>
+// CHECK: %[[VAL_80:.*]] = arith.cmpi ult, %[[VAL_79]], %[[VAL_78]] : index
+// CHECK: %[[VAL_81:.*]] = arith.select %[[VAL_80]], %[[VAL_79]], %[[VAL_78]] : index
+// CHECK: %[[VAL_82:.*]] = arith.cmpi eq, %[[VAL_78]], %[[VAL_81]] : index
+// CHECK: %[[VAL_83:.*]] = arith.cmpi eq, %[[VAL_79]], %[[VAL_81]] : index
+// CHECK: %[[VAL_84:.*]] = arith.andi %[[VAL_82]], %[[VAL_83]] : i1
+// CHECK: %[[VAL_85:.*]] = scf.if %[[VAL_84]] -> (i32) {
+// CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_75]]] : memref<?xi32>
+// CHECK: %[[VAL_87:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_76]]] : memref<?xi32>
+// CHECK: %[[VAL_88:.*]] = arith.muli %[[VAL_86]], %[[VAL_87]] : i32
+// CHECK: %[[VAL_89:.*]] = arith.addi %[[VAL_77]], %[[VAL_88]] : i32
+// CHECK: scf.yield %[[VAL_89]] : i32
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_79]] : i32
+// CHECK: scf.yield %[[VAL_77]] : i32
// CHECK: }
-// CHECK: %[[VAL_92:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index
-// CHECK: %[[VAL_93:.*]] = arith.addi %[[VAL_77]], %[[VAL_3]] : index
-// CHECK: %[[VAL_94:.*]] = arith.select %[[VAL_92]], %[[VAL_93]], %[[VAL_77]] : index
-// CHECK: %[[VAL_95:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index
-// CHECK: %[[VAL_96:.*]] = arith.addi %[[VAL_78]], %[[VAL_3]] : index
-// CHECK: %[[VAL_97:.*]] = arith.select %[[VAL_95]], %[[VAL_96]], %[[VAL_78]] : index
-// CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32
+// CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_78]], %[[VAL_81]] : index
+// CHECK: %[[VAL_91:.*]] = arith.addi %[[VAL_75]], %[[VAL_3]] : index
+// CHECK: %[[VAL_92:.*]] = arith.select %[[VAL_90]], %[[VAL_91]], %[[VAL_75]] : index
+// CHECK: %[[VAL_93:.*]] = arith.cmpi eq, %[[VAL_79]], %[[VAL_81]] : index
+// CHECK: %[[VAL_94:.*]] = arith.addi %[[VAL_76]], %[[VAL_3]] : index
+// CHECK: %[[VAL_95:.*]] = arith.select %[[VAL_93]], %[[VAL_94]], %[[VAL_76]] : index
+// CHECK: scf.yield %[[VAL_92]], %[[VAL_95]], %[[VAL_96:.*]] : index, index, i32
// CHECK: }
-// CHECK: memref.store %[[VAL_70]]#2, %[[BUF]][] : memref<i32>
-// CHECK: sparse_tensor.insert %[[VAL_8]], %[[VAL_23]], %[[BUF]] : tensor<?x?xi32, #{{.*}}>, memref<?xindex>, memref<i32>
+// CHECK: sparse_tensor.insert %[[VAL_97:.*]]#2 into %[[VAL_7]]{{\[}}%[[VAL_37]], %[[VAL_58]]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: } else {
// CHECK: }
-// CHECK: %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index
-// CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index
-// CHECK: %[[VAL_102:.*]] = arith.select %[[VAL_100]], %[[VAL_101]], %[[VAL_55]] : index
-// CHECK: %[[VAL_103:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index
-// CHECK: %[[VAL_104:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index
-// CHECK: %[[VAL_105:.*]] = arith.select %[[VAL_103]], %[[VAL_104]], %[[VAL_56]] : index
-// CHECK: scf.yield %[[VAL_102]], %[[VAL_105]] : index, index
+// CHECK: %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_58]] : index
+// CHECK: %[[VAL_99:.*]] = arith.addi %[[VAL_53]], %[[VAL_3]] : index
+// CHECK: %[[VAL_100:.*]] = arith.select %[[VAL_98]], %[[VAL_99]], %[[VAL_53]] : index
+// CHECK: %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_58]] : index
+// CHECK: %[[VAL_102:.*]] = arith.addi %[[VAL_54]], %[[VAL_3]] : index
+// CHECK: %[[VAL_103:.*]] = arith.select %[[VAL_101]], %[[VAL_102]], %[[VAL_54]] : index
+// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]] : index, index
// CHECK: }
// CHECK: } else {
// CHECK: }
-// CHECK: %[[VAL_106:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index
-// CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index
-// CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_106]], %[[VAL_107]], %[[VAL_34]] : index
-// CHECK: %[[VAL_109:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index
-// CHECK: %[[VAL_110:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index
-// CHECK: %[[VAL_111:.*]] = arith.select %[[VAL_109]], %[[VAL_110]], %[[VAL_35]] : index
-// CHECK: scf.yield %[[VAL_108]], %[[VAL_111]] : index, index
+// CHECK: %[[VAL_104:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_37]] : index
+// CHECK: %[[VAL_105:.*]] = arith.addi %[[VAL_32]], %[[VAL_3]] : index
+// CHECK: %[[VAL_106:.*]] = arith.select %[[VAL_104]], %[[VAL_105]], %[[VAL_32]] : index
+// CHECK: %[[VAL_107:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_37]] : index
+// CHECK: %[[VAL_108:.*]] = arith.addi %[[VAL_33]], %[[VAL_3]] : index
+// CHECK: %[[VAL_109:.*]] = arith.select %[[VAL_107]], %[[VAL_108]], %[[VAL_33]] : index
+// CHECK: scf.yield %[[VAL_106]], %[[VAL_109]] : index, index
// CHECK: }
-// CHECK: %[[VAL_112:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<?x?xi32, #{{.*}}>
-// CHECK: return %[[VAL_112]] : tensor<?x?xi32, #{{.*}}>
+// CHECK: %[[VAL_110:.*]] = sparse_tensor.load %[[VAL_7]] hasInserts : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: return %[[VAL_110]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
%argb: tensor<?x?x?xi32, #SparseTensor>) -> tensor<?x?xi32, #DCSR> {
@@ -312,93 +298,90 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
doc = "C(i,j) = SUM_k A(i,k) * B(k,j)"
}
-// CHECK-LABEL: func @matmat(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant true
-// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor(%[[VAL_7]], %[[VAL_8]]) : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
-// CHECK: %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
-// CHECK: %[[VAL_20:.*]] = memref.alloca(%[[VAL_4]]) : memref<?xindex>
-// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_21]] to %[[VAL_22]] step %[[VAL_3]] {
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_24]], %[[VAL_20]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]], %[[VAL_28:.*]] = sparse_tensor.expand %[[VAL_9]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>, memref<?xi1>, memref<?xindex>, index
-// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
-// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : index
-// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_30]]] : memref<?xindex>
-// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_34:.*]]:3 = scf.while (%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]]) : (index, index, index) -> (index, index, index) {
-// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_31]] : index
-// CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_36]], %[[VAL_33]] : index
-// CHECK: %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1
-// CHECK: scf.condition(%[[VAL_40]]) %[[VAL_35]], %[[VAL_36]], %[[VAL_37]] : index, index, index
+// CHECK-LABEL: func.func @matmat(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant false
+// CHECK: %[[VAL_5:.*]] = arith.constant true
+// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor(%[[VAL_6]], %[[VAL_7]]) : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf32>
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf32>
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_19]] to %[[VAL_20]] step %[[VAL_3]] {
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]] = sparse_tensor.expand %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf32>, memref<?xi1>, memref<?xindex>
+// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref<?xindex>
+// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_32:.*]]:3 = scf.while (%[[VAL_33:.*]] = %[[VAL_27]], %[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_26]]) : (index, index, index) -> (index, index, index) {
+// CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_29]] : index
+// CHECK: %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_31]] : index
+// CHECK: %[[VAL_38:.*]] = arith.andi %[[VAL_36]], %[[VAL_37]] : i1
+// CHECK: scf.condition(%[[VAL_38]]) %[[VAL_33]], %[[VAL_34]], %[[VAL_35]] : index, index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index):
-// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_41]]] : memref<?xindex>
-// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_42]]] : memref<?xindex>
-// CHECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_44]] : index
-// CHECK: %[[VAL_47:.*]] = arith.select %[[VAL_46]], %[[VAL_45]], %[[VAL_44]] : index
-// CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index
-// CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index
-// CHECK: %[[VAL_50:.*]] = arith.andi %[[VAL_48]], %[[VAL_49]] : i1
-// CHECK: %[[VAL_51:.*]] = scf.if %[[VAL_50]] -> (index) {
-// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_41]]] : memref<?xf32>
-// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_42]]] : memref<?xindex>
-// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index
-// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_54]]] : memref<?xindex>
-// CHECK: %[[VAL_56:.*]] = scf.for %[[VAL_57:.*]] = %[[VAL_53]] to %[[VAL_55]] step %[[VAL_3]] iter_args(%[[VAL_58:.*]] = %[[VAL_43]]) -> (index) {
-// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_57]]] : memref<?xindex>
-// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_59]]] : memref<?xf32>
-// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_57]]] : memref<?xf32>
-// CHECK: %[[VAL_62:.*]] = arith.mulf %[[VAL_52]], %[[VAL_61]] : f32
-// CHECK: %[[VAL_63:.*]] = arith.addf %[[VAL_60]], %[[VAL_62]] : f32
-// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_59]]] : memref<?xi1>
-// CHECK: %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_64]], %[[VAL_5]] : i1
-// CHECK: %[[VAL_66:.*]] = scf.if %[[VAL_65]] -> (index) {
-// CHECK: memref.store %[[VAL_6]], %[[VAL_26]]{{\[}}%[[VAL_59]]] : memref<?xi1>
-// CHECK: memref.store %[[VAL_59]], %[[VAL_27]]{{\[}}%[[VAL_58]]] : memref<?xindex>
-// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index
-// CHECK: scf.yield %[[VAL_67]] : index
+// CHECK: ^bb0(%[[VAL_39:.*]]: index, %[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index):
+// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_39]]] : memref<?xindex>
+// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK: %[[VAL_44:.*]] = arith.cmpi ult, %[[VAL_43]], %[[VAL_42]] : index
+// CHECK: %[[VAL_45:.*]] = arith.select %[[VAL_44]], %[[VAL_43]], %[[VAL_42]] : index
+// CHECK: %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_42]], %[[VAL_45]] : index
+// CHECK: %[[VAL_47:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_45]] : index
+// CHECK: %[[VAL_48:.*]] = arith.andi %[[VAL_46]], %[[VAL_47]] : i1
+// CHECK: %[[VAL_49:.*]] = scf.if %[[VAL_48]] -> (index) {
+// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_39]]] : memref<?xf32>
+// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK: %[[VAL_52:.*]] = arith.addi %[[VAL_40]], %[[VAL_3]] : index
+// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_52]]] : memref<?xindex>
+// CHECK: %[[VAL_54:.*]] = scf.for %[[VAL_55:.*]] = %[[VAL_51]] to %[[VAL_53]] step %[[VAL_3]] iter_args(%[[VAL_56:.*]] = %[[VAL_41]]) -> (index) {
+// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_23]]{{\[}}%[[VAL_57]]] : memref<?xf32>
+// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_55]]] : memref<?xf32>
+// CHECK: %[[VAL_60:.*]] = arith.mulf %[[VAL_50]], %[[VAL_59]] : f32
+// CHECK: %[[VAL_61:.*]] = arith.addf %[[VAL_58]], %[[VAL_60]] : f32
+// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_57]]] : memref<?xi1>
+// CHECK: %[[VAL_63:.*]] = arith.cmpi eq, %[[VAL_62]], %[[VAL_4]] : i1
+// CHECK: %[[VAL_64:.*]] = scf.if %[[VAL_63]] -> (index) {
+// CHECK: memref.store %[[VAL_5]], %[[VAL_24]]{{\[}}%[[VAL_57]]] : memref<?xi1>
+// CHECK: memref.store %[[VAL_57]], %[[VAL_25]]{{\[}}%[[VAL_56]]] : memref<?xindex>
+// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index
+// CHECK: scf.yield %[[VAL_65]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_58]] : index
+// CHECK: scf.yield %[[VAL_56]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_63]], %[[VAL_25]]{{\[}}%[[VAL_59]]] : memref<?xf32>
-// CHECK: scf.yield %[[VAL_68:.*]] : index
+// CHECK: memref.store %[[VAL_61]], %[[VAL_23]]{{\[}}%[[VAL_57]]] : memref<?xf32>
+// CHECK: scf.yield %[[VAL_66:.*]] : index
// CHECK: }
-// CHECK: scf.yield %[[VAL_69:.*]] : index
+// CHECK: scf.yield %[[VAL_67:.*]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_43]] : index
+// CHECK: scf.yield %[[VAL_41]] : index
// CHECK: }
-// CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index
-// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index
-// CHECK: %[[VAL_72:.*]] = arith.select %[[VAL_70]], %[[VAL_71]], %[[VAL_41]] : index
-// CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index
-// CHECK: %[[VAL_74:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index
-// CHECK: %[[VAL_75:.*]] = arith.select %[[VAL_73]], %[[VAL_74]], %[[VAL_42]] : index
-// CHECK: scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]] : index, index, index
+// CHECK: %[[VAL_68:.*]] = arith.cmpi eq, %[[VAL_42]], %[[VAL_45]] : index
+// CHECK: %[[VAL_69:.*]] = arith.addi %[[VAL_39]], %[[VAL_3]] : index
+// CHECK: %[[VAL_70:.*]] = arith.select %[[VAL_68]], %[[VAL_69]], %[[VAL_39]] : index
+// CHECK: %[[VAL_71:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_45]] : index
+// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_40]], %[[VAL_3]] : index
+// CHECK: %[[VAL_73:.*]] = arith.select %[[VAL_71]], %[[VAL_72]], %[[VAL_40]] : index
+// CHECK: scf.yield %[[VAL_70]], %[[VAL_73]], %[[VAL_74:.*]] : index, index, index
// CHECK: }
-// CHECK: sparse_tensor.compress %[[VAL_9]], %[[VAL_20]], %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_77:.*]]#2 : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>, memref<?xindex>, memref<?xf32>, memref<?xi1>, memref<?xindex>, index
+// CHECK: sparse_tensor.compress %[[VAL_23]], %[[VAL_24]], %[[VAL_25]], %[[VAL_75:.*]]#2 into %[[VAL_8]]{{\[}}%[[VAL_22]]] : memref<?xf32>, memref<?xi1>, memref<?xindex>, tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
-// CHECK: %[[VAL_78:.*]] = sparse_tensor.load %[[VAL_9]] hasInserts : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK: return %[[VAL_78]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_76:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: return %[[VAL_76]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
func.func @matmat(%arga: tensor<?x?xf32, #DCSR>,
%argb: tensor<?x?xf32, #DCSR>) -> tensor<?x?xf32, #DCSR> {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index 5d52fa5256502..d62604556f4a3 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -123,66 +123,63 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
return %3 : tensor<8x8xf64>
}
-
-// CHECK-LABEL: func @sparse_sampled_dd_unfused(
-// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding
-// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<8x8xf64>,
-// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<8x8xf64>)
-// CHECK-DAG: %[[TMP_c8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[TMP_false:.*]] = arith.constant false
-// CHECK-DAG: %[[TMP_true:.*]] = arith.constant true
-// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
-// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() copy(%[[TMP_cst]]) {bufferization.escape = [false]}
-// CHECK: %[[TMP_1:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]}
-// CHECK: %[[TMP_2:.*]] = bufferization.to_memref %[[TMP_arg1]] : memref<8x8xf64>
-// CHECK: %[[TMP_3:.*]] = bufferization.to_memref %[[TMP_arg2]] : memref<8x8xf64>
-// CHECK: %[[TMP_4:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 0 : index}
-// CHECK: %[[TMP_5:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 0 : index}
-// CHECK: %[[TMP_6:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 1 : index}
-// CHECK: %[[TMP_7:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 1 : index}
-// CHECK: %[[TMP_8:.*]] = sparse_tensor.values %[[TMP_arg0]]
-// CHECK: %[[TMP_9:.*]] = memref.alloca(%[[TMP_c2]]) : memref<?xindex>
-// CHECK: %[[TMP_10:.*]] = memref.load %[[TMP_4]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_11:.*]] = memref.load %[[TMP_4]][%[[TMP_c1]]] : memref<?xindex>
-// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_10]] to %[[TMP_11]] step %[[TMP_c1]] {
-// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_5]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: memref.store %[[TMP_13]], %[[TMP_9]][%[[TMP_c0]]] : memref<?xindex>
-// CHECK: %[[TMP_values:.*]], %[[TMP_filled:.*]], %[[TMP_added:.*]], %[[TMP_count:.*]] = sparse_tensor.expand %[[TMP_1]]
-// CHECK: %[[TMP_14:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_c0]] to %[[TMP_c8]] step %[[TMP_c1]] iter_args(%[[TMP_arg5:.*]] = %[[TMP_count]]) -> (index) {
-// CHECK: %[[TMP_15:.*]] = memref.load %[[TMP_2]][%[[TMP_13]], %[[TMP_arg4]]] : memref<8x8xf64>
-// CHECK: %[[TMP_16:.*]] = memref.load %[[TMP_6]][%[[TMP_arg3]]] : memref<?xindex>
-// CHECK: %[[TMP_17:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
-// CHECK: %[[TMP_18:.*]] = memref.load %[[TMP_6]][%[[TMP_17]]] : memref<?xindex>
-// CHECK: %[[TMP_19:.*]] = scf.for %[[TMP_arg6:.*]] = %[[TMP_16]] to %[[TMP_18]] step %[[TMP_c1]] iter_args(%[[TMP_arg7:.*]] = %[[TMP_arg5]]) -> (index) {
-// CHECK: %[[TMP_20:.*]] = memref.load %[[TMP_7]][%[[TMP_arg6]]] : memref<?xindex>
-// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_values]][%[[TMP_20]]] : memref<?xf64>
-// CHECK: %[[TMP_22:.*]] = memref.load %[[TMP_3]][%[[TMP_arg4]], %[[TMP_20]]] : memref<8x8xf64>
-// CHECK: %[[TMP_23:.*]] = arith.mulf %[[TMP_15]], %[[TMP_22]] : f64
-// CHECK: %[[TMP_24:.*]] = memref.load %[[TMP_8]][%[[TMP_arg6]]] : memref<?xf64>
-// CHECK: %[[TMP_25:.*]] = arith.mulf %[[TMP_23]], %[[TMP_24]] : f64
-// CHECK: %[[TMP_26:.*]] = arith.addf %[[TMP_21]], %[[TMP_25]] : f64
-// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_filled]][%[[TMP_20]]] : memref<?xi1>
-// CHECK: %[[TMP_28:.*]] = arith.cmpi eq, %[[TMP_27]], %[[TMP_false]] : i1
-// CHECK: %[[TMP_29:.*]] = scf.if %[[TMP_28]] -> (index) {
-// CHECK: memref.store %[[TMP_true]], %[[TMP_filled]][%[[TMP_20]]] : memref<?xi1>
-// CHECK: memref.store %[[TMP_20]], %[[TMP_added]][%[[TMP_arg7]]] : memref<?xindex>
-// CHECK: %[[TMP_30:.*]] = arith.addi %[[TMP_arg7]], %[[TMP_c1]] : index
-// CHECK: scf.yield %[[TMP_30]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[TMP_arg7]] : index
-// CHECK: }
-// CHECK: memref.store %[[TMP_26]], %[[TMP_values]][%[[TMP_20]]] : memref<?xf64>
-// CHECK: scf.yield %[[TMP_29]] : index
-// CHECK: }
-// CHECK: scf.yield %[[TMP_19]] : index
-// CHECK: }
-// CHECK: sparse_tensor.compress %[[TMP_1]], %[[TMP_9]], %[[TMP_values]], %[[TMP_filled]], %[[TMP_added]], %[[TMP_14]]
-// CHECK: }
-// CHECK: %[[TMP_12:.*]] = sparse_tensor.load %[[TMP_1]] hasInserts
-// CHECK: return %[[TMP_12]] : tensor<8x8xf64, #sparse_tensor.encoding
+// CHECK-LABEL: func.func @sparse_sampled_dd_unfused(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant false
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant true
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
+// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
+// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf64>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] {
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
+// CHECK: %[[VAL_26:.*]] = scf.for %[[VAL_27:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_28:.*]] = %[[VAL_25]]) -> (index) {
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_27]]] : memref<8x8xf64>
+// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_20]], %[[VAL_5]] : index
+// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK: %[[VAL_33:.*]] = scf.for %[[VAL_34:.*]] = %[[VAL_30]] to %[[VAL_32]] step %[[VAL_5]] iter_args(%[[VAL_35:.*]] = %[[VAL_28]]) -> (index) {
+// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_36]]] : memref<?xf64>
+// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]], %[[VAL_36]]] : memref<8x8xf64>
+// CHECK: %[[VAL_39:.*]] = arith.mulf %[[VAL_29]], %[[VAL_38]] : f64
+// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_34]]] : memref<?xf64>
+// CHECK: %[[VAL_41:.*]] = arith.mulf %[[VAL_39]], %[[VAL_40]] : f64
+// CHECK: %[[VAL_42:.*]] = arith.addf %[[VAL_37]], %[[VAL_41]] : f64
+// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_23]]{{\[}}%[[VAL_36]]] : memref<?xi1>
+// CHECK: %[[VAL_44:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_6]] : i1
+// CHECK: %[[VAL_45:.*]] = scf.if %[[VAL_44]] -> (index) {
+// CHECK: memref.store %[[VAL_7]], %[[VAL_23]]{{\[}}%[[VAL_36]]] : memref<?xi1>
+// CHECK: memref.store %[[VAL_36]], %[[VAL_24]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_35]], %[[VAL_5]] : index
+// CHECK: scf.yield %[[VAL_46]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[VAL_35]] : index
+// CHECK: }
+// CHECK: memref.store %[[VAL_42]], %[[VAL_22]]{{\[}}%[[VAL_36]]] : memref<?xf64>
+// CHECK: scf.yield %[[VAL_47:.*]] : index
+// CHECK: }
+// CHECK: scf.yield %[[VAL_48:.*]] : index
+// CHECK: }
+// CHECK: sparse_tensor.compress %[[VAL_22]], %[[VAL_23]], %[[VAL_24]], %[[VAL_49:.*]] into %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: }
+// CHECK: %[[VAL_50:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: return %[[VAL_50]] : tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: }
func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
%arga: tensor<8x8xf64>,
%argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
index 3fed42cf6418f..8b8b1c00e1762 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_transpose.mlir
@@ -16,38 +16,32 @@
// TODO: improve auto-conversion followed by yield
// CHECK-LABEL: func.func @sparse_transpose_auto(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
-// CHECK: %[[VAL_5:.*]] = sparse_tensor.convert %[[VAL_0]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>>
-// CHECK: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_5]] {dimension = 0 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xindex>
-// CHECK: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_5]] {dimension = 0 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xindex>
-// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_5]] {dimension = 1 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xindex>
-// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_5]] {dimension = 1 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xindex>
-// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xf64>
-// CHECK: %[[VAL_11:.*]] = memref.alloca(%[[VAL_3]]) : memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<f64>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_2]] {
+// CHECK: %[[VAL_3:.*]] = bufferization.alloc_tensor() : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.convert %[[VAL_0]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>>
+// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_4]] {dimension = 0 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xindex>
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_4]] {dimension = 0 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_4]] {dimension = 1 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xindex>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_4]] {dimension = 1 : index} : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_4]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> to memref<?xf64>
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] {
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_2]] : index
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_16]], %[[VAL_11]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_2]] {
-// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xf64>
-// CHECK: memref.store %[[VAL_22]], %[[VAL_12]][] : memref<f64>
-// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_11]], %[[VAL_12]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, memref<?xindex>, memref<f64>
+// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_2]] {
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<?xf64>
+// CHECK: sparse_tensor.insert %[[VAL_19]] into %[[VAL_3]]{{\[}}%[[VAL_13]], %[[VAL_18]]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
// CHECK: }
-// CHECK: %[[VAL_23:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
-// CHECK: bufferization.dealloc_tensor %[[VAL_5]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>>
-// CHECK: return %[[VAL_23]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK: bufferization.dealloc_tensor %[[VAL_4]] : tensor<3x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>>
+// CHECK: return %[[VAL_20]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK: }
func.func @sparse_transpose_auto(%arga: tensor<3x4xf64, #DCSR>)
-> tensor<4x3xf64, #DCSR> {
More information about the Mlir-commits
mailing list