[Mlir-commits] [mlir] 2aceadd - [mlir][sparse] Put the implementation for the insertion operation to subroutines.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 1 13:24:04 PST 2022


Author: bixia1
Date: 2022-12-01T13:23:59-08:00
New Revision: 2aceadda78209f3ab0f69c8c277f4051fa161644

URL: https://github.com/llvm/llvm-project/commit/2aceadda78209f3ab0f69c8c277f4051fa161644
DIFF: https://github.com/llvm/llvm-project/commit/2aceadda78209f3ab0f69c8c277f4051fa161644.diff

LOG: [mlir][sparse] Put the implementation for the insertion operation to subroutines.

Previously, we generated inlined implementation for insert operation and
observed MLIR compile time increase due to the size of the main routine. We now
put the insert operation implementation in subroutines and leave the inlining
decision to the MLIR compiler.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D138957

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir
    mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 775b20dab1751..97f1f952e5bd5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
+using FuncGeneratorType =
+    function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>;
+
+static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
+
 static constexpr uint64_t dimSizesIdx = 0;
 static constexpr uint64_t memSizesIdx = 1;
 static constexpr uint64_t fieldsIdx = 2;
@@ -476,12 +481,24 @@ static Value genCompressed(OpBuilder &builder, Location loc,
 ///
 /// TODO: better unord/not-unique; also generalize, optimize, specialize!
 ///
-static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
-                      SmallVectorImpl<Value> &fields,
-                      SmallVectorImpl<Value> &indices, Value value) {
+static void genInsertBody(OpBuilder &builder, ModuleOp module,
+                          func::FuncOp func, RankedTensorType rtp) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
   unsigned rank = rtp.getShape().size();
-  assert(rank == indices.size());
-  unsigned field = fieldsIdx; // start past header
+
+  // Construct fields and indices arrays from parameters.
+  ValueRange tmp = args.drop_back(rank + 1);
+  SmallVector<Value> fields(tmp.begin(), tmp.end());
+  tmp = args.take_back(rank + 1).drop_back();
+  SmallVector<Value> indices(tmp.begin(), tmp.end());
+  Value value = args.back();
+
+  unsigned field = fieldsIdx; // Start past header.
   Value pos = constantZero(builder, loc, builder.getIndexType());
   // Generate code for every dimension.
   for (unsigned d = 0; d < rank; d++) {
@@ -519,6 +536,77 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
   else
     genStore(builder, loc, value, fields[field++], pos);
   assert(fields.size() == field);
+  builder.create<func::ReturnOp>(loc, fields);
+}
+
+/// Generates a call to a function to perform an insertion operation. If the
+/// function doesn't exist yet, call `createFunc` to generate the function.
+static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
+                                   SmallVectorImpl<Value> &fields,
+                                   SmallVectorImpl<Value> &indices, Value value,
+                                   func::FuncOp insertPoint,
+                                   StringRef namePrefix,
+                                   FuncGeneratorType createFunc) {
+  // The mangled name of the function has this format:
+  //   <namePrefix>_[C|S|D]_<shape>_<ordering>_<eltType>
+  //     _<indexBitWidth>_<pointerBitWidth>
+  SmallString<32> nameBuffer;
+  llvm::raw_svector_ostream nameOstream(nameBuffer);
+  nameOstream << namePrefix;
+  unsigned rank = rtp.getShape().size();
+  assert(rank == indices.size());
+  for (unsigned d = 0; d < rank; d++) {
+    if (isCompressedDim(rtp, d)) {
+      nameOstream << "C_";
+    } else if (isSingletonDim(rtp, d)) {
+      nameOstream << "S_";
+    } else {
+      nameOstream << "D_";
+    }
+  }
+  // Static dim sizes are used in the generated code while dynamic sizes are
+  // loaded from the dimSizes buffer. This is the reason for adding the shape
+  // to the function name.
+  for (auto d : rtp.getShape())
+    nameOstream << d << "_";
+  SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp);
+  // Permutation information is also used in generating insertion.
+  if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity())
+    nameOstream << enc.getDimOrdering() << "_";
+  nameOstream << rtp.getElementType() << "_";
+  nameOstream << enc.getIndexBitWidth() << "_" << enc.getPointerBitWidth();
+
+  // Look up the function.
+  ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
+  MLIRContext *context = module.getContext();
+  auto result = SymbolRefAttr::get(context, nameOstream.str());
+  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+
+  // Construct parameters for fields and indices.
+  SmallVector<Value> operands(fields.begin(), fields.end());
+  operands.append(indices.begin(), indices.end());
+  operands.push_back(value);
+  Location loc = insertPoint.getLoc();
+
+  if (!func) {
+    // Create the function.
+    OpBuilder::InsertionGuard insertionGuard(builder);
+    builder.setInsertionPoint(insertPoint);
+
+    func = builder.create<func::FuncOp>(
+        loc, nameOstream.str(),
+        FunctionType::get(context, ValueRange(operands).getTypes(),
+                          ValueRange(fields).getTypes()));
+    func.setPrivate();
+    createFunc(builder, module, func, rtp);
+  }
+
+  // Generate a call to perform the insertion and update `fields` with values
+  // returned from the call.
+  func::CallOp call = builder.create<func::CallOp>(loc, func, operands);
+  for (size_t i = 0; i < fields.size(); i++) {
+    fields[i] = call.getResult(i);
+  }
 }
 
 /// Generations insertion finalization code.
