[llvm] [mlir][sparse] implementating stageSparseOpPass as an interface (PR #69022)

Peiming Liu via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 17 10:50:38 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/69022

>From 77aa8c2cf261ad344f75cc5591cc1f394714b269 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 12 Oct 2023 22:58:50 +0000
Subject: [PATCH 1/8] [mlir][sparse] implementating stage convertOp as an
 interface

---
 .../Dialect/SparseTensor/IR/CMakeLists.txt    |   6 +
 .../Dialect/SparseTensor/IR/SparseTensor.h    |   1 +
 .../SparseTensor/IR/SparseTensorInterfaces.h  |  30 +++
 .../SparseTensor/IR/SparseTensorInterfaces.td |  47 ++++
 .../SparseTensor/IR/SparseTensorOps.td        |  18 +-
 .../Dialect/SparseTensor/IR/CMakeLists.txt    |   1 +
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  38 ++-
 .../IR/SparseTensorInterfaces.cpp             |  57 +++++
 .../Transforms/SparseTensorRewriting.cpp      | 223 +++++++-----------
 .../Transforms/StageSparseOperations.cpp      |  53 +----
 10 files changed, 275 insertions(+), 199 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
 create mode 100644 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
 create mode 100644 mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
index 25a2e4869cc7824..54ad9491cce512c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td)
 mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls)
 mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs)
 add_public_tablegen_target(MLIRSparseTensorTypesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td)
+mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen)
+add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 3eb9ce010cb006f..cbca0a7f8cc0e3a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
new file mode 100644
index 000000000000000..f75e02266578495
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -0,0 +1,30 @@
+//===- SparseTensorInterface.h - sparse tensor operations interfaces-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
+#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class PatternRewriter;
+
+namespace sparse_tensor {
+class StageWithSortSparseOp;
+
+namespace detail {
+LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
+                                PatternRewriter &rewriter);
+} // namespace detail
+} // namespace sparse_tensor
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
new file mode 100644
index 000000000000000..29dc946227f5075
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -0,0 +1,47 @@
+//===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES
+#define SPARSETENSOR_IR_SPARSETENSORINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+// The 'LinalgContractionOpInterface' provides access to the
+// 'ContractionOpInterface'.
+def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
+  let description = [{
+    A stage-with-sort sparse tensor operation is an operation that produces
+    unordered intermediate output. An extra sort is required to obtain the final
+    ordered result.
+
+    E.g., convert csr -> csc need to be implemented as
+          convert csr -> unordered coo -> sort by column -> csc; and
+          concatenate csr, csc -> csr can be staged into
+          concatenate csr, csr -> unordered coo -> sort by row -> csr.
+  }];
+  let cppNamespace = "::mlir::sparse_tensor";
+  let methods = [
+    InterfaceMethod<
+    /*desc=*/"Return true if the operation needs an extra sort to produce the final result.",
+    /*retTy=*/"bool",
+    /*methodName=*/"needExtraSort",
+    /*args=*/(ins),
+    /*methodBody=*/"">,
+    InterfaceMethod<
+    /*desc=*/"Stage the operation, return the final result value after staging.",
+    /*retTy=*/"::mlir::LogicalResult",
+    /*methodName=*/"stageWithSort",
+    /*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
+    /*methodBody=*/[{
+        return detail::stageWithSortImpl($_op, rewriter);
+    }]>,
+  ];
+}
+
+
+#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9016634fa3be8dd..a1493c6aebee2b3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
+include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
@@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
 }
 
 def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
-  [Pure]>,
+  [Pure, StageWithSortSparseOpInterface]>,
     Arguments<(ins AnyTensor:$source)>,
     Results<(outs AnyTensor:$dest)> {
   string summary = "Converts between different tensor types";
@@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
   }];
 
   let extraClassDeclaration = [{
-     // Whether the convert can be done by a single step (either a sort or a foreach),
-     // or it would require a tmp buffer (sort, then foreach).
-     bool directConvertable();
+     // Whether the convert can be done by a single step or it would require
+     // an extra sort. Inherited from StageWithSortSparseOpInterface.
+     bool needExtraSort();
   }];
 
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
@@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure]
   let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
 }
 
