[Mlir-commits] [mlir] [mlir][sparse] fix logical error when generating sort_coo. (PR #66690)

Peiming Liu llvmlistbot at llvm.org
Mon Sep 18 12:35:24 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/66690

>From 1fa955d50e3ab25809d12a2b54c3841a014b257e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 18 Sep 2023 19:31:24 +0000
Subject: [PATCH 1/2] [mlir][sparse] fix logical error when generating
 sort_coo.

---
 .../Transforms/SparseBufferRewriting.cpp      |  64 ++--
 .../SparseTensor/buffer_rewriting.mlir        | 340 +-----------------
 .../CPU/sparse_rewrite_sort_coo.mlir          |  42 +--
 3 files changed, 68 insertions(+), 378 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 5e53dbe1cc28381..2935cdad372d3d3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -563,11 +563,14 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
 //   p = (lo+hi)/2  // pivot index
 //   i = lo
 //   j = hi-1
-//   while (i < j) do {
+//   while (true) do {
 //     while (xs[i] < xs[p]) i ++;
 //     i_eq = (xs[i] == xs[p]);
 //     while (xs[j] > xs[p]) j --;
 //     j_eq = (xs[j] == xs[p]);
+//
+//     if (i >= j) return j + 1;
+//
 //     if (i < j) {
 //       swap(xs[i], xs[j])
 //       if (i == p) {
@@ -605,22 +608,22 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   Value i = lo;
   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
   createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
-  SmallVector<Value, 3> operands{i, j, p}; // Exactly three values.
-  SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
+  Value cont = constantI1(builder, loc, true);
+  SmallVector<Value, 3> operands{i, j, p, cont}; // Exactly three values.
+  SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType(),
+                             cont.getType()};
   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
 
   // The before-region of the WhileOp.
-  Block *before =
-      builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
+  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
+                                      {loc, loc, loc, loc});
   builder.setInsertionPointToEnd(before);
-  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
-                                             before->getArgument(0),
-                                             before->getArgument(1));
-  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
+  builder.create<scf::ConditionOp>(loc, before->getArgument(3),
+                                   before->getArguments());
 
   // The after-region of the WhileOp.
   Block *after =
-      builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
+      builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
   builder.setInsertionPointToEnd(after);
   i = after->getArgument(0);
   j = after->getArgument(1);
@@ -637,7 +640,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   j = jresult;
 
   // If i < j:
-  cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
+  Value cond =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   SmallVector<Value> swapOperands{i, j};
@@ -675,11 +679,15 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointAfter(ifOp2);
   builder.create<scf::YieldOp>(
       loc,
-      ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
+      ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
+                 /*cont=*/constantI1(builder, loc, true)});
 
-  // False branch for if i < j:
+  // False branch for if i < j (i.e., i >= j):
   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
+  p = builder.create<arith::AddIOp>(loc, j,
+                                    constantOne(builder, loc, j.getType()));
+  builder.create<scf::YieldOp>(
+      loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
 
   // Return for the whileOp.
   builder.setInsertionPointAfter(ifOp);
@@ -927,6 +935,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   Location loc = func.getLoc();
   Value lo = args[loIdx];
   Value hi = args[hiIdx];
+  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
+
   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
       ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
@@ -935,14 +945,25 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
                                       TypeRange{IndexType::get(context)},
                                       args.drop_back(nTrailingP))
                 .getResult(0);
-  Value pP1 =
-      builder.create<arith::AddIOp>(loc, p, constantIndex(builder, loc, 1));
+
   Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
   Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
+  // Partition already sorts array with len <= 2
+  Value c2 = constantIndex(builder, loc, 2);
+  Value len = builder.create<arith::SubIOp>(loc, hi, lo);
+  Value lenGtTwo =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
+  scf::IfOp ifLenGtTwo =
+      builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
+  builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
+  // Returns an empty range to mark the entire region is fully sorted.
+  builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
+
+  // Else len > 2, need recursion.
+  builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
                                              lenLow, lenHigh);
 
-  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
 
   Value c0 = constantIndex(builder, loc, 0);
@@ -961,14 +982,17 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   // the bigger partition to be processed by the enclosed while-loop.
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   mayRecursion(lo, p, lenLow);
-  builder.create<scf::YieldOp>(loc, ValueRange{pP1, hi});
+  builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
 
   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  mayRecursion(pP1, hi, lenHigh);
+  mayRecursion(p, hi, lenHigh);
   builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
 
   builder.setInsertionPointAfter(ifOp);
