[llvm] [mlir][sparse] implementating stageSparseOpPass as an interface (PR #69022)
Peiming Liu via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 17 10:50:38 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/69022
>From 77aa8c2cf261ad344f75cc5591cc1f394714b269 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 12 Oct 2023 22:58:50 +0000
Subject: [PATCH 1/8] [mlir][sparse] implementating stage convertOp as an
interface
---
.../Dialect/SparseTensor/IR/CMakeLists.txt | 6 +
.../Dialect/SparseTensor/IR/SparseTensor.h | 1 +
.../SparseTensor/IR/SparseTensorInterfaces.h | 30 +++
.../SparseTensor/IR/SparseTensorInterfaces.td | 47 ++++
.../SparseTensor/IR/SparseTensorOps.td | 18 +-
.../Dialect/SparseTensor/IR/CMakeLists.txt | 1 +
.../SparseTensor/IR/SparseTensorDialect.cpp | 38 ++-
.../IR/SparseTensorInterfaces.cpp | 57 +++++
.../Transforms/SparseTensorRewriting.cpp | 223 +++++++-----------
.../Transforms/StageSparseOperations.cpp | 53 +----
10 files changed, 275 insertions(+), 199 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
create mode 100644 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
create mode 100644 mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
index 25a2e4869cc7824..54ad9491cce512c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td)
mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls)
mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRSparseTensorTypesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td)
+mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen)
+add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 3eb9ce010cb006f..cbca0a7f8cc0e3a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -11,6 +11,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
new file mode 100644
index 000000000000000..f75e02266578495
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -0,0 +1,30 @@
+//===- SparseTensorInterface.h - sparse tensor operations interfaces-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
+#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class PatternRewriter;
+
+namespace sparse_tensor {
+class StageWithSortSparseOp;
+
+namespace detail {
+LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
+ PatternRewriter &rewriter);
+} // namespace detail
+} // namespace sparse_tensor
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
new file mode 100644
index 000000000000000..29dc946227f5075
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -0,0 +1,47 @@
+//===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES
+#define SPARSETENSOR_IR_SPARSETENSORINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+// The 'LinalgContractionOpInterface' provides access to the
+// 'ContractionOpInterface'.
+def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
+ let description = [{
+ A stage-with-sort sparse tensor operation is an operation that produces
+ unordered intermediate output. An extra sort is required to obtain the final
+ ordered result.
+
+ E.g., convert csr -> csc need to be implemented as
+ convert csr -> unordered coo -> sort by column -> csc; and
+ concatenate csr, csc -> csr can be staged into
+ concatenate csr, csr -> unordered coo -> sort by row -> csr.
+ }];
+ let cppNamespace = "::mlir::sparse_tensor";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/"Return true if the operation needs an extra sort to produce the final result.",
+ /*retTy=*/"bool",
+ /*methodName=*/"needExtraSort",
+ /*args=*/(ins),
+ /*methodBody=*/"">,
+ InterfaceMethod<
+ /*desc=*/"Stage the operation, return the final result value after staging.",
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"stageWithSort",
+ /*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
+ /*methodBody=*/[{
+ return detail::stageWithSortImpl($_op, rewriter);
+ }]>,
+ ];
+}
+
+
+#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9016634fa3be8dd..a1493c6aebee2b3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
+include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
}
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
- [Pure]>,
+ [Pure, StageWithSortSparseOpInterface]>,
Arguments<(ins AnyTensor:$source)>,
Results<(outs AnyTensor:$dest)> {
string summary = "Converts between different tensor types";
@@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}];
let extraClassDeclaration = [{
- // Whether the convert can be done by a single step (either a sort or a foreach),
- // or it would require a tmp buffer (sort, then foreach).
- bool directConvertable();
+ // Whether the convert can be done by a single step or it would require
+ // an extra sort. Inherited from StageWithSortSparseOpInterface.
+ bool needExtraSort();
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
@@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure]
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
}
-def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
+def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate",
+ [Pure, StageWithSortSparseOpInterface]>,
Arguments<(ins Variadic<AnyRankedTensor>:$inputs, DimensionAttr:$dimension)>,
Results<(outs AnyRankedTensor:$result)> {
@@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
```
}];
+ let extraClassDeclaration = [{
+ // Whether the concatenate can be done by a single step or it would require
+ // an extra sort. Inherited from StageWithSortSparseOpInterface.
+ bool needExtraSort();
+ }];
+
let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index b22194d45062acc..dd6f1037f71b53f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -29,6 +29,7 @@ endif()
add_mlir_dialect_library(MLIRSparseTensorDialect
SparseTensorDialect.cpp
+ SparseTensorInterfaces.cpp
Detail/Var.cpp
Detail/DimLvlMap.cpp
Detail/LvlTypeParser.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 61522fb0dcd24b5..cc7ed639cbde66c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1065,18 +1065,18 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
return {};
}
-bool ConvertOp::directConvertable() {
+bool ConvertOp::needExtraSort() {
SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());
- // We can always directly convert to unordered sparse tensor or dense tensor
- // since dense tensor support random access.
+ // We do not need an extra sort when returning unordered sparse tensors or
+ // dense tensor since dense tensor support random access.
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
- return true;
+ return false;
if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
srcStt.hasSameDimToLvl(dstStt)) {
- return true;
+ return false;
}
// Source and dest tensors are ordered in different ways. We only do direct
@@ -1086,9 +1086,9 @@ bool ConvertOp::directConvertable() {
// performance.
if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
if (isa<SparseElementsAttr>(constOp.getValue()))
- return true;
+ return false;
- return false;
+ return true;
}
LogicalResult ToPositionsOp::verify() {
@@ -1248,6 +1248,23 @@ LogicalResult UnaryOp::verify() {
return success();
}
+bool ConcatenateOp::needExtraSort() {
+ SparseTensorType dstStt = getSparseTensorType(*this);
+ if (dstStt.isAllDense() || !dstStt.isAllOrdered())
+ return false;
+
+ bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
+ return getSparseTensorType(op).hasSameDimToLvl(dstStt);
+ });
+ // TODO: When conDim != 0, as long as conDim corresponding to the first level
+ // in all input/output buffers, and all input/output buffers have the same
+ // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
+ // CSC matrices along column).
+ bool directLowerable =
+ allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
+ return !directLowerable;
+}
+
LogicalResult ConcatenateOp::verify() {
const auto dstTp = getSparseTensorType(*this);
const Dimension concatDim = getDimension();
@@ -1287,9 +1304,10 @@ LogicalResult ConcatenateOp::verify() {
// If all dimension are statically known, the sum of all the input
// dimensions should be equal to the output dimension.
if (sumSz != dstSh)
- return emitError(
- "The concatenation dimension of the output tensor should be the "
- "sum of all the concatenation dimensions of the input tensors.");
+ return emitError("The concatenation dimension of the output tensor "
+ "should be the "
+ "sum of all the concatenation dimensions of the "
+ "input tensors.");
}
} else {
DynSize prev = dstSh;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
new file mode 100644
index 000000000000000..898eff26f5477f8
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -0,0 +1,57 @@
+//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
+
+LogicalResult
+sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
+ PatternRewriter &rewriter) {
+ // TODO: Implement it as an Interface, this can be reused from other
+ // operations too (e.g., concatenate, reshape, etc).
+ if (!op.needExtraSort())
+ return failure();
+
+ Location loc = op.getLoc();
+ Type finalTp = op->getOpResult(0).getType();
+ SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
+
+ Type srcCOOTp = getCOOFromTypeWithOrdering(
+ dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+
+ // Clones the original operation but changing the output to an unordered COO.
+ Operation *cloned = rewriter.clone(*op.getOperation());
+ rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
+ cloned->getOpResult(0).setType(srcCOOTp);
+ });
+ Value srcCOO = cloned->getOpResult(0);
+
+ // -> sort
+ Type dstCOOTp = getCOOFromTypeWithOrdering(
+ dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+ Value dstCOO = rewriter.create<ReorderCOOOp>(
+ loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
+
+ // -> dest.
+ if (dstCOO.getType() == finalTp) {
+ rewriter.replaceOp(op, dstCOO);
+ } else {
+ // Need an extra conversion if the target type is not COO.
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
+ }
+ // TODO: deallocate extra COOs, we should probably delegate it to buffer
+ // deallocation pass.
+ return success();
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a1ab2495f5f7b5e..fe2c333a7062705 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -829,10 +829,56 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
}
};
+struct TensorLike {
+ TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
+ ValueRange sizes)
+ : isSparse(rtt.getEncoding() != nullptr) {
+ SmallVector<Value> dynSzs;
+ getDynamicSizes(rtt, sizes, dynSzs);
+
+ if (isSparse)
+ val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
+ else
+ val = allocDenseTensor(builder, loc, rtt, sizes);
+ };
+
+ void insertOrStore(OpBuilder &builder, Location loc, Value v,
+ ValueRange crds) {
+ if (isSparse)
+ val = builder.create<InsertOp>(loc, v, val, crds);
+ else
+ builder.create<memref::StoreOp>(loc, v, val, crds);
+ }
+
+ Value getSSA() const {
+ // We don't need to maintain the SSA chain for a memref value.
+ return isSparse ? val : nullptr;
+ }
+
+ Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
+ if (isSparse)
+ return builder.create<LoadOp>(loc, val, true);
+ return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
+ }
+
+ void updateSSA(Value v) {
+ // Dense memref is a non-SSA value.
+ assert(isSparse);
+ val = v;
+ }
+
+private:
+ bool isSparse;
+ Value val; // either a memref (for dense tensor) or a sparse tensor.
+};
+
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter &rewriter) const override {
+ if (op.needExtraSort())
+ op.emitError("ConcatenateOp not staged");
+
const Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op);
const Dimension dimRank = dstTp.getDimRank();
@@ -852,94 +898,54 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// foreach in %s1 : insert d0, d1, %tmp
// foreach in %s2 : insert d0, d1 + size(s1), %tmp
// foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
- // %t = convert_to_dest_tensor(%tmp)
- //
- // NOTE: this cannot be `const` because it will be changed when
- // `needTmpCOO`, but that's buried in the conditional below and
- // thus not easily extracted.
- auto encDst = dstTp.getEncoding();
- Value dst; // Destination tensor for inserting source tensor values.
- bool needTmpCOO = true;
- const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense();
- Value annotatedDenseDst;
- if (dstTp.hasEncoding()) {
- bool allOrdered = false;
- // When concatenating on dimension 0, and all inputs are sorted
- // and have an identity dimToLvl, the concatenate will generate
- // coords in lexOrder thus no need for the tmp COO buffer.
- // TODO: When conDim != 0, as long as conDim is the first dimension
- // in all input/output buffers, and all input/output buffers have the same
- // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
- // CSC matrices along column).
- if (!allDense && conDim == 0 && dstTp.isIdentity()) {
- for (auto i : op.getInputs()) {
- const auto stt = getSparseTensorType(i);
- allOrdered = stt.isAllOrdered() && stt.isIdentity();
- if (!allOrdered)
- break;
- }
- }
-
- needTmpCOO = !allDense && !allOrdered;
- const RankedTensorType tp = getBufferType(dstTp, needTmpCOO);
- encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst;
- SmallVector<Value> dynSizes;
- getDynamicSizes(dstTp, sizes, dynSizes);
- dst = rewriter.create<AllocTensorOp>(loc, tp, dynSizes).getResult();
- if (allDense) {
- // Create a view of the values buffer to match the unannotated dense
- // tensor.
- Value valuesBuffer = genToValues(rewriter, loc, dst);
- Value dimCoords =
- genAlloca(rewriter, loc, dimRank, rewriter.getIndexType(),
- /*staticShape=*/true);
- annotatedDenseDst = dst;
- dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, valuesBuffer,
- dimCoords);
- }
- } else {
- // TODO: Dense buffers should be allocated/deallocated via the callback
- // in BufferizationOptions.
- dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
- }
+ TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
Value offset = constantIndex(rewriter, loc, 0);
- SmallVector<Value> initArgs;
- if (encDst && !allDense)
- initArgs.push_back(dst);
+ Value iterArg = dstBuf.getSSA();
+
ForeachOp foreachOp;
for (Value input : op.getInputs()) {
// Build a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
- loc, input, initArgs,
+ loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
SmallVector<Value> dstLcvs(dstTp.getLvlRank());
for (Dimension d = 0; d < dimRank; d++) {
Value crd = dcvs[d];
+ // Transform coordinates for the concatenating dim.
if (d == conDim)
- // Transform coordinates for the concatenating dim.
crd = builder.create<arith::AddIOp>(loc, crd, offset);
// FIXME: `toStoredDim` is deprecated
- dstLcvs[toStoredDim(encDst, d)] = crd;
+ dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
}
- if (encDst && !allDense) {
- Value cond = genIsNonzero(rewriter, loc, v);
- scf::IfOp ifOp = builder.create<scf::IfOp>(
- loc, TypeRange(reduc.front().getType()), cond, /*else*/ true);
+
+ if (!reduc.empty())
+ dstBuf.updateSSA(reduc.front());
+
+ if (!dstTp.isAllDense()) {
+ Value cond = genIsNonzero(builder, loc, v);
+ auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
+ /*else*/ true);
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value t =
- builder.create<InsertOp>(loc, v, reduc.front(), dstLcvs);
- rewriter.create<scf::YieldOp>(loc, t);
- rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
- rewriter.create<scf::YieldOp>(loc, reduc.front());
- rewriter.setInsertionPointAfter(ifOp);
- rewriter.create<sparse_tensor::YieldOp>(loc, ifOp.getResult(0));
+ dstBuf.insertOrStore(builder, loc, v, dstLcvs);
+ builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
+
+ // Exits the ifOp, update the sparse tensor SSA value.
+ builder.setInsertionPointAfter(ifOp);
+ assert(!reduc.empty());
+ dstBuf.updateSSA(ifOp.getResult(0));
} else {
- builder.create<memref::StoreOp>(loc, v, dst, dstLcvs);
- builder.create<sparse_tensor::YieldOp>(loc);
+ dstBuf.insertOrStore(builder, loc, v, dstLcvs);
}
+ if (reduc.empty())
+ builder.create<sparse_tensor::YieldOp>(loc);
+ else
+ builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
@@ -948,88 +954,27 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
assert(sh.has_value());
offset = rewriter.create<arith::AddIOp>(
loc, offset, constantIndex(rewriter, loc, *sh));
- if (encDst && !allDense) {
- dst = foreachOp.getResult(0);
- initArgs[0] = dst;
- }
- }
- // Temp variable to avoid needing to call `getRankedTensorType`
- // in the three use-sites below.
- const RankedTensorType dstRTT = dstTp;
- if (!encDst) {
- rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
- } else if (allDense) {
- rewriter.replaceOp(
- op, rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
- .getResult());
- } else {
- dst = rewriter.create<LoadOp>(loc, dst, true);
- if (needTmpCOO) {
- Value tmpCoo = dst;
- Type dstCooTp = getCOOType(dstRTT, true);
- // TODO: this should be a sort_coo operation.
- dst = rewriter
- .create<ReorderCOOOp>(loc, dstCooTp, tmpCoo,
- SparseTensorSortKind::HybridQuickSort)
- .getResult();
- dst = rewriter.create<ConvertOp>(loc, dstRTT, dst).getResult();
- rewriter.create<DeallocTensorOp>(loc, tmpCoo);
+ if (!foreachOp.getResults().empty()) {
+ iterArg = foreachOp.getResult(0);
+ dstBuf.updateSSA(iterArg);
}
- rewriter.replaceOp(op, dst);
}
- return success();
- }
-};
-struct TensorLike {
- TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
- ValueRange sizes)
- : isSparse(rtt.getEncoding() != nullptr) {
- SmallVector<Value> dynSzs;
- getDynamicSizes(rtt, sizes, dynSzs);
-
- if (isSparse)
- val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
- else
- val = allocDenseTensor(builder, loc, rtt, sizes);
- };
-
- void insertOrStore(OpBuilder &builder, Location loc, Value v,
- ValueRange crds) {
- if (isSparse)
- val = builder.create<InsertOp>(loc, v, val, crds);
- else
- builder.create<memref::StoreOp>(loc, v, val, crds);
- }
-
- Value getSSA() const {
- // We don't need to maintain the SSA chain for a memref value.
- return isSparse ? val : nullptr;
- }
-
- Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
- if (isSparse)
- return builder.create<LoadOp>(loc, val, true);
- return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
- }
+ if (!foreachOp.getResults().empty())
+ dstBuf.updateSSA(iterArg);
- void updateSSA(Value v) {
- // Dense memref is a non-SSA value.
- assert(isSparse);
- val = v;
+ Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
+ rewriter.replaceOp(op, ret);
+ return success();
}
-
-private:
- bool isSparse;
- Value val; // either a memref (for dense tensor) or a sparse tensor.
};
struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
- if (!op.directConvertable())
+ if (op.needExtraSort())
return op.emitError("ConvertOp not staged.");
// TODO: Maybe we want a different operation for this too.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 4c163ea6e067ba6..101238fc16581fb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -15,56 +15,19 @@ using namespace mlir::sparse_tensor;
namespace {
-struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
- using OpRewritePattern<ConvertOp>::OpRewritePattern;
+template <typename StageWithSortOp>
+struct StageUnorderedConvert : public OpRewritePattern<StageWithSortOp> {
+ using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(ConvertOp op,
+ LogicalResult matchAndRewrite(StageWithSortOp op,
PatternRewriter &rewriter) const override {
- // TODO: Implement it as an Interface, this can be reused from other
- // operations too (e.g., concatenate, reshape, etc).
- if (op.directConvertable())
- return failure();
-
- Location loc = op.getLoc();
- SparseTensorType srcStt = getSparseTensorType(op.getSource());
- SparseTensorType dstStt = getSparseTensorType(op.getDest());
-
- // Just to make sure that convert to dense tensor is always direct.
- assert(!dstStt.isAllDense());
-
- // source -> coo
- // The tmp COO must be unordered, otherwise it is a direct conversion.
- assert(!(srcStt.hasSameDimToLvl(dstStt) && srcStt.isAllOrdered()));
- (void)srcStt; // to silence warning when assertion is disabled
-
- Type srcCOOTp = getCOOFromTypeWithOrdering(
- dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
- Value srcCOO = op.getSource();
- if (srcCOO.getType() != srcCOOTp)
- srcCOO = rewriter.create<ConvertOp>(loc, srcCOOTp, op.getSource());
-
- // -> sort
- Type dstCOOTp = getCOOFromTypeWithOrdering(
- dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
- Value dstCOO = rewriter.create<ReorderCOOOp>(
- loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
-
- // -> dest.
- if (dstCOO.getType() == op.getType()) {
- rewriter.replaceOp(op, dstCOO);
- } else {
- // Need an extra conversion if the target type is not COO.
- rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getDest().getType(),
- dstCOO);
- }
- // TODO: deallocate extra COOs, we should probably delegate it to buffer
- // deallocation pass.
-
- return success();
+ return llvm::cast<StageWithSortSparseOp>(op.getOperation())
+ .stageWithSort(rewriter);
}
};
} // namespace
void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) {
- patterns.add<StageUnorderedConvert>(patterns.getContext());
+ patterns.add<StageUnorderedConvert<ConvertOp>,
+ StageUnorderedConvert<ConcatenateOp>>(patterns.getContext());
}
>From 2969a6a7bbe98633228b1179566b553204fed434 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 13 Oct 2023 18:53:39 +0000
Subject: [PATCH 2/8] revert unintended change
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index cc7ed639cbde66c..c5e97e97063706f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1304,10 +1304,9 @@ LogicalResult ConcatenateOp::verify() {
// If all dimension are statically known, the sum of all the input
// dimensions should be equal to the output dimension.
if (sumSz != dstSh)
- return emitError("The concatenation dimension of the output tensor "
- "should be the "
- "sum of all the concatenation dimensions of the "
- "input tensors.");
+ return emitError(
+ "The concatenation dimension of the output tensor should be the "
+ "sum of all the concatenation dimensions of the input tensors.");
}
} else {
DynSize prev = dstSh;
>From a87051dd8dba92cf47ce0a05159167837e877590 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 13 Oct 2023 20:09:14 +0000
Subject: [PATCH 3/8] renaming varaibles
---
.../SparseTensor/Transforms/StageSparseOperations.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 101238fc16581fb..5875cd4f9fd9d18 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -16,7 +16,7 @@ using namespace mlir::sparse_tensor;
namespace {
template <typename StageWithSortOp>
-struct StageUnorderedConvert : public OpRewritePattern<StageWithSortOp> {
+struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
LogicalResult matchAndRewrite(StageWithSortOp op,
@@ -28,6 +28,6 @@ struct StageUnorderedConvert : public OpRewritePattern<StageWithSortOp> {
} // namespace
void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) {
- patterns.add<StageUnorderedConvert<ConvertOp>,
- StageUnorderedConvert<ConcatenateOp>>(patterns.getContext());
+ patterns.add<StageUnorderedSparseOps<ConvertOp>,
+ StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
}
>From ff045b5c4c86af8c38a4db375592f1e86b31a112 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 16 Oct 2023 16:57:08 +0000
Subject: [PATCH 4/8] update bazel build config
---
.../llvm-project-overlay/mlir/BUILD.bazel | 30 +++++++++++++++++++
1 file changed, 30 insertions(+)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 63f9cdafce88b90..eb694569dd99cff 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1231,6 +1231,15 @@ td_library(
deps = [":OpBaseTdFiles"],
)
+td_library(
+ name = "SparseTensorInterfacesTdFiles",
+ srcs = [
+ "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td",
+ ],
+ includes = ["include"],
+ deps = [":OpBaseTdFiles"],
+)
+
td_library(
name = "TilingInterfaceTdFiles",
srcs = ["include/mlir/Interfaces/TilingInterface.td"],
@@ -2683,6 +2692,7 @@ td_library(
srcs = [
"include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td",
"include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td",
+ "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td",
"include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td",
"include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td",
],
@@ -2801,6 +2811,23 @@ gentbl_cc_library(
deps = [":PassBaseTdFiles"],
)
+gentbl_cc_library(
+ name = "SparseTensorInterfacesIncGen",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td",
+ deps = [":SparseTensorInterfacesTdFiles"],
+)
+
# This library is shared by both SparseTensorDialect and
# SparseTensorRuntime, so it must not depend on any of the MLIR/LLVM
# internals or else mlir_c_runner_utils will inherit that dependency.
@@ -2823,9 +2850,11 @@ cc_library(
"lib/Dialect/SparseTensor/IR/Detail/Var.cpp",
"lib/Dialect/SparseTensor/IR/Detail/Var.h",
"lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp",
+ "lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp",
],
hdrs = [
"include/mlir/Dialect/SparseTensor/IR/SparseTensor.h",
+ "include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h",
"include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h",
"include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h",
],
@@ -2837,6 +2866,7 @@ cc_library(
":InferTypeOpInterface",
":SparseTensorAttrDefsIncGen",
":SparseTensorEnums",
+ ":SparseTensorInterfacesIncGen",
":SparseTensorOpsIncGen",
":SparseTensorTypesIncGen",
"//llvm:Support",
>From b4fbdbd283594752f96a66213e012a3272119113 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 16 Oct 2023 17:05:11 +0000
Subject: [PATCH 5/8] fix comments
---
.../mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 29dc946227f5075..2931027621cdf58 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -11,8 +11,6 @@
include "mlir/IR/OpBase.td"
-// The 'LinalgContractionOpInterface' provides access to the
-// 'ContractionOpInterface'.
def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
let description = [{
A stage-with-sort sparse tensor operation is an operation that produces
>From d6093ea681f9224ea79ff0d092947ad5b2d62bb2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 16 Oct 2023 17:07:23 +0000
Subject: [PATCH 6/8] fix comments
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index 898eff26f5477f8..304a81bf529d9ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -19,8 +19,6 @@ using namespace mlir::sparse_tensor;
LogicalResult
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
PatternRewriter &rewriter) {
- // TODO: Implement it as an Interface, this can be reused from other
- // operations too (e.g., concatenate, reshape, etc).
if (!op.needExtraSort())
return failure();
>From b2418f8d554bb31f87fbe24b57e413e1c2355569 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 16 Oct 2023 17:32:16 +0000
Subject: [PATCH 7/8] address comments
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 +-
.../Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c5e97e97063706f..420d271e16af87a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1256,7 +1256,7 @@ bool ConcatenateOp::needExtraSort() {
bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
return getSparseTensorType(op).hasSameDimToLvl(dstStt);
});
- // TODO: When conDim != 0, as long as conDim corresponding to the first level
+ // TODO: When conDim != 0, as long as conDim corresponding to the first level
// in all input/output buffers, and all input/output buffers have the same
// dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
// CSC matrices along column).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fe2c333a7062705..fa982b9bf95064c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -905,7 +905,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
ForeachOp foreachOp;
for (Value input : op.getInputs()) {
- // Build a for op for each input tensor to append new values into the
+ // Builds a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
@@ -914,7 +914,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
SmallVector<Value> dstLcvs(dstTp.getLvlRank());
for (Dimension d = 0; d < dimRank; d++) {
Value crd = dcvs[d];
- // Transform coordinates for the concatenating dim.
+ // Transforms coordinates for the concatenating dim.
if (d == conDim)
crd = builder.create<arith::AddIOp>(loc, crd, offset);
// FIXME: `toStoredDim` is deprecated
>From 3381fbbebcbe43310e53f03c7d2f148e77446f39 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 17 Oct 2023 17:49:47 +0000
Subject: [PATCH 8/8] address comments
---
.../mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h | 3 ++-
.../mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td | 2 +-
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 4 ++--
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 4 ++--
mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp | 2 +-
.../Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp | 4 ++--
6 files changed, 10 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index f75e02266578495..ebbc522123a5990 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -1,4 +1,5 @@
-//===- SparseTensorInterface.h - sparse tensor operations interfaces-------===//
+//===- SparseTensorInterfaces.h - sparse tensor operations
+//interfaces-------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 2931027621cdf58..1379363ff75f420 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -27,7 +27,7 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
InterfaceMethod<
/*desc=*/"Return true if the operation needs an extra sort to produce the final result.",
/*retTy=*/"bool",
- /*methodName=*/"needExtraSort",
+ /*methodName=*/"needsExtraSort",
/*args=*/(ins),
/*methodBody=*/"">,
InterfaceMethod<
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a1493c6aebee2b3..3d1807094797ec6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -200,7 +200,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
let extraClassDeclaration = [{
// Whether the convert can be done by a single step or it would require
// an extra sort. Inherited from StageWithSortSparseOpInterface.
- bool needExtraSort();
+ bool needsExtraSort();
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
@@ -362,7 +362,7 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate",
let extraClassDeclaration = [{
// Whether the concatenate can be done by a single step or it would require
// an extra sort. Inherited from StageWithSortSparseOpInterface.
- bool needExtraSort();
+ bool needsExtraSort();
}];
let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 420d271e16af87a..cd1e585438ddac9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1065,7 +1065,7 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
return {};
}
-bool ConvertOp::needExtraSort() {
+bool ConvertOp::needsExtraSort() {
SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());
@@ -1248,7 +1248,7 @@ LogicalResult UnaryOp::verify() {
return success();
}
-bool ConcatenateOp::needExtraSort() {
+bool ConcatenateOp::needsExtraSort() {
SparseTensorType dstStt = getSparseTensorType(*this);
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
return false;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index 304a81bf529d9ea..d8769eacc44f39b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -19,7 +19,7 @@ using namespace mlir::sparse_tensor;
LogicalResult
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
PatternRewriter &rewriter) {
- if (!op.needExtraSort())
+ if (!op.needsExtraSort())
return failure();
Location loc = op.getLoc();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fa982b9bf95064c..1bfee3aa1d7ee8e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -876,7 +876,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter &rewriter) const override {
- if (op.needExtraSort())
+ if (op.needsExtraSort())
op.emitError("ConcatenateOp not staged");
const Location loc = op.getLoc();
@@ -974,7 +974,7 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
- if (op.needExtraSort())
+ if (op.needsExtraSort())
return op.emitError("ConvertOp not staged.");
// TODO: Maybe we want a different operation for this too.
More information about the llvm-commits
mailing list