@@ -865,7 +953,9 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     Value value = genLoad(rewriter, loc, values, index);
     indices.push_back(index);
     // TODO: faster for subsequent insertions?
-    genInsert(rewriter, loc, dstType, fields, indices, value);
+    auto insertPoint = op->template getParentOfType<func::FuncOp>();
+    genInsertionCallHelper(rewriter, dstType, fields, indices, value,
+                           insertPoint, kInsertFuncNamePrefix, genInsertBody);
     genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values,
              index);
     genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
@@ -899,7 +989,10 @@ class SparseInsertConverter : public OpConversionPattern<InsertOp> {
     SmallVector<Value> indices(adaptor.getIndices());
     // Generate insertion.
     Value value = adaptor.getValue();
-    genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
+    auto insertPoint = op->template getParentOfType<func::FuncOp>();
+    genInsertionCallHelper(rewriter, dstType, fields, indices, value,
+                           insertPoint, kInsertFuncNamePrefix, genInsertBody);
+
     // Replace operation with resulting memrefs.
     rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
     return success();

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 660cd7552db8a..0e3eb03bda78c 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -381,6 +381,17 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
   return %added : memref<?xindex>
 }
 
+// CHECK-LABEL: func.func private @_insert_C_100_f64_0_0(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: index,
+//  CHECK-SAME: %[[A6:.*6]]: f64)
+//       CHECK:   %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+//
 // CHECK-LABEL: func @sparse_compression_1d(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
@@ -396,18 +407,19 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
-//       CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
+//       CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
+//  CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
 //       CHECK:   %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
 //       CHECK:   %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-//       CHECK:   %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+//       CHECK:   %[[C:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]])
 //       CHECK:   memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
 //       CHECK:   memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-//       CHECK:   scf.yield %{{.*}}, %[[PV]] : memref<?xindex>, memref<?xf64>
+//       CHECK:   scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 //       CHECK: }
 //       CHECK: memref.dealloc %[[A5]] : memref<?xf64>
 //       CHECK: memref.dealloc %[[A6]] : memref<?xi1>
 //       CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
+//       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
 //  CHECK-SAME:   memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
                                  %values: memref<?xf64>,
@@ -420,6 +432,18 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
   return %1 : tensor<100xf64, #SV>
 }
 
+// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: index,
+//  CHECK-SAME: %[[A6:.*6]]: index,
+//  CHECK-SAME: %[[A7:.*7]]: f64)
+//       CHECK:   %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+//
 // CHECK-LABEL: func @sparse_compression(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
