[Mlir-commits] [mlir] 2aceadd - [mlir][sparse] Put the implementation for the insertion operation to subroutines.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 1 13:24:04 PST 2022
Author: bixia1
Date: 2022-12-01T13:23:59-08:00
New Revision: 2aceadda78209f3ab0f69c8c277f4051fa161644
URL: https://github.com/llvm/llvm-project/commit/2aceadda78209f3ab0f69c8c277f4051fa161644
DIFF: https://github.com/llvm/llvm-project/commit/2aceadda78209f3ab0f69c8c277f4051fa161644.diff
LOG: [mlir][sparse] Put the implementation for the insertion operation to subroutines.
Previously, we generated inlined implementation for insert operation and
observed MLIR compile time increase due to the size of the main routine. We now
put the insert operation implementation in subroutines and leave the inlining
decision to the MLIR compiler.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D138957
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 775b20dab1751..97f1f952e5bd5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
namespace {
+using FuncGeneratorType =
+ function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>;
+
+static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
+
static constexpr uint64_t dimSizesIdx = 0;
static constexpr uint64_t memSizesIdx = 1;
static constexpr uint64_t fieldsIdx = 2;
@@ -476,12 +481,24 @@ static Value genCompressed(OpBuilder &builder, Location loc,
///
/// TODO: better unord/not-unique; also generalize, optimize, specialize!
///
-static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
- SmallVectorImpl<Value> &fields,
- SmallVectorImpl<Value> &indices, Value value) {
+static void genInsertBody(OpBuilder &builder, ModuleOp module,
+ func::FuncOp func, RankedTensorType rtp) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Block *entryBlock = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBlock);
+
+ Location loc = func.getLoc();
+ ValueRange args = entryBlock->getArguments();
unsigned rank = rtp.getShape().size();
- assert(rank == indices.size());
- unsigned field = fieldsIdx; // start past header
+
+ // Construct fields and indices arrays from parameters.
+ ValueRange tmp = args.drop_back(rank + 1);
+ SmallVector<Value> fields(tmp.begin(), tmp.end());
+ tmp = args.take_back(rank + 1).drop_back();
+ SmallVector<Value> indices(tmp.begin(), tmp.end());
+ Value value = args.back();
+
+ unsigned field = fieldsIdx; // Start past header.
Value pos = constantZero(builder, loc, builder.getIndexType());
// Generate code for every dimension.
for (unsigned d = 0; d < rank; d++) {
@@ -519,6 +536,77 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
else
genStore(builder, loc, value, fields[field++], pos);
assert(fields.size() == field);
+ builder.create<func::ReturnOp>(loc, fields);
+}
+
+/// Generates a call to a function to perform an insertion operation. If the
+/// function doesn't exist yet, call `createFunc` to generate the function.
+static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields,
+ SmallVectorImpl<Value> &indices, Value value,
+ func::FuncOp insertPoint,
+ StringRef namePrefix,
+ FuncGeneratorType createFunc) {
+ // The mangled name of the function has this format:
+ // <namePrefix>_[C|S|D]_<shape>_<ordering>_<eltType>
+ // _<indexBitWidth>_<pointerBitWidth>
+ SmallString<32> nameBuffer;
+ llvm::raw_svector_ostream nameOstream(nameBuffer);
+ nameOstream << namePrefix;
+ unsigned rank = rtp.getShape().size();
+ assert(rank == indices.size());
+ for (unsigned d = 0; d < rank; d++) {
+ if (isCompressedDim(rtp, d)) {
+ nameOstream << "C_";
+ } else if (isSingletonDim(rtp, d)) {
+ nameOstream << "S_";
+ } else {
+ nameOstream << "D_";
+ }
+ }
+ // Static dim sizes are used in the generated code while dynamic sizes are
+ // loaded from the dimSizes buffer. This is the reason for adding the shape
+ // to the function name.
+ for (auto d : rtp.getShape())
+ nameOstream << d << "_";
+ SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp);
+ // Permutation information is also used in generating insertion.
+ if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity())
+ nameOstream << enc.getDimOrdering() << "_";
+ nameOstream << rtp.getElementType() << "_";
+ nameOstream << enc.getIndexBitWidth() << "_" << enc.getPointerBitWidth();
+
+ // Look up the function.
+ ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
+ MLIRContext *context = module.getContext();
+ auto result = SymbolRefAttr::get(context, nameOstream.str());
+ auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+
+ // Construct parameters for fields and indices.
+ SmallVector<Value> operands(fields.begin(), fields.end());
+ operands.append(indices.begin(), indices.end());
+ operands.push_back(value);
+ Location loc = insertPoint.getLoc();
+
+ if (!func) {
+ // Create the function.
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPoint(insertPoint);
+
+ func = builder.create<func::FuncOp>(
+ loc, nameOstream.str(),
+ FunctionType::get(context, ValueRange(operands).getTypes(),
+ ValueRange(fields).getTypes()));
+ func.setPrivate();
+ createFunc(builder, module, func, rtp);
+ }
+
+ // Generate a call to perform the insertion and update `fields` with values
+ // returned from the call.
+ func::CallOp call = builder.create<func::CallOp>(loc, func, operands);
+ for (size_t i = 0; i < fields.size(); i++) {
+ fields[i] = call.getResult(i);
+ }
}
/// Generations insertion finalization code.
@@ -865,7 +953,9 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
Value value = genLoad(rewriter, loc, values, index);
indices.push_back(index);
// TODO: faster for subsequent insertions?
- genInsert(rewriter, loc, dstType, fields, indices, value);
+ auto insertPoint = op->template getParentOfType<func::FuncOp>();
+ genInsertionCallHelper(rewriter, dstType, fields, indices, value,
+ insertPoint, kInsertFuncNamePrefix, genInsertBody);
genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values,
index);
genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
@@ -899,7 +989,10 @@ class SparseInsertConverter : public OpConversionPattern<InsertOp> {
SmallVector<Value> indices(adaptor.getIndices());
// Generate insertion.
Value value = adaptor.getValue();
- genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
+ auto insertPoint = op->template getParentOfType<func::FuncOp>();
+ genInsertionCallHelper(rewriter, dstType, fields, indices, value,
+ insertPoint, kInsertFuncNamePrefix, genInsertBody);
+
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
return success();
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 660cd7552db8a..0e3eb03bda78c 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -381,6 +381,17 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
return %added : memref<?xindex>
}
+// CHECK-LABEL: func.func private @_insert_C_100_f64_0_0(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: f64)
+// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+//
// CHECK-LABEL: func @sparse_compression_1d(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
@@ -396,18 +407,19 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
-// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
+// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: %[[C:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]])
// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-// CHECK: scf.yield %{{.*}}, %[[PV]] : memref<?xindex>, memref<?xf64>
+// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK: }
// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
+// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
%values: memref<?xf64>,
@@ -420,6 +432,18 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
return %1 : tensor<100xf64, #SV>
}
+// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: index,
+// CHECK-SAME: %[[A7:.*7]]: f64)
+// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+//
// CHECK-LABEL: func @sparse_compression(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
@@ -436,18 +460,19 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
-// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xi64>, memref<?xf64>) {
+// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>) {
// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]])
// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-// CHECK: scf.yield %{{.*}}, %[[PV]] : memref<?xi64>, memref<?xf64>
+// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
// CHECK: }
// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
+// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
%values: memref<?xf64>,
@@ -461,6 +486,18 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
return %1 : tensor<8x8xf64, #CSR>
}
+// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: index,
+// CHECK-SAME: %[[A7:.*7]]: f64)
+// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+//
// CHECK-LABEL: func @sparse_compression_unordered(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
@@ -477,18 +514,19 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-NOT: sparse_tensor.sort
-// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
+// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
+// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]])
// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-// CHECK: scf.yield %{{.*}}, %[[PV]] : memref<?xindex>, memref<?xf64>
+// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK: }
// CHECK: memref.dealloc %[[A5]] : memref<?xf64>
// CHECK: memref.dealloc %[[A6]] : memref<?xi1>
// CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
+// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
%values: memref<?xf64>,
@@ -502,7 +540,7 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
return %1 : tensor<8x8xf64, #UCSR>
}
-// CHECK-LABEL: func @sparse_insert(
+// CHECK-LABEL: func.func private @_insert_C_128_f64_0_0(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
@@ -512,6 +550,16 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
// CHECK-SAME: %[[A6:.*6]]: f64)
// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] :
+// CHECK: func @sparse_insert(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: f64)
+// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
+// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
@@ -519,7 +567,7 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
return %1 : tensor<128xf64, #SV>
}
-// CHECK-LABEL: func @sparse_insert_typed(
+// CHECK-LABEL: func.func private @_insert_C_128_f64_64_32(
// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
@@ -529,6 +577,16 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
// CHECK-SAME: %[[A6:.*6]]: f64)
// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] :
+// CHECK: func @sparse_insert_typed(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: f64)
+// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
+// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
@@ -547,4 +605,4 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
-}
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index daaf04a98dfc0..3808a5b200749 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -12,6 +12,43 @@
//
// Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
//
+// CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>,
+// CHECK-SAME: %[[VAL_2:[^ ]+]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_3:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_4:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_5:[^ ]+]]: index,
+// CHECK-SAME: %[[VAL_6:.*]]: index,
+// CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant false
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex>
+// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index
+// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
+// CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index
+// CHECK: scf.yield %[[VAL_18]] : i1
+// CHECK: } else {
+// CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: scf.yield %[[VAL_8]] : i1
+// CHECK: }
+// CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref<?xindex>) {
+// CHECK: scf.yield %[[VAL_3]] : memref<?xindex>
+// CHECK: } else {
+// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index
+// CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
+// CHECK: scf.yield %[[VAL_22]] : memref<?xindex>
+// CHECK: }
+// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK: }
+
// CHECK-LABEL: func.func @matmul(
// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>,
// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>,
@@ -22,22 +59,21 @@
// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>,
// CHECK-SAME: %[[VAL_7:.*7]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_8:.*8]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_9:.*9]]: memref<?xf64>)
-// CHECK-SAME: -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK-SAME: %[[VAL_9:.*9]]: memref<?xf64>) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true
-// CHECK-DAG: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex>
-// CHECK-DAG: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex>
-// CHECK-DAG: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex>
-// CHECK-DAG: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref<?xindex>
-// CHECK-DAG: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex>
-// CHECK-DAG: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref<?xindex>
-// CHECK-DAG: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64>
-// CHECK-DAG: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref<?xf64>
+// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex>
+// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex>
+// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex>
+// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64>
+// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref<?xf64>
// CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>)
// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex>
// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex>
@@ -49,84 +85,61 @@
// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref<?xindex>
// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>)
// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>)
-// CHECK: %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_21]], %[[VAL_33:.*]] = %[[VAL_23]]) -> (memref<?xindex>, memref<?xf64>) {
-// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref<?xindex>
-// CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index
-// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_35]]] : memref<?xindex>
-// CHECK: %[[VAL_37:.*]] = scf.for %[[VAL_38:.*]] = %[[VAL_34]] to %[[VAL_36]] step %[[VAL_13]] iter_args(%[[VAL_39:.*]] = %[[VAL_12]]) -> (index) {
-// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_38]]] : memref<?xf64>
-// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_40]]] : memref<?xindex>
-// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_13]] : index
-// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref<?xindex>
-// CHECK: %[[VAL_45:.*]] = scf.for %[[VAL_46:.*]] = %[[VAL_42]] to %[[VAL_44]] step %[[VAL_13]] iter_args(%[[VAL_47:.*]] = %[[VAL_39]]) -> (index) {
-// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_46]]] : memref<?xindex>
-// CHECK: %[[VAL_49:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64>
-// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_46]]] : memref<?xf64>
-// CHECK: %[[VAL_51:.*]] = arith.mulf %[[VAL_41]], %[[VAL_50]] : f64
-// CHECK: %[[VAL_52:.*]] = arith.addf %[[VAL_49]], %[[VAL_51]] : f64
-// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1>
-// CHECK: %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_53]], %[[VAL_14]] : i1
-// CHECK: %[[VAL_55:.*]] = scf.if %[[VAL_54]] -> (index) {
-// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1>
-// CHECK: memref.store %[[VAL_48]], %[[VAL_28]]{{\[}}%[[VAL_47]]] : memref<4xindex>
-// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_47]], %[[VAL_13]] : index
-// CHECK: scf.yield %[[VAL_56]] : index
+// CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index
+// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref<?xindex>
+// CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) {
+// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref<?xindex>
+// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref<?xf64>
+// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref<?xindex>
+// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index
+// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref<?xindex>
+// CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) {
+// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref<?xindex>
+// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
+// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref<?xf64>
+// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64
+// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64
+// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
+// CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1
+// CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) {
+// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
+// CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex>
+// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index
+// CHECK: scf.yield %[[VAL_59]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_47]] : index
+// CHECK: scf.yield %[[VAL_50]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_52]], %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64>
-// CHECK: scf.yield %[[VAL_57:.*]] : index
-// CHECK: }
-// CHECK: scf.yield %[[VAL_58:.*]] : index
-// CHECK: }
-// CHECK: sparse_tensor.sort %[[VAL_59:.*]], %[[VAL_29]] : memref<?xindex>
-// CHECK: %[[VAL_60:.*]]:2 = scf.for %[[VAL_61:.*]] = %[[VAL_12]] to %[[VAL_59]] step %[[VAL_13]] iter_args(%[[VAL_62:.*]] = %[[VAL_32]], %[[VAL_63:.*]] = %[[VAL_33]]) -> (memref<?xindex>, memref<?xf64>) {
-// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_61]]] : memref<4xindex>
-// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64>
-// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref<?xindex>
-// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref<?xindex>
-// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_13]]] : memref<3xindex>
-// CHECK: %[[VAL_69:.*]] = arith.subi %[[VAL_67]], %[[VAL_13]] : index
-// CHECK: %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_67]] : index
-// CHECK: %[[VAL_71:.*]] = scf.if %[[VAL_70]] -> (i1) {
-// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_62]]{{\[}}%[[VAL_69]]] : memref<?xindex>
-// CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_72]], %[[VAL_64]] : index
-// CHECK: scf.yield %[[VAL_73]] : i1
-// CHECK: } else {
-// CHECK: memref.store %[[VAL_68]], %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref<?xindex>
-// CHECK: scf.yield %[[VAL_14]] : i1
-// CHECK: }
-// CHECK: %[[VAL_74:.*]] = scf.if %[[VAL_75:.*]] -> (memref<?xindex>) {
-// CHECK: scf.yield %[[VAL_62]] : memref<?xindex>
-// CHECK: } else {
-// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_68]], %[[VAL_13]] : index
-// CHECK: memref.store %[[VAL_76]], %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref<?xindex>
-// CHECK: %[[VAL_77:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_62]], %[[VAL_64]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
-// CHECK: scf.yield %[[VAL_77]] : memref<?xindex>
-// CHECK: }
-// CHECK: %[[VAL_78:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_63]], %[[VAL_65]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64>
-// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_64]]] : memref<4xi1>
-// CHECK: scf.yield %[[VAL_79:.*]], %[[VAL_78]] : memref<?xindex>, memref<?xf64>
+// CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
+// CHECK: scf.yield %[[VAL_60:.*]] : index
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref<?xindex>
+// CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex>
+// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
+// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>)
+// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
+// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1>
+// CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK: }
-// CHECK: scf.yield %[[VAL_80:.*]]#0, %[[VAL_80]]#1 : memref<?xindex>, memref<?xf64>
-// CHECK: }
+// CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64>
// CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1>
// CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex>
-// CHECK: %[[VAL_81:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_12]]] : memref<3xindex>
-// CHECK: %[[VAL_82:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK: %[[VAL_83:.*]] = scf.for %[[VAL_84:.*]] = %[[VAL_13]] to %[[VAL_81]] step %[[VAL_13]] iter_args(%[[VAL_85:.*]] = %[[VAL_82]]) -> (index) {
-// CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref<?xindex>
-// CHECK: %[[VAL_87:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_12]] : index
-// CHECK: %[[VAL_88:.*]] = arith.select %[[VAL_87]], %[[VAL_85]], %[[VAL_86]] : index
-// CHECK: scf.if %[[VAL_87]] {
-// CHECK: memref.store %[[VAL_85]], %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref<?xindex>
+// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex>
+// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) {
+// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
+// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index
+// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index
+// CHECK: scf.if %[[VAL_81]] {
+// CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
// CHECK: }
-// CHECK: scf.yield %[[VAL_88]] : index
+// CHECK: scf.yield %[[VAL_82]] : index
// CHECK: }
-// CHECK: return %[[VAL_16]], %[[VAL_17]], %[[VAL_25]], %[[VAL_89:.*]]#0, %[[VAL_89]]#1 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK: }
func.func @matmul(%A: tensor<4x8xf64, #CSR>,
%B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
More information about the Mlir-commits
mailing list