-def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
+def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate",
+                                 [Pure, StageWithSortSparseOpInterface]>,
     Arguments<(ins Variadic<AnyRankedTensor>:$inputs, DimensionAttr:$dimension)>,
     Results<(outs AnyRankedTensor:$result)> {
 
@@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
      ```
    }];
 
+  let extraClassDeclaration = [{
+     // Whether the concatenate can be done by a single step or it would require
+     // an extra sort. Inherited from StageWithSortSparseOpInterface.
+     bool needExtraSort();
+  }];
+
   let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index b22194d45062acc..dd6f1037f71b53f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -29,6 +29,7 @@ endif()
 
 add_mlir_dialect_library(MLIRSparseTensorDialect
   SparseTensorDialect.cpp
+  SparseTensorInterfaces.cpp
   Detail/Var.cpp
   Detail/DimLvlMap.cpp
   Detail/LvlTypeParser.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 61522fb0dcd24b5..cc7ed639cbde66c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1065,18 +1065,18 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-bool ConvertOp::directConvertable() {
+bool ConvertOp::needExtraSort() {
   SparseTensorType srcStt = getSparseTensorType(getSource());
   SparseTensorType dstStt = getSparseTensorType(getDest());
 
-  // We can always directly convert to unordered sparse tensor or dense tensor
-  // since dense tensor support random access.
+  // We do not need an extra sort when returning unordered sparse tensors or
+  // dense tensor since dense tensor support random access.
   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
-    return true;
+    return false;
 
   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
       srcStt.hasSameDimToLvl(dstStt)) {
-    return true;
+    return false;
   }
 
   // Source and dest tensors are ordered in different ways. We only do direct
@@ -1086,9 +1086,9 @@ bool ConvertOp::directConvertable() {
   // performance.
   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
     if (isa<SparseElementsAttr>(constOp.getValue()))
-      return true;
+      return false;
 
-  return false;
+  return true;
 }
 
 LogicalResult ToPositionsOp::verify() {
@@ -1248,6 +1248,23 @@ LogicalResult UnaryOp::verify() {
   return success();
 }
 
+bool ConcatenateOp::needExtraSort() {
+  SparseTensorType dstStt = getSparseTensorType(*this);
+  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
+    return false;
+
+  bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
+    return getSparseTensorType(op).hasSameDimToLvl(dstStt);
+  });
+  // TODO: When conDim != 0, as long as conDim corresponding to  the first level
+  // in all input/output buffers, and all input/output buffers have the same
+  // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
+  // CSC matrices along column).
+  bool directLowerable =
+      allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
+  return !directLowerable;
+}
+
 LogicalResult ConcatenateOp::verify() {
   const auto dstTp = getSparseTensorType(*this);
   const Dimension concatDim = getDimension();
@@ -1287,9 +1304,10 @@ LogicalResult ConcatenateOp::verify() {
         // If all dimension are statically known, the sum of all the input
         // dimensions should be equal to the output dimension.
         if (sumSz != dstSh)
-          return emitError(
-              "The concatenation dimension of the output tensor should be the "
-              "sum of all the concatenation dimensions of the input tensors.");
+          return emitError("The concatenation dimension of the output tensor "
+                           "should be the "
+                           "sum of all the concatenation dimensions of the "
+                           "input tensors.");
       }
     } else {
       DynSize prev = dstSh;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
new file mode 100644
index 000000000000000..898eff26f5477f8
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -0,0 +1,57 @@
+//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
+
+LogicalResult
+sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
+                                         PatternRewriter &rewriter) {
+  // TODO: Implement it as an Interface, this can be reused from other
+  // operations too (e.g., concatenate, reshape, etc).
+  if (!op.needExtraSort())
+    return failure();
+
+  Location loc = op.getLoc();
+  Type finalTp = op->getOpResult(0).getType();
+  SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
+
+  Type srcCOOTp = getCOOFromTypeWithOrdering(
+      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+
+  // Clones the original operation but changing the output to an unordered COO.
+  Operation *cloned = rewriter.clone(*op.getOperation());
+  rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
+    cloned->getOpResult(0).setType(srcCOOTp);
+  });
+  Value srcCOO = cloned->getOpResult(0);
+
+  // -> sort
+  Type dstCOOTp = getCOOFromTypeWithOrdering(
+      dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+  Value dstCOO = rewriter.create<ReorderCOOOp>(
+      loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
+
+  // -> dest.
+  if (dstCOO.getType() == finalTp) {
+    rewriter.replaceOp(op, dstCOO);
+  } else {
+    // Need an extra conversion if the target type is not COO.
+    rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
+  }
+  // TODO: deallocate extra COOs, we should probably delegate it to buffer
+  // deallocation pass.
+  return success();
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a1ab2495f5f7b5e..fe2c333a7062705 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -829,10 +829,56 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
   }
 };
 
+struct TensorLike {
+  TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
+             ValueRange sizes)
+      : isSparse(rtt.getEncoding() != nullptr) {
+    SmallVector<Value> dynSzs;
+    getDynamicSizes(rtt, sizes, dynSzs);
+
+    if (isSparse)
+      val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
+    else
+      val = allocDenseTensor(builder, loc, rtt, sizes);
+  };
+
+  void insertOrStore(OpBuilder &builder, Location loc, Value v,
+                     ValueRange crds) {
+    if (isSparse)
+      val = builder.create<InsertOp>(loc, v, val, crds);
+    else
+      builder.create<memref::StoreOp>(loc, v, val, crds);
+  }
+
+  Value getSSA() const {
+    // We don't need to maintain the SSA chain for a memref value.
+    return isSparse ? val : nullptr;
+  }
+
+  Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
+    if (isSparse)
+      return builder.create<LoadOp>(loc, val, true);
+    return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
+  }
+
+  void updateSSA(Value v) {
+    // Dense memref is a non-SSA value.
+    assert(isSparse);
+    val = v;
+  }
+
+private:
+  bool isSparse;
+  Value val; // either a memref (for dense tensor) or a sparse tensor.
+};
+
 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ConcatenateOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.needExtraSort())
+      op.emitError("ConcatenateOp not staged");
+
     const Location loc = op.getLoc();
     const auto dstTp = getSparseTensorType(op);
     const Dimension dimRank = dstTp.getDimRank();
@@ -852,94 +898,54 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
     // foreach in %s1 : insert d0, d1, %tmp
     // foreach in %s2 : insert d0, d1 + size(s1), %tmp
     // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
-    // %t = convert_to_dest_tensor(%tmp)
-    //
-    // NOTE: this cannot be `const` because it will be changed when
-    // `needTmpCOO`, but that's buried in the conditional below and
-    // thus not easily extracted.
-    auto encDst = dstTp.getEncoding();
-    Value dst; // Destination tensor for inserting source tensor values.
-    bool needTmpCOO = true;
-    const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense();
-    Value annotatedDenseDst;
-    if (dstTp.hasEncoding()) {
-      bool allOrdered = false;
-      // When concatenating on dimension 0, and all inputs are sorted
-      // and have an identity dimToLvl, the concatenate will generate
-      // coords in lexOrder thus no need for the tmp COO buffer.
-      // TODO: When conDim != 0, as long as conDim is the first dimension
-      // in all input/output buffers, and all input/output buffers have the same
-      // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
-      // CSC matrices along column).
-      if (!allDense && conDim == 0 && dstTp.isIdentity()) {
-        for (auto i : op.getInputs()) {
-          const auto stt = getSparseTensorType(i);
-          allOrdered = stt.isAllOrdered() && stt.isIdentity();
-          if (!allOrdered)
-            break;
-        }
-      }
-
-      needTmpCOO = !allDense && !allOrdered;
-      const RankedTensorType tp = getBufferType(dstTp, needTmpCOO);
-      encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst;
-      SmallVector<Value> dynSizes;
-      getDynamicSizes(dstTp, sizes, dynSizes);
-      dst = rewriter.create<AllocTensorOp>(loc, tp, dynSizes).getResult();
-      if (allDense) {
-        // Create a view of the values buffer to match the unannotated dense
-        // tensor.
-        Value valuesBuffer = genToValues(rewriter, loc, dst);
-        Value dimCoords =
-            genAlloca(rewriter, loc, dimRank, rewriter.getIndexType(),
-                      /*staticShape=*/true);
-        annotatedDenseDst = dst;
-        dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, valuesBuffer,
-                                    dimCoords);
-      }
-    } else {
-      // TODO: Dense buffers should be allocated/deallocated via the callback
-      // in BufferizationOptions.
-      dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
-    }
 
+    TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
     Value offset = constantIndex(rewriter, loc, 0);
-    SmallVector<Value> initArgs;
-    if (encDst && !allDense)
-      initArgs.push_back(dst);
+    Value iterArg = dstBuf.getSSA();
+
     ForeachOp foreachOp;
     for (Value input : op.getInputs()) {
       // Build a for op for each input tensor to append new values into the
       // output tensor.
       foreachOp = rewriter.create<ForeachOp>(
-          loc, input, initArgs,
+          loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
           [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
               ValueRange reduc) {
             SmallVector<Value> dstLcvs(dstTp.getLvlRank());
             for (Dimension d = 0; d < dimRank; d++) {
               Value crd = dcvs[d];
+              // Transform coordinates for the concatenating dim.
               if (d == conDim)
-                // Transform coordinates for the concatenating dim.
                 crd = builder.create<arith::AddIOp>(loc, crd, offset);
               // FIXME: `toStoredDim` is deprecated
-              dstLcvs[toStoredDim(encDst, d)] = crd;
+              dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
             }
-            if (encDst && !allDense) {
-              Value cond = genIsNonzero(rewriter, loc, v);
-              scf::IfOp ifOp = builder.create<scf::IfOp>(
-                  loc, TypeRange(reduc.front().getType()), cond, /*else*/ true);
+
+            if (!reduc.empty())
+              dstBuf.updateSSA(reduc.front());
+
+            if (!dstTp.isAllDense()) {
+              Value cond = genIsNonzero(builder, loc, v);
+              auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
+                                                    /*else*/ true);
+              builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+
               builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-              Value t =
-                  builder.create<InsertOp>(loc, v, reduc.front(), dstLcvs);
-              rewriter.create<scf::YieldOp>(loc, t);
-              rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
-              rewriter.create<scf::YieldOp>(loc, reduc.front());
-              rewriter.setInsertionPointAfter(ifOp);
-              rewriter.create<sparse_tensor::YieldOp>(loc, ifOp.getResult(0));
+              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
+              builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+
+              // Exits the ifOp, update the sparse tensor SSA value.
+              builder.setInsertionPointAfter(ifOp);
+              assert(!reduc.empty());
+              dstBuf.updateSSA(ifOp.getResult(0));
             } else {
-              builder.create<memref::StoreOp>(loc, v, dst, dstLcvs);
-              builder.create<sparse_tensor::YieldOp>(loc);
+              dstBuf.insertOrStore(builder, loc, v, dstLcvs);
             }
+            if (reduc.empty())
+              builder.create<sparse_tensor::YieldOp>(loc);
+            else
+              builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
           });
       // Accumulates the offset. Note that only static-shaped inputs are allowed
       // by concatenate op verifier, which saves us from computing the offset
@@ -948,88 +954,27 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       assert(sh.has_value());
       offset = rewriter.create<arith::AddIOp>(
           loc, offset, constantIndex(rewriter, loc, *sh));
-      if (encDst && !allDense) {
-        dst = foreachOp.getResult(0);
-        initArgs[0] = dst;
-      }
-    }
 
-    // Temp variable to avoid needing to call `getRankedTensorType`
-    // in the three use-sites below.
-    const RankedTensorType dstRTT = dstTp;
-    if (!encDst) {
-      rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
-    } else if (allDense) {
-      rewriter.replaceOp(
-          op, rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
-                  .getResult());
-    } else {
-      dst = rewriter.create<LoadOp>(loc, dst, true);
-      if (needTmpCOO) {
-        Value tmpCoo = dst;
-        Type dstCooTp = getCOOType(dstRTT, true);
-        // TODO: this should be a sort_coo operation.
-        dst = rewriter
-                  .create<ReorderCOOOp>(loc, dstCooTp, tmpCoo,
-                                        SparseTensorSortKind::HybridQuickSort)
-                  .getResult();
-        dst = rewriter.create<ConvertOp>(loc, dstRTT, dst).getResult();
-        rewriter.create<DeallocTensorOp>(loc, tmpCoo);
+      if (!foreachOp.getResults().empty()) {
+        iterArg = foreachOp.getResult(0);
+        dstBuf.updateSSA(iterArg);
       }
-      rewriter.replaceOp(op, dst);
     }
-    return success();
-  }
-};
 
-struct TensorLike {
-  TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
-             ValueRange sizes)
-      : isSparse(rtt.getEncoding() != nullptr) {
-    SmallVector<Value> dynSzs;
-    getDynamicSizes(rtt, sizes, dynSzs);
-
-    if (isSparse)
-      val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
-    else
-      val = allocDenseTensor(builder, loc, rtt, sizes);
-  };
-
-  void insertOrStore(OpBuilder &builder, Location loc, Value v,
-                     ValueRange crds) {
-    if (isSparse)
-      val = builder.create<InsertOp>(loc, v, val, crds);
-    else
-      builder.create<memref::StoreOp>(loc, v, val, crds);
-  }
-
-  Value getSSA() const {
-    // We don't need to maintain the SSA chain for a memref value.
-    return isSparse ? val : nullptr;
-  }
-
-  Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
-    if (isSparse)
-      return builder.create<LoadOp>(loc, val, true);
-    return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
-  }
+    if (!foreachOp.getResults().empty())
+      dstBuf.updateSSA(iterArg);
 
-  void updateSSA(Value v) {
-    // Dense memref is a non-SSA value.
-    assert(isSparse);
-    val = v;
+    Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
+    rewriter.replaceOp(op, ret);
+    return success();
   }
-
-private:
-  bool isSparse;
-  Value val; // either a memref (for dense tensor) or a sparse tensor.
 };
 
 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ConvertOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!op.directConvertable())
+    if (op.needExtraSort())
       return op.emitError("ConvertOp not staged.");
 
     // TODO: Maybe we want a different operation for this too.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 4c163ea6e067ba6..101238fc16581fb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -15,56 +15,19 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
-  using OpRewritePattern<ConvertOp>::OpRewritePattern;
+template <typename StageWithSortOp>
+struct StageUnorderedConvert : public OpRewritePattern<StageWithSortOp> {
+  using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(ConvertOp op,
+  LogicalResult matchAndRewrite(StageWithSortOp op,
                                 PatternRewriter &rewriter) const override {
-    // TODO: Implement it as an Interface, this can be reused from other
-    // operations too (e.g., concatenate, reshape, etc).
-    if (op.directConvertable())
-      return failure();
-
-    Location loc = op.getLoc();
-    SparseTensorType srcStt = getSparseTensorType(op.getSource());
-    SparseTensorType dstStt = getSparseTensorType(op.getDest());
-
-    // Just to make sure that convert to dense tensor is always direct.
-    assert(!dstStt.isAllDense());
-
-    // source -> coo
-    // The tmp COO must be unordered, otherwise it is a direct conversion.
-    assert(!(srcStt.hasSameDimToLvl(dstStt) && srcStt.isAllOrdered()));
-    (void)srcStt; // to silence warning when assertion is disabled
-
-    Type srcCOOTp = getCOOFromTypeWithOrdering(
-        dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
-    Value srcCOO = op.getSource();
-    if (srcCOO.getType() != srcCOOTp)
-      srcCOO = rewriter.create<ConvertOp>(loc, srcCOOTp, op.getSource());
-
-    // -> sort
-    Type dstCOOTp = getCOOFromTypeWithOrdering(
-        dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
-    Value dstCOO = rewriter.create<ReorderCOOOp>(
-        loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
-
-    // -> dest.
-    if (dstCOO.getType() == op.getType()) {
-      rewriter.replaceOp(op, dstCOO);
-    } else {
-      // Need an extra conversion if the target type is not COO.
-      rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getDest().getType(),
-                                             dstCOO);
-    }
-    // TODO: deallocate extra COOs, we should probably delegate it to buffer
-    // deallocation pass.
-
-    return success();
+    return llvm::cast<StageWithSortSparseOp>(op.getOperation())
+        .stageWithSort(rewriter);
   }
 };
 } // namespace
 
 void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) {
-  patterns.add<StageUnorderedConvert>(patterns.getContext());
+  patterns.add<StageUnorderedConvert<ConvertOp>,
+               StageUnorderedConvert<ConcatenateOp>>(patterns.getContext());
 }

>From 2969a6a7bbe98633228b1179566b553204fed434 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 13 Oct 2023 18:53:39 +0000
Subject: [PATCH 2/8] revert unintended change

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index cc7ed639cbde66c..c5e97e97063706f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1304,10 +1304,9 @@ LogicalResult ConcatenateOp::verify() {
         // If all dimension are statically known, the sum of all the input
         // dimensions should be equal to the output dimension.
         if (sumSz != dstSh)
-          return emitError("The concatenation dimension of the output tensor "
-                           "should be the "
-                           "sum of all the concatenation dimensions of the "
-                           "input tensors.");
+          return emitError(
+              "The concatenation dimension of the output tensor should be the "
+              "sum of all the concatenation dimensions of the input tensors.");
       }
     } else {
       DynSize prev = dstSh;

>From a87051dd8dba92cf47ce0a05159167837e877590 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 13 Oct 2023 20:09:14 +0000
Subject: [PATCH 3/8] renaming varaibles

---
 .../SparseTensor/Transforms/StageSparseOperations.cpp       | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 101238fc16581fb..5875cd4f9fd9d18 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -16,7 +16,7 @@ using namespace mlir::sparse_tensor;
 namespace {
 
 template <typename StageWithSortOp>
-struct StageUnorderedConvert : public OpRewritePattern<StageWithSortOp> {
+struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
   using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(StageWithSortOp op,
@@ -28,6 +28,6 @@ struct StageUnorderedConvert : public OpRewritePattern<StageWithSortOp> {
 } // namespace
 
 void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) {
-  patterns.add<StageUnorderedConvert<ConvertOp>,
-               StageUnorderedConvert<ConcatenateOp>>(patterns.getContext());
+  patterns.add<StageUnorderedSparseOps<ConvertOp>,
+               StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
 }

>From ff045b5c4c86af8c38a4db375592f1e86b31a112 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 16 Oct 2023 16:57:08 +0000
Subject: [PATCH 4/8] update bazel build config

---
 .../llvm-project-overlay/mlir/BUILD.bazel     | 30 +++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 63f9cdafce88b90..eb694569dd99cff 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1231,6 +1231,15 @@ td_library(
     deps = [":OpBaseTdFiles"],
 )
 
+td_library(
+    name = "SparseTensorInterfacesTdFiles",
+    srcs = [
+        "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td",
+    ],
+    includes = ["include"],
+    deps = [":OpBaseTdFiles"],
+)
+
 td_library(
     name = "TilingInterfaceTdFiles",
     srcs = ["include/mlir/Interfaces/TilingInterface.td"],
@@ -2683,6 +2692,7 @@ td_library(
     srcs = [
         "include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td",
         "include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td",
+        "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td",
         "include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td",
         "include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td",
     ],
@@ -2801,6 +2811,23 @@ gentbl_cc_library(
     deps = [":PassBaseTdFiles"],
 )
 
+gentbl_cc_library(
+    name = "SparseTensorInterfacesIncGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td",
+    deps = [":SparseTensorInterfacesTdFiles"],
+)
+
 # This library is shared by both SparseTensorDialect and
 # SparseTensorRuntime, so it must not depend on any of the MLIR/LLVM
 # internals or else mlir_c_runner_utils will inherit that dependency.
@@ -2823,9 +2850,11 @@ cc_library(
         "lib/Dialect/SparseTensor/IR/Detail/Var.cpp",
         "lib/Dialect/SparseTensor/IR/Detail/Var.h",
         "lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp",
+        "lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp",
     ],
     hdrs = [
         "include/mlir/Dialect/SparseTensor/IR/SparseTensor.h",
+        "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h",
         "include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h",
         "include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h",
     ],
@@ -2837,6 +2866,7 @@ cc_library(
         ":InferTypeOpInterface",
         ":SparseTensorAttrDefsIncGen",
         ":SparseTensorEnums",
+        ":SparseTensorInterfacesIncGen",
         ":SparseTensorOpsIncGen",
         ":SparseTensorTypesIncGen",
         "//llvm:Support",

>From b4fbdbd283594752f96a66213e012a3272119113 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 16 Oct 2023 17:05:11 +0000
Subject: [PATCH 5/8] fix comments

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td      | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 29dc946227f5075..2931027621cdf58 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -11,8 +11,6 @@
 
 include "mlir/IR/OpBase.td"
 
-// The 'LinalgContractionOpInterface' provides access to the
-// 'ContractionOpInterface'.
 def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
   let description = [{
     A stage-with-sort sparse tensor operation is an operation that produces

>From d6093ea681f9224ea79ff0d092947ad5b2d62bb2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 16 Oct 2023 17:07:23 +0000
Subject: [PATCH 6/8] fix comments

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index 898eff26f5477f8..304a81bf529d9ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -19,8 +19,6 @@ using namespace mlir::sparse_tensor;
 LogicalResult
 sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
                                          PatternRewriter &rewriter) {
-  // TODO: Implement it as an Interface, this can be reused from other
-  // operations too (e.g., concatenate, reshape, etc).
   if (!op.needExtraSort())
     return failure();
 

>From b2418f8d554bb31f87fbe24b57e413e1c2355569 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 16 Oct 2023 17:32:16 +0000
Subject: [PATCH 7/8] address comments

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp      | 2 +-
 .../Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c5e97e97063706f..420d271e16af87a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1256,7 +1256,7 @@ bool ConcatenateOp::needExtraSort() {
   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
   });
-  // TODO: When conDim != 0, as long as conDim corresponding to  the first level
+  // TODO: When conDim != 0, as long as conDim corresponding to the first level
   // in all input/output buffers, and all input/output buffers have the same
   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
   // CSC matrices along column).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fe2c333a7062705..fa982b9bf95064c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -905,7 +905,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
 
     ForeachOp foreachOp;
     for (Value input : op.getInputs()) {
-      // Build a for op for each input tensor to append new values into the
+      // Builds a for op for each input tensor to append new values into the
       // output tensor.
       foreachOp = rewriter.create<ForeachOp>(
           loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
@@ -914,7 +914,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
             SmallVector<Value> dstLcvs(dstTp.getLvlRank());
             for (Dimension d = 0; d < dimRank; d++) {
               Value crd = dcvs[d];
-              // Transform coordinates for the concatenating dim.
+              // Transforms coordinates for the concatenating dim.
               if (d == conDim)
                 crd = builder.create<arith::AddIOp>(loc, crd, offset);
               // FIXME: `toStoredDim` is deprecated

>From 3381fbbebcbe43310e53f03c7d2f148e77446f39 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 17 Oct 2023 17:49:47 +0000
Subject: [PATCH 8/8] address comments

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h     | 3 ++-
 .../mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td    | 2 +-
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td  | 4 ++--
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp      | 4 ++--
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp   | 2 +-
 .../Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp | 4 ++--
 6 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index f75e02266578495..ebbc522123a5990 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -1,4 +1,5 @@
-//===- SparseTensorInterface.h - sparse tensor operations interfaces-------===//
+//===- SparseTensorInterfaces.h - sparse tensor operations
+//interfaces-------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 2931027621cdf58..1379363ff75f420 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -27,7 +27,7 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
     InterfaceMethod<
     /*desc=*/"Return true if the operation needs an extra sort to produce the final result.",
     /*retTy=*/"bool",
-    /*methodName=*/"needExtraSort",
+    /*methodName=*/"needsExtraSort",
     /*args=*/(ins),
     /*methodBody=*/"">,
     InterfaceMethod<
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a1493c6aebee2b3..3d1807094797ec6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -200,7 +200,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
   let extraClassDeclaration = [{
      // Whether the convert can be done by a single step or it would require
      // an extra sort. Inherited from StageWithSortSparseOpInterface.
-     bool needExtraSort();
+     bool needsExtraSort();
   }];
 
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
@@ -362,7 +362,7 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate",
   let extraClassDeclaration = [{
      // Whether the concatenate can be done by a single step or it would require
      // an extra sort. Inherited from StageWithSortSparseOpInterface.
-     bool needExtraSort();
+     bool needsExtraSort();
   }];
 
   let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 420d271e16af87a..cd1e585438ddac9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1065,7 +1065,7 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-bool ConvertOp::needExtraSort() {
+bool ConvertOp::needsExtraSort() {
   SparseTensorType srcStt = getSparseTensorType(getSource());
   SparseTensorType dstStt = getSparseTensorType(getDest());
 
@@ -1248,7 +1248,7 @@ LogicalResult UnaryOp::verify() {
   return success();
 }
 
-bool ConcatenateOp::needExtraSort() {
+bool ConcatenateOp::needsExtraSort() {
   SparseTensorType dstStt = getSparseTensorType(*this);
   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
     return false;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index 304a81bf529d9ea..d8769eacc44f39b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -19,7 +19,7 @@ using namespace mlir::sparse_tensor;
 LogicalResult
 sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
                                          PatternRewriter &rewriter) {
-  if (!op.needExtraSort())
+  if (!op.needsExtraSort())
     return failure();
 
   Location loc = op.getLoc();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fa982b9bf95064c..1bfee3aa1d7ee8e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -876,7 +876,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ConcatenateOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op.needExtraSort())
+    if (op.needsExtraSort())
       op.emitError("ConcatenateOp not staged");
 
     const Location loc = op.getLoc();
@@ -974,7 +974,7 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ConvertOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op.needExtraSort())
+    if (op.needsExtraSort())
       return op.emitError("ConvertOp not staged.");
 
     // TODO: Maybe we want a different operation for this too.



More information about the llvm-commits mailing list