@@ -436,18 +460,19 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
-//       CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xi64>, memref<?xf64>) {
+//       CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
+//  CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>) {
 //       CHECK:   %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
 //       CHECK:   %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-//       CHECK:   %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+//       CHECK:   %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]])
 //       CHECK:   memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
 //       CHECK:   memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-//       CHECK:   scf.yield %{{.*}}, %[[PV]] : memref<?xi64>, memref<?xf64>
+//       CHECK:   scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
 //       CHECK: }
 //       CHECK: memref.dealloc %[[A5]] : memref<?xf64>
 //       CHECK: memref.dealloc %[[A6]] : memref<?xi1>
 //       CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
+//       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
 //  CHECK-SAME:   memref<2xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
 func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
                               %values: memref<?xf64>,
@@ -461,6 +486,18 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
   return %1 : tensor<8x8xf64, #CSR>
 }
 
+// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: index,
+//  CHECK-SAME: %[[A6:.*6]]: index,
+//  CHECK-SAME: %[[A7:.*7]]: f64)
+//       CHECK:   %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]]
+//
 // CHECK-LABEL: func @sparse_compression_unordered(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
@@ -477,18 +514,19 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-NOT: sparse_tensor.sort
-//       CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
+//       CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]]
+//  CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
 //       CHECK:   %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
 //       CHECK:   %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref<?xf64>
-//       CHECK:   %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[VAL]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+//       CHECK:   %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]])
 //       CHECK:   memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
 //       CHECK:   memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
-//       CHECK:   scf.yield %{{.*}}, %[[PV]] : memref<?xindex>, memref<?xf64>
+//       CHECK:   scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 //       CHECK: }
 //       CHECK: memref.dealloc %[[A5]] : memref<?xf64>
 //       CHECK: memref.dealloc %[[A6]] : memref<?xi1>
 //       CHECK: memref.dealloc %[[A7]] : memref<?xindex>
-//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[R]]#0, %[[R]]#1
+//       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
 //  CHECK-SAME:   memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
                                         %values: memref<?xf64>,
@@ -502,7 +540,7 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
   return %1 : tensor<8x8xf64, #UCSR>
 }
 
-// CHECK-LABEL: func @sparse_insert(
+// CHECK-LABEL: func.func private @_insert_C_128_f64_0_0(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
 //  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
@@ -512,6 +550,16 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
 //  CHECK-SAME: %[[A6:.*6]]: f64)
 //       CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]  {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
 //       CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] :
+//       CHECK: func @sparse_insert(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: index,
+//  CHECK-SAME: %[[A6:.*6]]: f64)
+//       CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
+//       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
 //  CHECK-SAME:   memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
   %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
@@ -519,7 +567,7 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
   return %1 : tensor<128xf64, #SV>
 }
 
-// CHECK-LABEL: func @sparse_insert_typed(
+// CHECK-LABEL: func.func private @_insert_C_128_f64_64_32(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
 //  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
@@ -529,6 +577,16 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
 //  CHECK-SAME: %[[A6:.*6]]: f64)
 //       CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]  {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
 //       CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] :
+//       CHECK: func @sparse_insert_typed(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: index,
+//  CHECK-SAME: %[[A6:.*6]]: f64)
+//       CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
+//       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4
 //  CHECK-SAME:   memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
 func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
   %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
@@ -547,4 +605,4 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
 func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
   %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xf32, #SparseVector>
   return %0 : tensor<?xf32, #SparseVector>
-}
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index daaf04a98dfc0..3808a5b200749 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -12,6 +12,43 @@
 //
 // Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
 //