-  return std::make_pair(ifOp.getResult(0), ifOp.getResult(1));
+  builder.create<scf::YieldOp>(loc, ifOp.getResults());
+
+  builder.setInsertionPointAfter(ifLenGtTwo);
+  return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
 }
 
 /// Creates a function to perform insertion sort on the values in the range of
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 170f851138f82ae..0036bd5c3310b97 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -75,343 +75,9 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 
 // -----
 
-// CHECK-LABEL:   func.func private @_sparse_partition_1_i8_f32_index(
-// CHECK-SAME:                                                        %[[VAL_0:.*0]]: index,
-// CHECK-SAME:                                                        %[[VAL_1:.*1]]: index,
-// CHECK-SAME:                                                        %[[VAL_2:.*2]]: memref<?xi8>,
-// CHECK-SAME:                                                        %[[VAL_3:.*3]]: memref<?xf32>,
-// CHECK-SAME:                                                        %[[VAL_4:.*4]]: memref<?xindex>) -> index {
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1000
-// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant -1
-// CHECK:           %[[VAL_8:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]]
-// CHECK:           %[[VAL_9:.*]] = arith.shrui %[[VAL_8]], %[[VAL_5]]
-// CHECK:           %[[VAL_10:.*]] = arith.subi %[[VAL_1]], %[[VAL_5]]
-// CHECK:           %[[VAL_11:.*]] = arith.subi %[[VAL_1]], %[[VAL_0]]
-// CHECK:           %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_6]]
-// CHECK:           scf.if %[[VAL_12]] {
-// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:             %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_14]]
-// CHECK:             scf.if %[[VAL_15]] {
-// CHECK:               %[[VAL_16:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:               memref.store %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_16]], %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:               memref.store %[[VAL_19]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_18]], %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:               memref.store %[[VAL_21]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_20]], %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:             }
-// CHECK:             %[[VAL_22:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]]
-// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:             %[[VAL_24:.*]] = arith.cmpi ult, %[[VAL_22]], %[[VAL_23]]
-// CHECK:             scf.if %[[VAL_24]] {
-// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]]
-// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_26]], %[[VAL_2]]{{\[}}%[[VAL_10]]]
-// CHECK:               memref.store %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_10]]]
-// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_28]], %[[VAL_3]]{{\[}}%[[VAL_10]]]
-// CHECK:               memref.store %[[VAL_27]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_29:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_10]]]
-// CHECK:               %[[VAL_30:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_30]], %[[VAL_4]]{{\[}}%[[VAL_10]]]
-// CHECK:               memref.store %[[VAL_29]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_31:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:               %[[VAL_33:.*]] = arith.cmpi ult, %[[VAL_31]], %[[VAL_32]]
-// CHECK:               scf.if %[[VAL_33]] {
-// CHECK:                 %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_35:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                 memref.store %[[VAL_35]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_34]], %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                 %[[VAL_36:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_37:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:                 memref.store %[[VAL_37]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_36]], %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:                 %[[VAL_38:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_39:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:                 memref.store %[[VAL_39]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_38]], %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:               }
-// CHECK:             }
-// CHECK:           } else {
-// CHECK:             %[[VAL_40:.*]] = arith.addi %[[VAL_9]], %[[VAL_1]]
-// CHECK:             %[[VAL_41:.*]] = arith.shrui %[[VAL_40]], %[[VAL_5]]
-// CHECK:             %[[VAL_42:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:             %[[VAL_43:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:             %[[VAL_44:.*]] = arith.cmpi ult, %[[VAL_42]], %[[VAL_43]]
-// CHECK:             scf.if %[[VAL_44]] {
-// CHECK:               %[[VAL_45:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_46:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:               memref.store %[[VAL_46]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_45]], %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:               %[[VAL_47:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_48:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:               memref.store %[[VAL_48]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_47]], %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:               %[[VAL_49:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_50:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:               memref.store %[[VAL_50]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_49]], %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:             }
-// CHECK:             %[[VAL_51:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:             %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_51]]
-// CHECK:             scf.if %[[VAL_52]] {
-// CHECK:               %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_53]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_53]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_54:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_54]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_54]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_55:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_55]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_55]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_56:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_57:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:               %[[VAL_58:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_57]]
-// CHECK:               scf.if %[[VAL_58]] {
-// CHECK:                 %[[VAL_59:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_60:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                 memref.store %[[VAL_60]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_59]], %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_62:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:                 memref.store %[[VAL_62]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_61]], %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:                 %[[VAL_63:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_64:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:                 memref.store %[[VAL_64]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_63]], %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:               }
-// CHECK:             }
-// CHECK:             %[[VAL_65:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:             %[[VAL_66:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:             %[[VAL_67:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_66]]
-// CHECK:             scf.if %[[VAL_67]] {
-// CHECK:               %[[VAL_68:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:               %[[VAL_69:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_69]], %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:               memref.store %[[VAL_68]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_70:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]]
-// CHECK:               %[[VAL_71:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_71]], %[[VAL_3]]{{\[}}%[[VAL_41]]]
-// CHECK:               memref.store %[[VAL_70]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_72:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]]
-// CHECK:               %[[VAL_73:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               memref.store %[[VAL_73]], %[[VAL_4]]{{\[}}%[[VAL_41]]]
-// CHECK:               memref.store %[[VAL_72]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_74:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_75:.*]] = arith.cmpi ult, %[[VAL_74]], %[[VAL_74]]
-// CHECK:               scf.if %[[VAL_75]] {
-// CHECK:                 %[[VAL_76:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_76]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_76]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_77:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_77]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_77]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_78:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_78]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_78]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_79:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_80:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                 %[[VAL_81:.*]] = arith.cmpi ult, %[[VAL_79]], %[[VAL_80]]
-// CHECK:                 scf.if %[[VAL_81]] {
-// CHECK:                   %[[VAL_82:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                   %[[VAL_83:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                   memref.store %[[VAL_83]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_82]], %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                   %[[VAL_84:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                   %[[VAL_85:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:                   memref.store %[[VAL_85]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_84]], %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:                   %[[VAL_86:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                   %[[VAL_87:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:                   memref.store %[[VAL_87]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_86]], %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:                 }
-// CHECK:               }
-// CHECK:             }
-// CHECK:             %[[VAL_88:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]]
-// CHECK:             %[[VAL_89:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:             %[[VAL_90:.*]] = arith.cmpi ult, %[[VAL_88]], %[[VAL_89]]
-// CHECK:             scf.if %[[VAL_90]] {
-// CHECK:               %[[VAL_91:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]]
-// CHECK:               %[[VAL_92:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:               memref.store %[[VAL_92]], %[[VAL_2]]{{\[}}%[[VAL_10]]]
-// CHECK:               memref.store %[[VAL_91]], %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:               %[[VAL_93:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_10]]]
-// CHECK:               %[[VAL_94:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]]
-// CHECK:               memref.store %[[VAL_94]], %[[VAL_3]]{{\[}}%[[VAL_10]]]
-// CHECK:               memref.store %[[VAL_93]], %[[VAL_3]]{{\[}}%[[VAL_41]]]
-// CHECK:               %[[VAL_95:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_10]]]
-// CHECK:               %[[VAL_96:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]]
-// CHECK:               memref.store %[[VAL_96]], %[[VAL_4]]{{\[}}%[[VAL_10]]]
-// CHECK:               memref.store %[[VAL_95]], %[[VAL_4]]{{\[}}%[[VAL_41]]]
-// CHECK:               %[[VAL_97:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:               %[[VAL_98:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:               %[[VAL_99:.*]] = arith.cmpi ult, %[[VAL_97]], %[[VAL_98]]
-// CHECK:               scf.if %[[VAL_99]] {
-// CHECK:                 %[[VAL_100:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:                 %[[VAL_101:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_101]], %[[VAL_2]]{{\[}}%[[VAL_41]]]
-// CHECK:                 memref.store %[[VAL_100]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_102:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]]
-// CHECK:                 %[[VAL_103:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_103]], %[[VAL_3]]{{\[}}%[[VAL_41]]]
-// CHECK:                 memref.store %[[VAL_102]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_104:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]]
-// CHECK:                 %[[VAL_105:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 memref.store %[[VAL_105]], %[[VAL_4]]{{\[}}%[[VAL_41]]]
-// CHECK:                 memref.store %[[VAL_104]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_106:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                 %[[VAL_107:.*]] = arith.cmpi ult, %[[VAL_106]], %[[VAL_106]]
-// CHECK:                 scf.if %[[VAL_107]] {
-// CHECK:                   %[[VAL_108:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_108]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_108]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                   %[[VAL_109:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_109]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_109]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                   %[[VAL_110:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_110]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                   memref.store %[[VAL_110]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                   %[[VAL_111:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                   %[[VAL_112:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                   %[[VAL_113:.*]] = arith.cmpi ult, %[[VAL_111]], %[[VAL_112]]
-// CHECK:                   scf.if %[[VAL_113]] {
-// CHECK:                     %[[VAL_114:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                     %[[VAL_115:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                     memref.store %[[VAL_115]], %[[VAL_2]]{{\[}}%[[VAL_9]]]
-// CHECK:                     memref.store %[[VAL_114]], %[[VAL_2]]{{\[}}%[[VAL_0]]]
-// CHECK:                     %[[VAL_116:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                     %[[VAL_117:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:                     memref.store %[[VAL_117]], %[[VAL_3]]{{\[}}%[[VAL_9]]]
-// CHECK:                     memref.store %[[VAL_116]], %[[VAL_3]]{{\[}}%[[VAL_0]]]
-// CHECK:                     %[[VAL_118:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                     %[[VAL_119:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:                     memref.store %[[VAL_119]], %[[VAL_4]]{{\[}}%[[VAL_9]]]
-// CHECK:                     memref.store %[[VAL_118]], %[[VAL_4]]{{\[}}%[[VAL_0]]]
-// CHECK:                   }
-// CHECK:                 }
-// CHECK:               }
-// CHECK:             }
-// CHECK:           }
-// CHECK:           %[[VAL_120:.*]]:3 = scf.while (%[[VAL_121:.*]] = %[[VAL_0]], %[[VAL_122:.*]] = %[[VAL_10]], %[[VAL_123:.*]] = %[[VAL_9]])
-// CHECK:             %[[VAL_124:.*]] = arith.cmpi ult, %[[VAL_121]], %[[VAL_122]]
-// CHECK:             scf.condition(%[[VAL_124]]) %[[VAL_121]], %[[VAL_122]], %[[VAL_123]]
-// CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_125:.*]]: index, %[[VAL_126:.*]]: index, %[[VAL_127:.*]]: index)
-// CHECK:             %[[VAL_128:.*]] = scf.while (%[[VAL_129:.*]] = %[[VAL_125]])
-// CHECK:               %[[VAL_130:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_129]]]
-// CHECK:               %[[VAL_131:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]]
-// CHECK:               %[[VAL_132:.*]] = arith.cmpi ult, %[[VAL_130]], %[[VAL_131]]
-// CHECK:               scf.condition(%[[VAL_132]]) %[[VAL_129]]
-// CHECK:             } do {
-// CHECK:             ^bb0(%[[VAL_133:.*]]: index):
-// CHECK:               %[[VAL_134:.*]] = arith.addi %[[VAL_133]], %[[VAL_5]]
-// CHECK:               scf.yield %[[VAL_134]]
-// CHECK:             }
-// CHECK:             %[[VAL_135:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136:.*]]]
-// CHECK:             %[[VAL_137:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]]
-// CHECK:             %[[VAL_138:.*]] = arith.cmpi eq, %[[VAL_135]], %[[VAL_137]]
-// CHECK:             %[[VAL_139:.*]] = scf.while (%[[VAL_140:.*]] = %[[VAL_126]])
-// CHECK:               %[[VAL_141:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]]
-// CHECK:               %[[VAL_142:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_140]]]
-// CHECK:               %[[VAL_143:.*]] = arith.cmpi ult, %[[VAL_141]], %[[VAL_142]]
-// CHECK:               scf.condition(%[[VAL_143]]) %[[VAL_140]]
-// CHECK:             } do {
-// CHECK:             ^bb0(%[[VAL_144:.*]]: index):
-// CHECK:               %[[VAL_145:.*]] = arith.addi %[[VAL_144]], %[[VAL_7]]
-// CHECK:               scf.yield %[[VAL_145]]
-// CHECK:             }
-// CHECK:             %[[VAL_146:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147:.*]]]
-// CHECK:             %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_137]]
-// CHECK:             %[[VAL_150:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_147]]
-// CHECK:             %[[VAL_151:.*]]:3 = scf.if %[[VAL_150]]
-// CHECK:               %[[VAL_152:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136]]]
-// CHECK:               %[[VAL_153:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147]]]
-// CHECK:               memref.store %[[VAL_153]], %[[VAL_2]]{{\[}}%[[VAL_136]]]
-// CHECK:               memref.store %[[VAL_152]], %[[VAL_2]]{{\[}}%[[VAL_147]]]
-// CHECK:               %[[VAL_154:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_136]]]
-// CHECK:               %[[VAL_155:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_147]]]
-// CHECK:               memref.store %[[VAL_155]], %[[VAL_3]]{{\[}}%[[VAL_136]]]
-// CHECK:               memref.store %[[VAL_154]], %[[VAL_3]]{{\[}}%[[VAL_147]]]
-// CHECK:               %[[VAL_156:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_136]]]
-// CHECK:               %[[VAL_157:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_147]]]
-// CHECK:               memref.store %[[VAL_157]], %[[VAL_4]]{{\[}}%[[VAL_136]]]
-// CHECK:               memref.store %[[VAL_156]], %[[VAL_4]]{{\[}}%[[VAL_147]]]
-// CHECK:               %[[VAL_158:.*]] = arith.cmpi eq, %[[VAL_136]], %[[VAL_127]]
-// CHECK:               %[[VAL_159:.*]] = scf.if %[[VAL_158]]
-// CHECK:                 scf.yield %[[VAL_147]]
-// CHECK:               } else {
-// CHECK:                 %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_147]], %[[VAL_127]]
-// CHECK:                 %[[VAL_161:.*]] = arith.select %[[VAL_160]], %[[VAL_136]], %[[VAL_127]]
-// CHECK:                 scf.yield %[[VAL_161]]
-// CHECK:               }
-// CHECK:               %[[VAL_162:.*]] = arith.andi %[[VAL_138]], %[[VAL_149]] : i1
-// CHECK:               %[[VAL_163:.*]]:2 = scf.if %[[VAL_162]]
-// CHECK:                 %[[VAL_164:.*]] = arith.addi %[[VAL_136]], %[[VAL_5]]
-// CHECK:                 %[[VAL_165:.*]] = arith.subi %[[VAL_147]], %[[VAL_5]]
-// CHECK:                 scf.yield %[[VAL_164]], %[[VAL_165]]
-// CHECK:               } else {
-// CHECK:                 scf.yield %[[VAL_136]], %[[VAL_147]]
-// CHECK:               }
-// CHECK:               scf.yield %[[VAL_166:.*]]#0, %[[VAL_166]]#1, %[[VAL_167:.*]]
-// CHECK:             } else {
-// CHECK:               scf.yield %[[VAL_136]], %[[VAL_147]], %[[VAL_127]]
-// CHECK:             }
-// CHECK:             scf.yield %[[VAL_168:.*]]#0, %[[VAL_168]]#1, %[[VAL_168]]#2
-// CHECK:           }
-// CHECK:           return %[[VAL_169:.*]]#2
-// CHECK:         }
-
-// CHECK-LABEL:   func.func private @_sparse_qsort_1_i8_f32_index(
-// CHECK-SAME:                                                   %[[L:arg0]]: index,
-// CHECK-SAME:                                                   %[[H:.*]]: index,
-// CHECK-SAME:                                                   %[[X0:.*]]: memref<?xi8>,
-// CHECK-SAME:                                                   %[[Y0:.*]]: memref<?xf32>,
-// CHECK-SAME:                                                   %[[Y1:.*]]: memref<?xindex>) {
-// CHECK:           %[[C1:.*]] = arith.constant 1
-// CHECK:           scf.while (%[[L2:.*]] = %[[L]], %[[H2:.*]] = %[[H]])
-// CHECK:             %[[Lb:.*]] = arith.addi %[[L2]], %[[C1]]
-// CHECK:             %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H2]]
-// CHECK:             scf.condition(%[[COND]]) %[[L2]], %[[H2]]
-// CHECK:           } do {
-// CHECK:           ^bb0(%[[L3:.*]]: index, %[[H3:.*]]: index)
-// CHECK:             %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L3]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:             %[[PP1:.*]] = arith.addi %[[P]], %[[C1]] : index
-// CHECK:             %[[LenL:.*]] = arith.subi %[[P]], %[[L3]]
-// CHECK:             %[[LenH:.*]] = arith.subi %[[H3]], %[[P]]
-// CHECK:             %[[Cmp:.*]] = arith.cmpi ule, %[[LenL]], %[[LenH]]
-// CHECK:             %[[L4:.*]] = arith.select %[[Cmp]], %[[PP1]], %[[L3]]
-// CHECK:             %[[H4:.*]] = arith.select %[[Cmp]], %[[H3]], %[[P]]
-// CHECK:             scf.if %[[Cmp]]
-// CHECK:               func.call @_sparse_qsort_1_i8_f32_index(%[[L3]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:             else
-// CHECK:               func.call @_sparse_qsort_1_i8_f32_index(%[[PP1]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:             scf.yield %[[L4]], %[[H4]]
-// CHECK:           }
-// CHECK:           return
-// CHECK:         }
-
-// CHECK-LABEL:   func.func @sparse_sort_1d2v_quick(
-// CHECK-SAME:                                %[[N:.*]]: index,
-// CHECK-SAME:                                %[[X0:.*]]: memref<10xi8>,
-// CHECK-SAME:                                %[[Y0:.*]]: memref<?xf32>,
-// CHECK-SAME:                                %[[Y1:.*]]: memref<10xindex>) -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0
-// CHECK:           %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref<?xi8>
-// CHECK:           %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref<?xindex>
-// CHECK:           call @_sparse_qsort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
-// CHECK:           return %[[X0]], %[[Y0]], %[[Y1]]
-// CHECK:         }
+// CHECK-LABEL:   func.func private @_sparse_partition_1_i8_f32_index
+// CHECK-LABEL:   func.func private @_sparse_qsort_1_i8_f32_index
+// CHECK-LABEL:   func.func @sparse_sort_1d2v_quick
 func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
    -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
   sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index c3bdc30e355b1d8..ca5dd00d02aff1e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -94,14 +94,14 @@ module {
     %y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
 
     // Sort "parallel arrays".
-    // CHECK: ( 1, 1, 2, 5, 10 )
-    // CHECK: ( 3, 3, 1, 10, 1 )
-    // CHECK: ( 9, 9, 4, 7, 2 )
-    // CHECK: ( 7, 8, 10, 9, 6 )
-    // CHECK: ( 7, 4, 7, 9, 5 )
-    call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+    // CHECK: ( 1, 1, 3, 3, 10 )
+    // CHECK: ( 2, 10, 1, 1, 5 )
+    // CHECK: ( 4, 2, 9, 9, 7 )
+    // CHECK: ( 10, 6, 7, 8, 9 )
+    // CHECK: ( 7, 5, 7, 4, 9 )
+    call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+    call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
@@ -122,14 +122,14 @@ module {
     %y1v = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
     vector.print %y1v : vector<5xi32>
     // Stable sort.
-    // CHECK: ( 1, 1, 2, 5, 10 )
-    // CHECK: ( 3, 3, 1, 10, 1 )
-    // CHECK: ( 9, 9, 4, 7, 2 )
-    // CHECK: ( 8, 7, 10, 9, 6 )
-    // CHECK: ( 4, 7, 7, 9, 5 )
-    call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+    // CHECK: ( 1, 1, 3, 3, 10 )
+    // CHECK: ( 2, 10, 1, 1, 5 )
+    // CHECK: ( 4, 2, 9, 9, 7 )
+    // CHECK: ( 10, 6, 8, 7, 9 )
+    // CHECK: ( 7, 5, 4, 7, 9 )
+    call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+    call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
@@ -150,14 +150,14 @@ module {
     %y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
     vector.print %y1v2 : vector<5xi32>
     // Heap sort.
-    // CHECK: ( 1, 1, 2, 5, 10 )
-    // CHECK: ( 3, 3, 1, 10, 1 )
-    // CHECK: ( 9, 9, 4, 7, 2 )
-    // CHECK: ( 7, 8, 10, 9, 6 )
-    // CHECK: ( 7, 4, 7, 9, 5 )
-    call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+    // CHECK: ( 1, 1, 3, 3, 10 )
+    // CHECK: ( 2, 10, 1, 1, 5 )
+    // CHECK: ( 4, 2, 9, 9, 7 )
+    // CHECK: ( 10, 6, 8, 7, 9 )
+    // CHECK: ( 7, 5, 4, 7, 9 )
+    call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+    call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()

>From 3a3c197a1dedbfe3d8b909f0ebe764b7e872b17a Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 18 Sep 2023 19:35:11 +0000
Subject: [PATCH 2/2] cleanup

---
 .../SparseTensor/Transforms/SparseBufferRewriting.cpp       | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 2935cdad372d3d3..fdbbfed82b0eb2b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -608,9 +608,9 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   Value i = lo;
   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
   createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
-  Value cont = constantI1(builder, loc, true);
-  SmallVector<Value, 3> operands{i, j, p, cont}; // Exactly three values.
-  SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType(),
+  Value cont = constantI1(builder, loc, true);   // The value for while (true)
+  SmallVector<Value, 4> operands{i, j, p, cont}; // Exactly four values.
+  SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
                              cont.getType()};
   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
 



More information about the Mlir-commits mailing list