[Mlir-commits] [mlir] [mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. (PR #66722)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 17:09:33 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

The functionality of the two operations are largely overlapped, let's simplify it and only use one of them.

---

Patch is 91.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66722.diff


14 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+13-62) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+10-35) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (+142-172) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+7-5) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+16-22) 
- (modified) mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir (+23-82) 
- (modified) mlir/test/Dialect/SparseTensor/codegen.mlir (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+19-28) 
- (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+7-57) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir (+1-1) 
- (removed) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir (-187) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir (+27-24) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 94301dbcd9f7b42..59815fc755ee5f3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -762,81 +762,32 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
 // Sparse Tensor Sorting Operations.
 //===----------------------------------------------------------------------===//
 
-def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
-    Arguments<(ins Index:$n,
-               Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
-               Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
-               SparseTensorSortKindAttr:$algorithm)>  {
-  string summary = "Sorts the arrays in xs and ys lexicographically on the "
-                   "integral values found in the xs list";
-  string description = [{
-    Lexicographically sort the first `n` values in `xs` along with the values in
-    `ys`. Conceptually, the values being sorted are tuples produced by
-    `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
-    along with values in `xs`, but values in `ys` don't affect the
-    lexicographical order. The order in which arrays appear in `xs` affects the
-    sorting result. The operator updates `xs` and `ys` in place with the result
-    of the sorting.
-
-    For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
-    "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
-    output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
-
-    Buffers in `xs` needs to have the same integral element type while buffers
-    in `ys` can have different numeric element types. All buffers in `xs` and
-    `ys` should have a dimension not less than `n`. The behavior of the operator
-    is undefined if this condition is not met. The operator requires at least
-    one buffer in `xs` while `ys` can be empty.
-
-    The enum attribute `algorithm` indicates the sorting algorithm used to
-    implement the operator: hybrid_quick_sort, insertion_sort_stable,
-    quick_sort, or heap_sort.
-
-    Note that this operation is "impure" in the sense that its behavior is
-    solely defined by side-effects and not SSA values.
-
-    Example:
-
-    ```mlir
-    sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
-      : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
-    ```
-
-    ```mlir
-    sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
-      { alg=1 : index}
-      : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
-    ```
-  }];
-  let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
-                       "`:` type($xs) (`jointly` type($ys)^)?";
-  let hasVerifier = 1;
-}
-
 def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
     Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
                Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
-               OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
+               AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
                SparseTensorSortKindAttr:$algorithm)>  {
   let summary = "Sorts the arrays in xs and ys lexicographically on the "
                 "integral values found in the xs list";
   let description = [{
-    Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
-    `xs` values and some `ys` values are put in the linear buffer `xy`. The
-    optional index attribute `nx` provides the number of `xs` values in `xy`.
-    When `nx` is not explicitly specified, its value is 1. The optional index
-    attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
-    explicitly specified, its value is 0. This instruction supports a more
-    efficient way to store the COO definition in sparse tensor type.
-
-    The buffer xy should have a dimension not less than n * (nx + ny) while the
+    Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
+    that are put in a single linear buffer `xy`.
+    The affine map attribute `perm_map` specifies the permutation to be applied on
+    the `xs` before comparison, the rank of the permutation map
+    also specifies the number of `xs` values in `xy`.
+    The optional index attribute `ny` provides the number of `ys` values in `xy`.
+    When `ny` is not explicitly specified, its value is 0.
+    This instruction supports a more efficient way to store the COO definition
+    in sparse tensor type.
+
+    The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
     buffers in `ys` should have a dimension not less than `n`. The behavior of
     the operator is undefined if this condition is not met.
 
     Example:
 
     ```mlir
-    sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
+    sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
       : memref<?xindex>
     ```
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e71d2a8dd623a6a..9675a61109477b5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1353,35 +1353,15 @@ LogicalResult SelectOp::verify() {
   return success();
 }
 
-LogicalResult SortOp::verify() {
-  if (getXs().empty())
-    return emitError("need at least one xs buffer.");
-
-  std::optional<int64_t> n = getConstantIntValue(getN());
-
-  Type xtp = getMemRefType(getXs().front()).getElementType();
-  auto checkTypes = [&](ValueRange operands,
-                        bool checkEleType = true) -> LogicalResult {
-    for (Value opnd : operands) {
-      auto mtp = getMemRefType(opnd);
-      const DynSize sh = mtp.getShape()[0];
-      // We can't check the size of dynamic dimension at compile-time, but all
-      // xs and ys should have a dimension not less than n at runtime.
-      if (n && !ShapedType::isDynamic(sh) && sh < n.value())
-        return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
-                                       ": {0} < {1}",
-                                       sh, n.value()));
-
-      if (checkEleType && xtp != mtp.getElementType())
-        return emitError("mismatch xs element types");
-    }
-    return success();
-  };
-  RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
-  return n ? checkTypes(getYs(), false) : success();
-}
-
 LogicalResult SortCooOp::verify() {
+  AffineMap xPerm = getPermMap();
+  uint64_t nx = xPerm.getNumDims();
+  if (nx < 1)
+    emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
+
+  if (!xPerm.isPermutation())
+    emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
+
   std::optional<int64_t> cn = getConstantIntValue(getN());
   // We can't check the size of the buffers when n or buffer dimensions aren't
   // compile-time constants.
@@ -1389,12 +1369,6 @@ LogicalResult SortCooOp::verify() {
     return success();
 
   uint64_t n = cn.value();
-  uint64_t nx = 1;
-  if (auto nxAttr = getNxAttr()) {
-    nx = nxAttr.getInt();
-    if (nx < 1)
-      emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
-  }
   uint64_t ny = 0;
   if (auto nyAttr = getNyAttr()) {
     ny = nyAttr.getInt();
@@ -1409,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
   };
 
-  checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
+  checkDim(getXy(), n * (nx + ny),
+           "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
 
   for (Value opnd : getYs()) {
     checkDim(opnd, n, "Expected dimension(y) >= n");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 029ecb0708941fe..3181395a474cfec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -45,46 +45,43 @@ static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
 
-using FuncGeneratorType = function_ref<void(
-    OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
+using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
+                                            AffineMap, uint64_t, uint32_t)>;
 
 /// Constructs a function name with this format to facilitate quick sort:
-///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
-///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
+///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
+///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
-                                         StringRef namePrefix, uint64_t nx,
-                                         uint64_t ny, bool isCoo,
-                                         ValueRange operands) {
-  nameOstream << namePrefix << nx << "_"
-              << getMemRefType(operands[xStartIdx]).getElementType();
+                                         StringRef namePrefix, AffineMap xPerm,
+                                         uint64_t ny, ValueRange operands) {
+  nameOstream << namePrefix;
+  for (auto res : xPerm.getResults())
+    nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
 
-  if (isCoo)
-    nameOstream << "_coo_" << ny;
+  nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
+  nameOstream << "_coo_" << ny;
 
-  uint64_t yBufferOffset = isCoo ? 1 : nx;
+  constexpr uint64_t yBufferOffset = 1;
   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
     nameOstream << "_" << getMemRefType(v).getElementType();
 }
 
 /// Looks up a function that is appropriate for the given operands being
 /// sorted, and creates such a function if it doesn't exist yet. The
-/// parameters `nx` and `ny` tell the number of x and y values provided
-/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
-/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
+/// parameters `xPerm` and `ny` tell the number of x and y values provided
+/// by the buffer in xStartIdx.
 //
 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
 // parameters for the sorting functions. Other parameters, such as the recursive
 // call depth, are appended to the end of the parameter list as
 // "trailing parameters".
-static FlatSymbolRefAttr
-getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
-                         TypeRange resultTypes, StringRef namePrefix,
-                         uint64_t nx, uint64_t ny, bool isCoo,
-                         ValueRange operands, FuncGeneratorType createFunc,
-                         uint32_t nTrailingP = 0) {
+static FlatSymbolRefAttr getMangledSortHelperFunc(
+    OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
+    StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
+    FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
   SmallString<32> nameBuffer;
   llvm::raw_svector_ostream nameOstream(nameBuffer);
-  getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
+  getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
                                operands.drop_back(nTrailingP));
 
   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -101,7 +98,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
         loc, nameOstream.str(),
         FunctionType::get(context, operands.getTypes(), resultTypes));
     func.setPrivate();
-    createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
+    createFunc(builder, module, func, xPerm, ny, nTrailingP);
   }
 
   return result;
@@ -110,27 +107,19 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
 /// The code to process the value pairs is generated by `bodyBuilder`.
 static void forEachIJPairInXs(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
-  Value iOffset, jOffset;
-  if (isCoo) {
-    Value cstep = constantIndex(builder, loc, nx + ny);
-    iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
-    jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
-  }
-  for (uint64_t k = 0; k < nx; k++) {
-    scf::IfOp ifOp;
-    Value i, j, buffer;
-    if (isCoo) {
-      Value ck = constantIndex(builder, loc, k);
-      i = builder.create<arith::AddIOp>(loc, ck, iOffset);
-      j = builder.create<arith::AddIOp>(loc, ck, jOffset);
-      buffer = args[xStartIdx];
-    } else {
-      i = args[0];
-      j = args[1];
-      buffer = args[xStartIdx + k];
-    }
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny,
+    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+  Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
+  Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
+  Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
+  for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
+    unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
+    Value ak = constantIndex(builder, loc, actualK);
+    Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
+    Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
+    Value buffer = args[xStartIdx];
+
     bodyBuilder(k, i, j, buffer);
   }
 }
@@ -138,21 +127,28 @@ static void forEachIJPairInXs(
 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
 /// The code to process the value pairs is generated by `bodyBuilder`.
 static void forEachIJPairInAllBuffers(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
-
-  // Create code for the first (nx + ny) buffers. When isCoo==true, these
-  // logical buffers are all from the xy buffer of the sort_coo operator.
-  forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny,
+    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+
+  // Create code for the first (xPerm + ny) buffers.
+  SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
+                               xPerm.getResults().end());
+  for (unsigned y = 0; y < ny; y++) {
+    exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
+  }
+  AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
+  assert(xyPerm.isPermutation());
 
-  uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
+  forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
 
+  constexpr uint64_t numHandledBuffers = 1;
   // Create code for the remaining buffers.
   Value i = args[0];
   Value j = args[1];
   for (const auto &arg :
        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
-    bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
+    bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
   }
 }
 
@@ -168,7 +164,7 @@ static void forEachIJPairInAllBuffers(
 //     ...
 //     swap(yn[i], yn[j]);
 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
-                       uint64_t nx, uint64_t ny, bool isCoo) {
+                       AffineMap xPerm, uint64_t ny) {
   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
@@ -176,20 +172,20 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
     builder.create<memref::StoreOp>(loc, vi, buffer, j);
   };
 
-  forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
+  forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
 }
 
 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
 /// each pair is create via `compareBuilder`.
 static Value createInlinedCompareImplementation(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo,
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny,
     function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
         compareBuilder) {
   Value result;
   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
     bool isFirstDim = (k == 0);
-    bool isLastDim = (k == nx - 1);
+    bool isLastDim = (k == xPerm.getNumResults() - 1);
     Value val =
         compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
     if (isFirstDim) {
@@ -202,7 +198,7 @@ static Value createInlinedCompareImplementation(
     }
   };
 
-  forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
+  forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
 
   builder.setInsertionPointAfterValue(result);
   return result;
@@ -252,12 +248,12 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
 //     else if (x2[2] != x2[j]))
 //       and so on ...
 static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
-                                    ValueRange args, uint64_t nx, uint64_t ny,
-                                    bool isCoo, uint32_t nTrailingP = 0) {
+                                    ValueRange args, AffineMap xPerm,
+                                    uint64_t ny, uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
-  return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
+  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
                                             createEqCompare);
 }
 
@@ -306,12 +302,12 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
 //   else
 //       and so on ...
 static Value createInlinedLessThan(OpBuilder &builder, Location loc,
-                                   ValueRange args, uint64_t nx, uint64_t ny,
-                                   bool isCoo, uint32_t nTrailingP = 0) {
+                                   ValueRange args, AffineMap xPerm,
+                                   uint64_t ny, uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
-  return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
+  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
                                             createLessThanCompare);
 }
 
@@ -329,8 +325,8 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
 //   return lo;
 //
 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
-                                   func::FuncOp func, uint64_t nx, uint64_t ny,
-                                   bool isCoo, uint32_t nTrailingP = 0) {
+                                   func::FuncOp func, AffineMap xPerm,
+                                   uint64_t ny, uint32_t nTrailingP = 0) {
   // Binary search doesn't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
@@ -368,11 +364,10 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
 
   // Compare xs[p] < xs[mid].
   SmallVector<Value> compareOperands{p, mid};
-  uint64_t numXBuffers = isCoo ? 1 : nx;
+  constexpr uint64_t numXBuffers = 1;
   compareOp...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/66722


More information about the Mlir-commits mailing list