+// CHECK-LABEL:   func.func private @_insert_D_C_4_4_f64_0_0(
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<2xindex>,
+// CHECK-SAME:      %[[VAL_1:.*]]: memref<3xindex>,
+// CHECK-SAME:      %[[VAL_2:[^ ]+]]: memref<?xindex>,
+// CHECK-SAME:      %[[VAL_3:.*]]: memref<?xindex>,
+// CHECK-SAME:      %[[VAL_4:.*]]: memref<?xf64>,
+// CHECK-SAME:      %[[VAL_5:[^ ]+]]: index,
+// CHECK-SAME:      %[[VAL_6:.*]]: index,
+// CHECK-SAME:      %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex>
+// CHECK:           %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
+// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK:             %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_18]] : i1
+// CHECK:           } else {
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:             scf.yield %[[VAL_8]] : i1
+// CHECK:           }
+// CHECK:           %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref<?xindex>) {
+// CHECK:             scf.yield %[[VAL_3]] : memref<?xindex>
+// CHECK:           } else {
+// CHECK:             %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index
+// CHECK:             memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
+// CHECK:             scf.yield %[[VAL_22]] : memref<?xindex>
+// CHECK:           }
+// CHECK:           %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK:           return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK:         }
+
 // CHECK-LABEL:   func.func @matmul(
 // CHECK-SAME:      %[[VAL_0:.*0]]: memref<2xindex>,
 // CHECK-SAME:      %[[VAL_1:.*1]]: memref<3xindex>,
@@ -22,22 +59,21 @@
 // CHECK-SAME:      %[[VAL_6:.*6]]: memref<3xindex>,
 // CHECK-SAME:      %[[VAL_7:.*7]]: memref<?xindex>,
 // CHECK-SAME:      %[[VAL_8:.*8]]: memref<?xindex>,
-// CHECK-SAME:      %[[VAL_9:.*9]]: memref<?xf64>)
-// CHECK-SAME: -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK-SAME:      %[[VAL_9:.*9]]: memref<?xf64>) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 4 : index
 // CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64
 // CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_14:.*]] = arith.constant false
 // CHECK-DAG:       %[[VAL_15:.*]] = arith.constant true
-// CHECK-DAG:       %[[VAL_16:.*]] = memref.alloc() : memref<2xindex>
-// CHECK-DAG:       %[[VAL_17:.*]] = memref.alloc() : memref<3xindex>
-// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloc() : memref<16xindex>
-// CHECK-DAG:       %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_20:.*]] = memref.alloc() : memref<16xindex>
-// CHECK-DAG:       %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_22:.*]] = memref.alloc() : memref<16xf64>
-// CHECK-DAG:       %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref<?xf64>
+// CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<2xindex>
+// CHECK:           %[[VAL_17:.*]] = memref.alloc() : memref<3xindex>
+// CHECK:           %[[VAL_18:.*]] = memref.alloc() : memref<16xindex>
+// CHECK:           %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref<?xindex>
+// CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<16xindex>
+// CHECK:           %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref<?xindex>
+// CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<16xf64>
+// CHECK:           %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref<?xf64>
 // CHECK:           linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>)
 // CHECK:           memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex>
 // CHECK:           memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex>
@@ -49,84 +85,61 @@
 // CHECK:           %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref<?xindex>
 // CHECK:           linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>)
 // CHECK:           linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>)
-// CHECK:           %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_21]], %[[VAL_33:.*]] = %[[VAL_23]]) -> (memref<?xindex>, memref<?xf64>) {
-// CHECK:             %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref<?xindex>
-// CHECK:             %[[VAL_35:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index
-// CHECK:             %[[VAL_36:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_35]]] : memref<?xindex>
-// CHECK:             %[[VAL_37:.*]] = scf.for %[[VAL_38:.*]] = %[[VAL_34]] to %[[VAL_36]] step %[[VAL_13]] iter_args(%[[VAL_39:.*]] = %[[VAL_12]]) -> (index) {
-// CHECK:               %[[VAL_40:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK:               %[[VAL_41:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_38]]] : memref<?xf64>
-// CHECK:               %[[VAL_42:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_40]]] : memref<?xindex>
-// CHECK:               %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_13]] : index
-// CHECK:               %[[VAL_44:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref<?xindex>
-// CHECK:               %[[VAL_45:.*]] = scf.for %[[VAL_46:.*]] = %[[VAL_42]] to %[[VAL_44]] step %[[VAL_13]] iter_args(%[[VAL_47:.*]] = %[[VAL_39]]) -> (index) {
-// CHECK:                 %[[VAL_48:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_46]]] : memref<?xindex>
-// CHECK:                 %[[VAL_49:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64>
-// CHECK:                 %[[VAL_50:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_46]]] : memref<?xf64>
-// CHECK:                 %[[VAL_51:.*]] = arith.mulf %[[VAL_41]], %[[VAL_50]] : f64
-// CHECK:                 %[[VAL_52:.*]] = arith.addf %[[VAL_49]], %[[VAL_51]] : f64
-// CHECK:                 %[[VAL_53:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1>
-// CHECK:                 %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_53]], %[[VAL_14]] : i1
-// CHECK:                 %[[VAL_55:.*]] = scf.if %[[VAL_54]] -> (index) {
-// CHECK:                   memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1>
-// CHECK:                   memref.store %[[VAL_48]], %[[VAL_28]]{{\[}}%[[VAL_47]]] : memref<4xindex>
-// CHECK:                   %[[VAL_56:.*]] = arith.addi %[[VAL_47]], %[[VAL_13]] : index
-// CHECK:                   scf.yield %[[VAL_56]] : index
+// CHECK:           %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK:             %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK:             %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index
+// CHECK:             %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref<?xindex>
+// CHECK:             %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) {
+// CHECK:               %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref<?xindex>
+// CHECK:               %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref<?xf64>
+// CHECK:               %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref<?xindex>
+// CHECK:               %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index
+// CHECK:               %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref<?xindex>
+// CHECK:               %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) {
+// CHECK:                 %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref<?xindex>
+// CHECK:                 %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
+// CHECK:                 %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref<?xf64>
+// CHECK:                 %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64
+// CHECK:                 %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64
+// CHECK:                 %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
+// CHECK:                 %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1
+// CHECK:                 %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) {
+// CHECK:                   memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
+// CHECK:                   memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex>
+// CHECK:                   %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index
+// CHECK:                   scf.yield %[[VAL_59]] : index
 // CHECK:                 } else {
-// CHECK:                   scf.yield %[[VAL_47]] : index
+// CHECK:                   scf.yield %[[VAL_50]] : index
 // CHECK:                 }
-// CHECK:                 memref.store %[[VAL_52]], %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64>
-// CHECK:                 scf.yield %[[VAL_57:.*]] : index
-// CHECK:               }
-// CHECK:               scf.yield %[[VAL_58:.*]] : index
-// CHECK:             }
-// CHECK:             sparse_tensor.sort %[[VAL_59:.*]], %[[VAL_29]] : memref<?xindex>
-// CHECK:             %[[VAL_60:.*]]:2 = scf.for %[[VAL_61:.*]] = %[[VAL_12]] to %[[VAL_59]] step %[[VAL_13]] iter_args(%[[VAL_62:.*]] = %[[VAL_32]], %[[VAL_63:.*]] = %[[VAL_33]]) -> (memref<?xindex>, memref<?xf64>) {
-// CHECK:               %[[VAL_64:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_61]]] : memref<4xindex>
-// CHECK:               %[[VAL_65:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64>
-// CHECK:               %[[VAL_66:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref<?xindex>
-// CHECK:               %[[VAL_67:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref<?xindex>
-// CHECK:               %[[VAL_68:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_13]]] : memref<3xindex>
-// CHECK:               %[[VAL_69:.*]] = arith.subi %[[VAL_67]], %[[VAL_13]] : index
-// CHECK:               %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_67]] : index
-// CHECK:               %[[VAL_71:.*]] = scf.if %[[VAL_70]] -> (i1) {
-// CHECK:                 %[[VAL_72:.*]] = memref.load %[[VAL_62]]{{\[}}%[[VAL_69]]] : memref<?xindex>
-// CHECK:                 %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_72]], %[[VAL_64]] : index
-// CHECK:                 scf.yield %[[VAL_73]] : i1
-// CHECK:               } else {
-// CHECK:                 memref.store %[[VAL_68]], %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref<?xindex>
-// CHECK:                 scf.yield %[[VAL_14]] : i1
-// CHECK:               }
-// CHECK:               %[[VAL_74:.*]] = scf.if %[[VAL_75:.*]] -> (memref<?xindex>) {
-// CHECK:                 scf.yield %[[VAL_62]] : memref<?xindex>
-// CHECK:               } else {
-// CHECK:                 %[[VAL_76:.*]] = arith.addi %[[VAL_68]], %[[VAL_13]] : index
-// CHECK:                 memref.store %[[VAL_76]], %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref<?xindex>
-// CHECK:                 %[[VAL_77:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_62]], %[[VAL_64]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
-// CHECK:                 scf.yield %[[VAL_77]] : memref<?xindex>
-// CHECK:               }
-// CHECK:               %[[VAL_78:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_63]], %[[VAL_65]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-// CHECK:               memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64>
-// CHECK:               memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_64]]] : memref<4xi1>
-// CHECK:               scf.yield %[[VAL_79:.*]], %[[VAL_78]] : memref<?xindex>, memref<?xf64>
+// CHECK:                 memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
+// CHECK:                 scf.yield %[[VAL_60:.*]] : index
+// CHECK:               } {"Emitted from" = "linalg.generic"}
+// CHECK:             sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref<?xindex>
+// CHECK:             %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK:               %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex>
+// CHECK:               %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
+// CHECK:               %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>)
+// CHECK:               memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
+// CHECK:               memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1>
+// CHECK:               scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 // CHECK:             }
-// CHECK:             scf.yield %[[VAL_80:.*]]#0, %[[VAL_80]]#1 : memref<?xindex>, memref<?xf64>
-// CHECK:           }
+// CHECK:             scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK:           } {"Emitted from" = "linalg.generic"}
 // CHECK:           memref.dealloc %[[VAL_26]] : memref<4xf64>
 // CHECK:           memref.dealloc %[[VAL_27]] : memref<4xi1>
 // CHECK:           memref.dealloc %[[VAL_28]] : memref<4xindex>
-// CHECK:           %[[VAL_81:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_12]]] : memref<3xindex>
-// CHECK:           %[[VAL_82:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK:           %[[VAL_83:.*]] = scf.for %[[VAL_84:.*]] = %[[VAL_13]] to %[[VAL_81]] step %[[VAL_13]] iter_args(%[[VAL_85:.*]] = %[[VAL_82]]) -> (index) {
-// CHECK:             %[[VAL_86:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref<?xindex>
-// CHECK:             %[[VAL_87:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_12]] : index
-// CHECK:             %[[VAL_88:.*]] = arith.select %[[VAL_87]], %[[VAL_85]], %[[VAL_86]] : index
-// CHECK:             scf.if %[[VAL_87]] {
-// CHECK:               memref.store %[[VAL_85]], %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref<?xindex>
+// CHECK:           %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex>
+// CHECK:           %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK:           %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) {
+// CHECK:             %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
+// CHECK:             %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index
+// CHECK:             %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index
+// CHECK:             scf.if %[[VAL_81]] {
+// CHECK:               memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
 // CHECK:             }
-// CHECK:             scf.yield %[[VAL_88]] : index
+// CHECK:             scf.yield %[[VAL_82]] : index
 // CHECK:           }
-// CHECK:           return %[[VAL_16]], %[[VAL_17]], %[[VAL_25]], %[[VAL_89:.*]]#0, %[[VAL_89]]#1 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK:           return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 // CHECK:         }
 func.func @matmul(%A: tensor<4x8xf64, #CSR>,
                   %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {


        


More information about the Mlir-commits mailing list