[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)
Peiming Liu
llvmlistbot at llvm.org
Thu Mar 28 09:32:54 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/85958
>From 1e76a2a5417212bb2e5891a00c0069c4118ce778 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 18 Mar 2024 19:40:46 +0000
Subject: [PATCH 1/5] test parse iterate operation
---
.../SparseTensor/IR/SparseTensorOps.td | 17 +++++
.../SparseTensor/IR/SparseTensorTypes.td | 76 +++++++++++++++++++
2 files changed, 93 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..6efeb6007d649e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1418,6 +1418,23 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def IterateOp : SparseTensor_Op<"iterate",
+ [RecursiveMemoryEffects]> {
+
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
+ Variadic<AnyType>:$initArgs);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$region);
+
+ let extraClassDeclaration = [{}];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..54a8e4d7ecd398 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,80 @@ def SparseTensorStorageSpecifier
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
"::mlir::sparse_tensor::StorageSpecifierType">;
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def LevelTypeArrayParameter : ArrayRefParameter<"::mlir::sparse_tensor::LevelType", "level-types"> {
+ let printer = [{
+ auto lvlStrings = llvm::map_range($_self, [](auto lt){ return lt.toMLIRString(); });
+ $_printer << "[" << llvm::join(lvlStrings, ",") << "]";
+ }];
+
+ let parser = [{ [&]() -> FailureOr<SmallVector<::mlir::sparse_tensor::LevelType>> {
+ SmallVector<::mlir::sparse_tensor::LevelType> ret;
+
+ const auto res = $_parser.parseCommaSeparatedList(
+ mlir::OpAsmParser::Delimiter::Square,
+ [&]() -> ParseResult {
+ ::mlir::sparse_tensor::ir_detail::LvlTypeParser lParser;
+ auto lvlTpOrFail = lParser.parseLvlType($_parser);
+ if (failed(lvlTpOrFail))
+ return failure();
+ ret.emplace_back(*lvlTpOrFail);
+ return success();
+ }, " in level-type list");
+
+ if (failed(res))
+ return failure();
+ return ret;
+ }() }];
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+ let mnemonic = "iterator";
+
+ let parameters = (ins
+ LevelTypeArrayParameter: $lvlTypes
+ );
+
+ // let skipDefaultBuilders = 1;
+ // let hasCustomAssemblyFormat = 1;
+ let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+ let mnemonic = "iter_space";
+
+ let parameters = (ins
+ LevelTypeArrayParameter: $lvlTypes
+ );
+
+ let extraClassDeclaration = [{
+ ::mlir::sparse_tensor::IteratorType getIteratorType() const {
+ return IteratorType::get(getContext(), getLvlTypes());
+ }
+ }];
+
+ // let skipDefaultBuilders = 1;
+ // let hasCustomAssemblyFormat = 1;
+ let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+ : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+ "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+ : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+ "::mlir::sparse_tensor::IteratorType">;
+
+
#endif // SPARSETENSOR_TYPES
>From e64944645a4301b393ec4e8c77ee58aaf4cfc19d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 20 Mar 2024 16:14:12 +0000
Subject: [PATCH 2/5] test sparse space collapse
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 2 +
.../SparseTensor/IR/SparseTensorOps.td | 41 ++-
.../SparseTensor/IR/SparseTensorTypes.td | 15 +-
.../Dialect/SparseTensor/Transforms/Passes.h | 6 +
.../Dialect/SparseTensor/Transforms/Passes.td | 16 ++
.../SparseTensor/IR/SparseTensorDialect.cpp | 253 +++++++++++++++++-
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../Transforms/SparseSpaceCollapse.cpp | 152 +++++++++++
8 files changed, 479 insertions(+), 7 deletions(-)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..78692307820bc5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,7 +17,9 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6efeb6007d649e..f0eaf2191fdbd1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
//===----------------------------------------------------------------------===//
// Base class.
@@ -1422,16 +1424,51 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
// Sparse Tensor Iteration Operations.
//===----------------------------------------------------------------------===//
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+ [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+ let arguments = (ins AnySparseTensor:$tensor,
+ Optional<AnySparseIterator>:$parentIter,
+ LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+ let results = (outs AnySparseIterSpace:$resultSpace);
+
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getHiLvl() - getLoLvl();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
def IterateOp : SparseTensor_Op<"iterate",
- [RecursiveMemoryEffects]> {
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+ "getSingleInductionVar", "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
let arguments = (ins AnySparseIterSpace:$iterSpace,
Variadic<AnyType>:$initArgs);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
- let extraClassDeclaration = [{}];
+ let extraClassDeclaration = [{
+ BlockArgument getIterator() {
+ return getRegion().getArguments().front();
+ }
+ unsigned getNumRegionIterArgs() {
+ return getRegion().getArguments().size() - 1;
+ }
+ }];
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 54a8e4d7ecd398..aa674b613e71db 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -109,8 +109,13 @@ def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
LevelTypeArrayParameter: $lvlTypes
);
- // let skipDefaultBuilders = 1;
- // let hasCustomAssemblyFormat = 1;
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getLvlTypes().size();
+ }
+ }];
+
+
let assemblyFormat="`<` $lvlTypes `>`";
}
@@ -123,13 +128,15 @@ def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
);
let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getLvlTypes().size();
+ }
+
::mlir::sparse_tensor::IteratorType getIteratorType() const {
return IteratorType::get(getContext(), getLvlTypes());
}
}];
- // let skipDefaultBuilders = 1;
- // let hasCustomAssemblyFormat = 1;
let assemblyFormat="`<` $lvlTypes `>`";
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..0e9f5120f7b3dc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -247,6 +247,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
bool enableBufferInitialization, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..3ab75c23dbefa0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -454,4 +454,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
];
}
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+ let summary = "(experimental) sparse space collpasing pass";
+ let description = [{
+ This pass collapse consecutive sparse spaces (extracted from the same tensor)
+ into one multi-dimensional space.
+ }];
+ let constructor = "mlir::createSparseSpaceCollapsePass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..6afa3e6309dc65 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
- isa<ForeachOp>(parentOp))
+ isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
return success();
return emitOpError("expected parent op to be sparse_tensor unary, binary, "
"reduce, select or foreach");
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+ IntegerAttr &lvlHiAttr) {
+ Level lvlLo, lvlHi;
+
+ if (parser.parseInteger(lvlLo))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ if (parser.parseInteger(lvlHi))
+ return failure();
+ } else {
+ lvlHi = lvlLo + 1;
+ }
+ lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+ lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+ return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+ IntegerAttr lvlLo, IntegerAttr lvlHi) {
+ if (op.getLoLvl() + 1 == op.getHiLvl())
+ p << op.getLoLvl();
+ else
+ p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+ MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+ DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+
+ ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+ SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+ ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+ adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+ ret.push_back(IterSpaceType::get(ctx, lts));
+ return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+ SparseTensorType stt = getSparseTensorType(getTensor());
+ if (getLoLvl() >= getHiLvl())
+ return emitOpError("expected smaller level low than level high");
+
+ ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+ if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+ return emitOpError(
+ "mismatch in iteration space level types and tensor level types");
+ }
+
+ TypedValue<IteratorType> pIter = getParentIter();
+ if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+ return emitOpError("parent iterator is only needed iff level low equals 0");
+ }
+
+ if (pIter) {
+ unsigned pDim = pIter.getType().getSpaceDim();
+ if (getLoLvl() < pDim || !stt.getLvlTypes()
+ .slice(getLoLvl() - pDim, pDim)
+ .equals(pIter.getType().getLvlTypes())) {
+ return emitOpError(
+ "mismatch in parent iterator level types and tensor level types");
+ }
+ }
+
+ return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::Argument iterator;
+ OpAsmParser::UnresolvedOperand iterSpace;
+
+ // Parses %iters in %spaces
+ if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+ parser.parseOperand(iterSpace)) {
+ return failure();
+ }
+
+ // Parse the optional initial iteration arguments.
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ // Region arguments starts with iterators and follows by optional
+ // user-provided iter_args.
+ regionArgs.push_back(iterator);
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(regionArgs, operands))
+ return failure();
+
+ // parse ": sparse_tensor.iter_space -> ret"
+ Type iterSpaceTps;
+ if (parser.parseColon() || parser.parseType(iterSpaceTps))
+ return failure();
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(result.types))
+ return failure();
+
+ if (regionArgs.size() != result.types.size() + 1) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of loop-carried values and defined values");
+ }
+
+ // Resolves input operands.
+ if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
+ return failure();
+
+ if (hasIterArgs) {
+ for (auto argOperandType :
+ llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+ Type type = std::get<2>(argOperandType);
+ std::get<0>(argOperandType).type = type;
+ if (parser.resolveOperand(std::get<1>(argOperandType), type,
+ result.operands))
+ return failure();
+ }
+ }
+
+ Region *body = result.addRegion();
+ regionArgs.front().type =
+ iterSpaceTps.cast<IterSpaceType>().getIteratorType();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+
+ IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+/// Prints the initialization list in the form of
+/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ p << prefix << '(';
+ llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it);
+ });
+ p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+ p << " " << getIterator() << " in " << getIterSpace();
+
+ printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+ p << " : " << getIterSpace().getType() << " ";
+ if (!getInitArgs().empty())
+ p << "-> (" << getInitArgs().getTypes() << ") ";
+
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+ if (getInitArgs().size() != getNumResults()) {
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ }
+ return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+ if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+ return emitOpError("mismatch in iterator and iteration space type");
+ if (getNumRegionIterArgs() != getNumResults())
+ return emitOpError(
+ "mismatch in number of basic block args and defined values");
+
+ auto initArgs = getInitArgs();
+ auto iterArgs = getRegionIterArgs();
+ auto yieldVals = getYieldedValues();
+ auto opResults = getResults();
+ if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+ opResults.size()})) {
+ return emitOpError() << "number mismatch between iter args and results.";
+ }
+
+ unsigned i = 0;
+ for (auto e : llvm::zip_equal(initArgs, iterArgs, yieldVals, opResults)) {
+ if (std::get<0>(e).getType() != std::get<3>(e).getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter operand and defined value";
+ if (std::get<1>(e).getType() != std::get<3>(e).getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter region arg and defined value";
+ if (std::get<2>(e).getType() != std::get<3>(e).getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th yield value and defined value";
+
+ ++i;
+ }
+ return success();
+}
+
+/// IterateOp implemented interfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> IterateOp::getSingleInductionVar() {
+ return getIterator();
+}
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+ return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+ return getRegion().getArguments().drop_front();
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+ return cast<sparse_tensor::YieldOp>(
+ getRegion().getBlocks().front().getTerminator())
+ .getResultMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself.
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ // It is possible for loop not to enter the body.
+ regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..8840da9aa56ef7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseGPUCodegen.cpp
SparseReinterpretMap.cpp
SparseStorageSpecifierToLLVM.cpp
+ SparseSpaceCollapse.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 00000000000000..f3207ede9585b4
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,152 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+bool isCollapsableIterations(LoopLikeOpInterface parent,
+ LoopLikeOpInterface node) {
+ auto pIterArgs = parent.getRegionIterArgs();
+ auto nInitArgs = node.getInits();
+ if (pIterArgs.size() != nInitArgs.size())
+ return false;
+
+ auto pYields = parent.getYieldedValues();
+ auto nResult = node.getLoopResults().value();
+
+ bool yieldEq =
+ llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ // Parent iter_args should be passed directly to the node's init_args.
+ bool iterArgEq =
+ llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
+ auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
+ auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+
+ // Can only collapse spaces extracted from the same tensor.
+ if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+ return false;
+
+ // Can only collapse consecutive simple iteration on one tensor (i.e., no
+ // coiteration).
+ if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
+ nItOp->getBlock() != parent->getBlock())
+ return false;
+
+ if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+ return false;
+
+ // TODO: Make sure all other operations in the same basic block as `node` can
+ // be collapsed and sink into the collapsed iteration (through Interfaces
+ // defined in TD files).
+ return true;
+}
+
+void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+ if (toCollapse.size() < 2)
+ return;
+
+ ExtractIterSpaceOp root = toCollapse.front();
+ ExtractIterSpaceOp leaf = toCollapse.back();
+ Location loc = root.getLoc();
+
+ if (!leaf->hasOneUse())
+ return;
+ assert(root->hasOneUse());
+
+ // Insert collapsed operation at the same scope as root operation.
+ OpBuilder builder(toCollapse.front());
+
+ // Construct the collapsed iteration space.
+ auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+ loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+ leaf.getHiLvl());
+
+ auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+ auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
+
+ // This could either be IterateOp or (TODO: in the future) CoIterateOp.
+ auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
+ if (!loop || !isCollapsableIterations(pItOp, loop))
+ return;
+
+ IRMapping mapper;
+ mapper.map(leaf, collapsedSpace.getResultSpace());
+ for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+ mapper.map(std::get<0>(z), std::get<1>(z));
+
+ auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+ cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+
+ rItOp.replaceAllUsesWith(cloned.getResults());
+ // Erase collapsed loops.
+ rItOp.erase();
+ root.erase();
+}
+
+struct SparseSpaceCollapsePass
+ : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+ SparseSpaceCollapsePass() = default;
+
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+
+ // A naive (experimental) implementation to collapse consecutive sparse
+ // spaces. It does NOT handle complex cases where multiple spaces are
+ // extracted in the same basic block. E.g.,
+ //
+ // %space1 = extract_space %t1 ...
+ // %space2 = extract_space %t2 ...
+ // sparse_tensor.iterate(%sp1) ...
+ //
+ SmallVector<ExtractIterSpaceOp> toCollapse;
+ func->walk([&](ExtractIterSpaceOp op) {
+ if (toCollapse.empty()) {
+ // Root space to collapse.
+ toCollapse.push_back(op);
+ } else {
+ if (legalToCollapse(toCollapse.back(), op)) {
+ toCollapse.push_back(op);
+ } else {
+ collapseSparseSpace(toCollapse);
+ toCollapse.clear();
+ }
+ }
+ });
+
+ collapseSparseSpace(toCollapse);
+ }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+ return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir
>From 69cf58f3d4375ab076ef16e98d9c513b9eea1205 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 22 Mar 2024 18:41:34 +0000
Subject: [PATCH 3/5] test collapsing coordinate extraction from iterator.
---
.../SparseTensor/IR/SparseTensorInterfaces.h | 2 +
.../SparseTensor/IR/SparseTensorInterfaces.td | 15 ++
.../SparseTensor/IR/SparseTensorOps.td | 17 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 167 +++++++++++-------
.../Transforms/SparseSpaceCollapse.cpp | 119 +++++++++----
5 files changed, 217 insertions(+), 103 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index c0f31762ee071f..115e08b2cf8b14 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
namespace mlir {
class PatternRewriter;
@@ -20,6 +21,7 @@ class StageWithSortSparseOp;
namespace detail {
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
PatternRewriter &rewriter, Value &tmpBufs);
+
} // namespace detail
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 05eed0483f2c8a..ee1c0b52b47e45 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -42,5 +42,20 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
];
}
+def SparseCollapsableOpInterface : OpInterface<"SparseCollapsableOp"> {
+ let description = [{ TODO }];
+
+ let cppNamespace = "::mlir::sparse_tensor";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/"test",
+ /*retTy=*/"ValueRange",
+ /*methodName=*/"collaspeOpInto",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "::mlir::ArrayRef<::mlir::Operation *>":$loops,
+ "::mlir::Operation *":$collapsed)>,
+ ];
+}
#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index f0eaf2191fdbd1..467030a8d221af 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1280,7 +1280,9 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
let hasVerifier = 1;
}
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+ ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", "ForeachOp",
+ "IterateOp"]>]>,
Arguments<(ins Optional<AnyType>:$result)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
@@ -1311,7 +1313,6 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
let assemblyFormat = [{
$result attr-dict `:` type($result)
}];
- let hasVerifier = 1;
}
def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
@@ -1440,10 +1441,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
}];
let hasVerifier = 1;
- let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+ let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
}
+def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
+ [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
+ let arguments = (ins AnySparseIterator:$iterator);
+ let results = (outs Variadic<Index>:$crds);
+
+ let extraClassDeclaration = [{ }];
+ // let hasVerifier = 1;
+ let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
+}
+
def IterateOp : SparseTensor_Op<"iterate",
[RecursiveMemoryEffects, RecursivelySpeculatable,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6afa3e6309dc65..54be3c3b4c3e5f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1907,18 +1907,6 @@ LogicalResult SortOp::verify() {
return success();
}
-LogicalResult YieldOp::verify() {
- // Check for compatible parent.
- auto *parentOp = (*this)->getParentOp();
- if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
- isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
- isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
- return success();
-
- return emitOpError("expected parent op to be sparse_tensor unary, binary, "
- "reduce, select or foreach");
-}
-
//===----------------------------------------------------------------------===//
// Sparse Tensor Iteration Operations.
//===----------------------------------------------------------------------===//
@@ -1936,17 +1924,82 @@ static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
} else {
lvlHi = lvlLo + 1;
}
+
+ if (lvlHi <= lvlLo)
+ parser.emitError(parser.getNameLoc(),
+ "expect larger level upper bound than lower bound");
+
lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
return success();
}
-static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
- IntegerAttr lvlLo, IntegerAttr lvlHi) {
- if (op.getLoLvl() + 1 == op.getHiLvl())
- p << op.getLoLvl();
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+ IntegerAttr lvlHi) {
+ unsigned lo = lvlLo.getValue().getZExtValue();
+ unsigned hi = lvlHi.getValue().getZExtValue();
+ if (lo + 1 == hi)
+ p << lo;
else
- p << op.getLoLvl() << " to " << op.getHiLvl();
+ p << lo << " to " << hi;
+}
+
+ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<OpAsmParser::Argument> &iterators,
+ SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+ SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+ SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+ // Parses "%iters, ... in %spaces, ..."
+ if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+ parser.parseOperandList(spaces))
+ return failure();
+
+ if (iterators.size() != spaces.size())
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of sparse iterators and sparse spaces");
+
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(iterArgs, initArgs))
+ return failure();
+
+ SmallVector<Type> iterSpaceTps;
+ // parse ": sparse_tensor.iter_space -> ret"
+ if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+ return failure();
+ if (iterSpaceTps.size() != spaces.size())
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space operands "
+ "and iteration space types");
+
+ for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+ IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+ if (!spaceTp)
+ return parser.emitError(parser.getNameLoc(),
+ "expected sparse_tensor.iter_space type for "
+ "iteration space operands");
+ it.type = spaceTp.getIteratorType();
+ }
+
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(state.types))
+ return failure();
+
+ // Resolves input operands.
+ if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+ state.operands))
+ return failure();
+
+ if (hasIterArgs) {
+ for (auto [it, init, tp] : llvm::zip(iterArgs, initArgs, state.types)) {
+ it.type = tp;
+ if (parser.resolveOperand(init, tp, state.operands))
+ return failure();
+ }
+ }
+ return success();
}
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
@@ -1991,60 +2044,45 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
+ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
+ ArrayRef<Operation *> loops,
+ Operation *collapsed) {
+ assert(llvm::all_of(loops,
+ [](Operation *l) { return llvm::isa<IterateOp>(l); }));
+ auto finalLoop = llvm::cast<IterateOp>(collapsed);
+ SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
+ builder.getIndexType());
+ auto collapsedCoords =
+ builder.create<CoordinateOp>(getLoc(), retTps, finalLoop.getIterator());
+
+ for (Operation *l : loops) {
+ if (getIterator().getParentBlock()->getParentOp() == l) {
+ auto space = llvm::cast<IterateOp>(l)
+ .getIterSpace()
+ .getDefiningOp<ExtractIterSpaceOp>();
+
+ return collapsedCoords.getResults().slice(space.getLoLvl(),
+ space.getSpaceDim());
+ }
+ }
+ llvm_unreachable(
+ "Can not find the corresponding iterate space for the collapsable op.");
+}
+
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument iterator;
OpAsmParser::UnresolvedOperand iterSpace;
- // Parses %iters in %spaces
- if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
- parser.parseOperand(iterSpace)) {
+ SmallVector<OpAsmParser::Argument> iters, iterArgs;
+ if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
return failure();
- }
-
- // Parse the optional initial iteration arguments.
- SmallVector<OpAsmParser::Argument> regionArgs;
- SmallVector<OpAsmParser::UnresolvedOperand> operands;
- // Region arguments starts with iterators and follows by optional
- // user-provided iter_args.
- regionArgs.push_back(iterator);
- bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
- if (hasIterArgs)
- if (parser.parseAssignmentList(regionArgs, operands))
- return failure();
-
- // parse ": sparse_tensor.iter_space -> ret"
- Type iterSpaceTps;
- if (parser.parseColon() || parser.parseType(iterSpaceTps))
- return failure();
- if (hasIterArgs)
- if (parser.parseArrowTypeList(result.types))
- return failure();
-
- if (regionArgs.size() != result.types.size() + 1) {
- return parser.emitError(
- parser.getNameLoc(),
- "mismatch in number of loop-carried values and defined values");
- }
-
- // Resolves input operands.
- if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
- return failure();
-
- if (hasIterArgs) {
- for (auto argOperandType :
- llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
- Type type = std::get<2>(argOperandType);
- std::get<0>(argOperandType).type = type;
- if (parser.resolveOperand(std::get<1>(argOperandType), type,
- result.operands))
- return failure();
- }
- }
+ if (iters.size() != 1)
+ return parser.emitError(parser.getNameLoc(),
+ "expected only one iterator/iteration space");
+ iters.append(iterArgs);
Region *body = result.addRegion();
- regionArgs.front().type =
- iterSpaceTps.cast<IterSpaceType>().getIteratorType();
- if (parser.parseRegion(*body, regionArgs))
+ if (parser.parseRegion(*body, iters))
return failure();
IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2078,7 +2116,6 @@ static void printInitializationList(OpAsmPrinter &p,
void IterateOp::print(OpAsmPrinter &p) {
p << " " << getIterator() << " in " << getIterSpace();
-
printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
p << " : " << getIterSpace().getType() << " ";
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index f3207ede9585b4..752b2dfc2a0070 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -20,8 +20,14 @@ namespace mlir {
namespace sparse_tensor {
-bool isCollapsableIterations(LoopLikeOpInterface parent,
- LoopLikeOpInterface node) {
+struct CollapseSpaceInfo {
+ ExtractIterSpaceOp space;
+ // Coiteration as well (if make sense)?
+ IterateOp loop;
+ SmallVector<SparseCollapsableOp> collapseOps;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
auto pIterArgs = parent.getRegionIterArgs();
auto nInitArgs = node.getInits();
if (pIterArgs.size() != nInitArgs.size())
@@ -44,43 +50,81 @@ bool isCollapsableIterations(LoopLikeOpInterface parent,
return yieldEq && iterArgEq;
}
-bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
- auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
- auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+ ExtractIterSpaceOp curSpace) {
+
+ auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+ Value spaceVal = space.getResultSpace();
+ if (spaceVal.hasOneUse())
+ return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+ return nullptr;
+ };
+
+ if (toCollapse.empty()) {
+ // Collapse root.
+ if (auto itOp = getIterateOpOverSpace(curSpace)) {
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = itOp;
+ // No operations need to be collapsed at the root level;
+ info.collapseOps = {};
+ return true;
+ }
+ return false;
+ }
+
+ auto parent = toCollapse.back().space;
+ auto pItOp = toCollapse.back().loop;
+ auto nItOp = getIterateOpOverSpace(curSpace);
// Can only collapse spaces extracted from the same tensor.
- if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+ if (parent.getTensor() != curSpace.getTensor())
return false;
// Can only collapse consecutive simple iteration on one tensor (i.e., no
// coiteration).
- if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
- nItOp->getBlock() != parent->getBlock())
+ if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+ pItOp.getIterator() != curSpace.getParentIter() ||
+ curSpace->getParentOp() != pItOp.getOperation())
return false;
- if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+ if (pItOp && !isCollapsableLoops(pItOp, nItOp))
return false;
// TODO: Make sure all other operations in the same basic block as `node` can
// be collapsed and sink into the collapsed iteration (through Interfaces
// defined in TD files).
+ SmallVector<SparseCollapsableOp> collapsableOps;
+ for (Operation &op : *pItOp.getBody()) {
+ if (&op == curSpace.getOperation() || &op == nItOp.getOperation() ||
+ &op == pItOp.getBody()->getTerminator())
+ continue;
+ // All other ops in parent loop need to be collapsable.
+ auto collapsableOp = llvm::dyn_cast<SparseCollapsableOp>(&op);
+ if (!collapsableOp)
+ return false;
+ collapsableOps.push_back(collapsableOp);
+ }
+
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = nItOp;
+ info.collapseOps = std::move(collapsableOps);
return true;
}
-void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
if (toCollapse.size() < 2)
return;
- ExtractIterSpaceOp root = toCollapse.front();
- ExtractIterSpaceOp leaf = toCollapse.back();
+ ExtractIterSpaceOp root = toCollapse.front().space;
+ ExtractIterSpaceOp leaf = toCollapse.back().space;
Location loc = root.getLoc();
- if (!leaf->hasOneUse())
- return;
- assert(root->hasOneUse());
+ assert(root->hasOneUse() && leaf->hasOneUse());
// Insert collapsed operation at the same scope as root operation.
- OpBuilder builder(toCollapse.front());
+ OpBuilder builder(root);
// Construct the collapsed iteration space.
auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
@@ -88,19 +132,29 @@ void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
leaf.getHiLvl());
auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
- auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
-
- // This could either be IterateOp or (TODO: in the future) CoIterateOp.
- auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
- if (!loop || !isCollapsableIterations(pItOp, loop))
- return;
+ auto innermost = toCollapse.back().loop;
IRMapping mapper;
mapper.map(leaf, collapsedSpace.getResultSpace());
- for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+ for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
mapper.map(std::get<0>(z), std::get<1>(z));
- auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+ auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+ builder.setInsertionPointToStart(cloned.getBody());
+ SmallVector<Operation *> loops =
+ llvm::map_to_vector(toCollapse, [](CollapseSpaceInfo &info) {
+ return info.loop.getOperation();
+ });
+
+ for (const CollapseSpaceInfo &info : toCollapse) {
+ for (SparseCollapsableOp op : info.collapseOps) {
+ ValueRange colVals = op.collaspeOpInto(builder, loops, cloned);
+ for (auto [o, r] : llvm::zip(op->getResults(), colVals))
+ o.replaceAllUsesWith(r);
+ op.erase();
+ }
+ }
+
cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
rItOp.replaceAllUsesWith(cloned.getResults());
@@ -124,18 +178,13 @@ struct SparseSpaceCollapsePass
// %space2 = extract_space %t2 ...
// sparse_tensor.iterate(%sp1) ...
//
- SmallVector<ExtractIterSpaceOp> toCollapse;
+ SmallVector<CollapseSpaceInfo> toCollapse;
func->walk([&](ExtractIterSpaceOp op) {
- if (toCollapse.empty()) {
- // Root space to collapse.
- toCollapse.push_back(op);
- } else {
- if (legalToCollapse(toCollapse.back(), op)) {
- toCollapse.push_back(op);
- } else {
- collapseSparseSpace(toCollapse);
- toCollapse.clear();
- }
+ if (!legalToCollapse(toCollapse, op)) {
+ // if not legal to collapse one more space, collapse the existing ones
+ // and clear.
+ collapseSparseSpace(toCollapse);
+ toCollapse.clear();
}
});
>From d637ed12109f60833001fc57540470366988fc11 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 25 Mar 2024 22:41:52 +0000
Subject: [PATCH 4/5] fuse crds into iterate operation.
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 36 ++++++
.../SparseTensor/IR/SparseTensorAttrDefs.td | 15 +++
.../SparseTensor/IR/SparseTensorOps.td | 30 +++--
.../SparseTensor/IR/SparseTensorDialect.cpp | 119 +++++++++++++-----
.../Transforms/SparseSpaceCollapse.cpp | 44 ++-----
5 files changed, 170 insertions(+), 74 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 78692307820bc5..081a9b8cad8d62 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -22,6 +22,8 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/bit.h"
+
//===----------------------------------------------------------------------===//
//
// Type aliases to help code be more self-documenting. Unfortunately
@@ -43,6 +45,40 @@ using Level = uint64_t;
/// including the value `ShapedType::kDynamic` (for shapes).
using Size = int64_t;
+/// A simple wrapper to encode a bitset of defined (at most 64) levels.
+class LevelSet {
+ uint64_t bits = 0;
+
+public:
+ LevelSet() = default;
+ explicit LevelSet(uint64_t bits) : bits(bits) {}
+ operator uint64_t() const { return bits; }
+
+ LevelSet &set(unsigned i) {
+ assert(i < 64);
+ bits |= 1 << i;
+ return *this;
+ }
+
+ LevelSet &operator|=(LevelSet lhs) {
+ bits |= static_cast<uint64_t>(lhs);
+ return *this;
+ }
+
+ LevelSet &lshift(unsigned offset) {
+ bits = bits << offset;
+ return *this;
+ }
+
+ bool operator[](unsigned i) const {
+ assert(i < 64);
+ return (bits & (1 << i)) != 0;
+ }
+
+ unsigned count() const { return llvm::popcount(bits); }
+ bool empty() const { return bits == 0; }
+};
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index d3be8a3009ba1e..36c075f52f8e5b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
list<Trait> traits = []>
: AttrDef<SparseTensor_Dialect, name, traits>;
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+ TypedAttrBase<
+ I64, "IntegerAttr",
+ And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+ CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+ "LevelSet attribute"> {
+ let returnType = [{::mlir::sparse_tensor::LevelSet}];
+ let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 467030a8d221af..9a918760c3190d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1445,36 +1445,42 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
}
-def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
- [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
- let arguments = (ins AnySparseIterator:$iterator);
- let results = (outs Variadic<Index>:$crds);
-
- let extraClassDeclaration = [{ }];
- // let hasVerifier = 1;
- let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
-}
+// def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
+// [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
+// let arguments = (ins AnySparseIterator:$iterator);
+// let results = (outs Variadic<Index>:$crds);
+// let extraClassDeclaration = [{ }];
+// // let hasVerifier = 1;
+// let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
+// }
def IterateOp : SparseTensor_Op<"iterate",
[RecursiveMemoryEffects, RecursivelySpeculatable,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
- "getSingleInductionVar", "getYieldedValuesMutable"]>,
+ "getYieldedValuesMutable"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
let arguments = (ins AnySparseIterSpace:$iterSpace,
- Variadic<AnyType>:$initArgs);
+ Variadic<AnyType>:$initArgs,
+ LevelSetAttr:$crdUsedLvls);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getIterSpace().getType().getSpaceDim();
+ }
BlockArgument getIterator() {
return getRegion().getArguments().front();
}
+ Block::BlockArgListType getCrds() {
+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ }
unsigned getNumRegionIterArgs() {
- return getRegion().getArguments().size() - 1;
+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
}
}];
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 54be3c3b4c3e5f..2b54c2fda3d739 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -42,6 +42,7 @@ namespace mlir::sparse_tensor {
llvm::hash_code hash_value(LevelType lt) {
return llvm::hash_value(static_cast<uint64_t>(lt));
}
+
} // namespace mlir::sparse_tensor
//===----------------------------------------------------------------------===//
@@ -1944,12 +1945,13 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
p << lo << " to " << hi;
}
-ParseResult
+static ParseResult
parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
SmallVectorImpl<OpAsmParser::Argument> &iterators,
SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
SmallVector<OpAsmParser::UnresolvedOperand> spaces;
SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
// Parses "%iters, ... in %spaces, ..."
if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
parser.parseOperandList(spaces))
@@ -1960,6 +1962,34 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
+ // Parse "at(%crd0, _, ...)"
+ LevelSet crdUsedLvlSet;
+ bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+ unsigned lvlCrdCnt = 0;
+ if (hasUsedCrds) {
+ ParseResult crdList = parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+ if (parser.parseOptionalKeyword("_")) {
+ if (parser.parseArgument(iterArgs.emplace_back()))
+ return failure();
+ // Always use IndexType for the coordinate.
+ crdUsedLvlSet.set(lvlCrdCnt);
+ iterArgs.back().type = parser.getBuilder().getIndexType();
+ }
+ lvlCrdCnt += 1;
+ return success();
+ });
+ if (failed(crdList)) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expecting SSA value or \"_\" for level coordinates");
+ }
+ }
+ // Set the CrdUsedLvl bitset.
+ state.addAttribute("crdUsedLvls",
+ parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+ // Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
if (hasIterArgs)
if (parser.parseAssignmentList(iterArgs, initArgs))
@@ -1980,6 +2010,10 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
return parser.emitError(parser.getNameLoc(),
"expected sparse_tensor.iter_space type for "
"iteration space operands");
+ if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space dimension "
+ "and specified coordinates");
it.type = spaceTp.getIteratorType();
}
@@ -1993,7 +2027,10 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
return failure();
if (hasIterArgs) {
- for (auto [it, init, tp] : llvm::zip(iterArgs, initArgs, state.types)) {
+ unsigned numCrds = crdUsedLvlSet.count();
+ // Strip off leading args that used for coordinates.
+ MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+ for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
it.type = tp;
if (parser.resolveOperand(init, tp, state.operands))
return failure();
@@ -2044,30 +2081,32 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
-ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
- ArrayRef<Operation *> loops,
- Operation *collapsed) {
- assert(llvm::all_of(loops,
- [](Operation *l) { return llvm::isa<IterateOp>(l); }));
- auto finalLoop = llvm::cast<IterateOp>(collapsed);
- SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
- builder.getIndexType());
- auto collapsedCoords =
- builder.create<CoordinateOp>(getLoc(), retTps, finalLoop.getIterator());
-
- for (Operation *l : loops) {
- if (getIterator().getParentBlock()->getParentOp() == l) {
- auto space = llvm::cast<IterateOp>(l)
- .getIterSpace()
- .getDefiningOp<ExtractIterSpaceOp>();
-
- return collapsedCoords.getResults().slice(space.getLoLvl(),
- space.getSpaceDim());
- }
- }
- llvm_unreachable(
- "Can not find the corresponding iterate space for the collapsable op.");
-}
+// ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
+// ArrayRef<Operation *> loops,
+// Operation *collapsed) {
+// assert(llvm::all_of(loops,
+// [](Operation *l) { return llvm::isa<IterateOp>(l); }));
+// auto finalLoop = llvm::cast<IterateOp>(collapsed);
+// SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
+// builder.getIndexType());
+// auto collapsedCoords =
+// builder.create<CoordinateOp>(getLoc(), retTps,
+// finalLoop.getIterator());
+
+// for (Operation *l : loops) {
+// if (getIterator().getParentBlock()->getParentOp() == l) {
+// auto space = llvm::cast<IterateOp>(l)
+// .getIterSpace()
+// .getDefiningOp<ExtractIterSpaceOp>();
+
+// return collapsedCoords.getResults().slice(space.getLoLvl(),
+// space.getSpaceDim());
+// }
+// }
+// llvm_unreachable(
+// "Can not find the corresponding iterate space for the collapsable
+// op.");
+// }
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument iterator;
@@ -2114,8 +2153,30 @@ static void printInitializationList(OpAsmPrinter &p,
p << ")";
}
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+ Block::BlockArgListType blocksArgs,
+ LevelSet crdUsedLvls) {
+ if (crdUsedLvls.empty())
+ return;
+
+ p << " at(";
+ for (unsigned i = 0; i < spaceDim; i++) {
+ if (crdUsedLvls[i]) {
+ p << blocksArgs.front();
+ blocksArgs = blocksArgs.drop_front();
+ } else {
+ p << "_";
+ }
+ if (i != spaceDim - 1)
+ p << ", ";
+ }
+ assert(blocksArgs.empty());
+ p << ")";
+}
+
void IterateOp::print(OpAsmPrinter &p) {
p << " " << getIterator() << " in " << getIterSpace();
+ printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
p << " : " << getIterSpace().getType() << " ";
@@ -2170,16 +2231,12 @@ LogicalResult IterateOp::verifyRegions() {
/// IterateOp implemented interfaces' methods.
SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
-std::optional<Value> IterateOp::getSingleInductionVar() {
- return getIterator();
-}
-
MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
return getInitArgsMutable();
}
Block::BlockArgListType IterateOp::getRegionIterArgs() {
- return getRegion().getArguments().drop_front();
+ return getRegion().getArguments().take_back(getNumRegionIterArgs());
}
std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 752b2dfc2a0070..39c9a9292c9be9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -24,7 +24,6 @@ struct CollapseSpaceInfo {
ExtractIterSpaceOp space;
// Coiteration as well (if make sense)?
IterateOp loop;
- SmallVector<SparseCollapsableOp> collapseOps;
};
bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
@@ -66,8 +65,6 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
CollapseSpaceInfo &info = toCollapse.emplace_back();
info.space = curSpace;
info.loop = itOp;
- // No operations need to be collapsed at the root level;
- info.collapseOps = {};
return true;
}
return false;
@@ -91,29 +88,13 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
if (pItOp && !isCollapsableLoops(pItOp, nItOp))
return false;
- // TODO: Make sure all other operations in the same basic block as `node` can
- // be collapsed and sink into the collapsed iteration (through Interfaces
- // defined in TD files).
- SmallVector<SparseCollapsableOp> collapsableOps;
- for (Operation &op : *pItOp.getBody()) {
- if (&op == curSpace.getOperation() || &op == nItOp.getOperation() ||
- &op == pItOp.getBody()->getTerminator())
- continue;
- // All other ops in parent loop need to be collapsable.
- auto collapsableOp = llvm::dyn_cast<SparseCollapsableOp>(&op);
- if (!collapsableOp)
- return false;
- collapsableOps.push_back(collapsableOp);
- }
-
CollapseSpaceInfo &info = toCollapse.emplace_back();
info.space = curSpace;
info.loop = nItOp;
- info.collapseOps = std::move(collapsableOps);
return true;
}
-void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
if (toCollapse.size() < 2)
return;
@@ -141,21 +122,22 @@ void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
builder.setInsertionPointToStart(cloned.getBody());
- SmallVector<Operation *> loops =
- llvm::map_to_vector(toCollapse, [](CollapseSpaceInfo &info) {
- return info.loop.getOperation();
- });
- for (const CollapseSpaceInfo &info : toCollapse) {
- for (SparseCollapsableOp op : info.collapseOps) {
- ValueRange colVals = op.collaspeOpInto(builder, loops, cloned);
- for (auto [o, r] : llvm::zip(op->getResults(), colVals))
- o.replaceAllUsesWith(r);
- op.erase();
+ LevelSet crdUsedLvls;
+ unsigned shift = 0, argIdx = 1;
+ for (auto info : toCollapse.drop_back()) {
+ LevelSet set = info.loop.getCrdUsedLvls();
+ crdUsedLvls |= set.lshift(shift);
+ shift += info.loop.getSpaceDim();
+ for (BlockArgument crd : info.loop.getCrds()) {
+ BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+ argIdx++, builder.getIndexType(), crd.getLoc());
+ crd.replaceAllUsesWith(collapsedCrd);
}
}
-
+ crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+ cloned.setCrdUsedLvls(crdUsedLvls);
rItOp.replaceAllUsesWith(cloned.getResults());
// Erase collapsed loops.
>From e7406607c6a34aca915ada8988a7e054d2315796 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 28 Mar 2024 16:32:38 +0000
Subject: [PATCH 5/5] setup lowering passes
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 4 +
.../SparseTensor/IR/SparseTensorOps.td | 24 ++--
.../Dialect/SparseTensor/Transforms/Passes.h | 16 +++
.../Dialect/SparseTensor/Transforms/Passes.td | 13 ++
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../Transforms/SparseIterationToScf.cpp | 76 ++++++++++++
.../Transforms/SparseTensorPasses.cpp | 27 ++++
.../Transforms/Utils/SparseTensorLevel.cpp | 117 +++++++++++-------
.../Transforms/Utils/SparseTensorLevel.h | 6 +
9 files changed, 227 insertions(+), 57 deletions(-)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 33f613a46bad84..96ee7111fea2cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -357,6 +357,10 @@ struct LevelType {
return hasSparseSemantic();
}
+ constexpr unsigned getNumBuffer() const {
+ return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
+ }
+
std::string toMLIRString() const {
std::string lvlStr = toFormatString(getLvlFmt());
std::string propStr = "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9a918760c3190d..540cfa880a13e2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -263,7 +263,7 @@ def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemory
}
def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+ [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
Results<(outs AnyNon0RankedMemRef:$result)> {
let summary = "Extracts the `level`-th positions array of the `tensor`";
@@ -285,12 +285,12 @@ def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
: tensor<64x64xf64, #CSR> to memref<?xindex>
```
}];
- let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+ let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+ [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
Results<(outs AnyNon0RankedMemRef:$result)> {
let summary = "Extracts the `level`-th coordinates array of the `tensor`";
@@ -312,12 +312,12 @@ def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
: tensor<64x64xf64, #CSR> to memref<?xindex>
```
}];
- let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+ let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+ [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs AnyNon0RankedMemRef:$result)> {
let summary = "Extracts the linear coordinates array from a tensor";
@@ -340,16 +340,15 @@ def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
Example:
```mlir
- %1 = sparse_tensor.coordinates_buffer %0
- : tensor<64x64xf64, #COO> to memref<?xindex>
+ %1 = sparse_tensor.coordinates_buffer %0 : tensor<64x64xf64, #COO>
```
}];
- let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+ let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+ [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs AnyNon0RankedMemRef:$result)> {
let summary = "Extracts numerical values array from a tensor";
@@ -367,10 +366,10 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
Example:
```mlir
- %1 = sparse_tensor.values %0 : tensor<64x64xf64, #CSR> to memref<?xf64>
+ %1 = sparse_tensor.values %0 : tensor<64x64xf64, #CSR>
```
}];
- let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+ let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
@@ -1438,6 +1437,9 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
unsigned getSpaceDim() {
return getHiLvl() - getLoLvl();
}
+ ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
+ return getResultSpace().getType().getLvlTypes();
+ }
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 0e9f5120f7b3dc..8d2b9fe571e20b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,6 +16,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
//===----------------------------------------------------------------------===//
// Include the generated pass header (which needs some early definitions).
@@ -142,6 +143,21 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
std::unique_ptr<Pass> createLowerForeachToSCFPass();
+//===----------------------------------------------------------------------===//
+// The LowerSparseIterationToSCF pass.
+//===----------------------------------------------------------------------===//
+
+/// Type converter for iter_space and iterator.
+class SparseIterationTypeConverter : public OneToNTypeConverter {
+public:
+ SparseIterationTypeConverter();
+};
+
+void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
+
//===----------------------------------------------------------------------===//
// The SparseTensorConversion pass.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3ab75c23dbefa0..f27c64a7dee84e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -210,6 +210,19 @@ def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> {
];
}
+def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
+ let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
+ let description = [{
+ TODO:
+ }];
+ let constructor = "mlir::createLowerSparseIterationToSCFPass()";
+ let dependentDialects = [
+ "memref::MemRefDialect",
+ "scf::SCFDialect",
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
let summary = "Convert sparse tensors and primitives to library calls";
let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 8840da9aa56ef7..c615f9ad2370c7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseAssembler.cpp
SparseBufferRewriting.cpp
SparseGPUCodegen.cpp
+ SparseIterationToScf.cpp
SparseReinterpretMap.cpp
SparseStorageSpecifierToLLVM.cpp
SparseSpaceCollapse.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
new file mode 100644
index 00000000000000..267eff724590e7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -0,0 +1,76 @@
+
+#include "Utils/CodegenUtils.h"
+#include "Utils/SparseTensorLevel.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+static std::optional<LogicalResult>
+convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
+ if (itSp.getSpaceDim() > 1)
+ llvm_unreachable("Not implemented.");
+
+ auto idxTp = IndexType::get(itSp.getContext());
+ // FIXME: this assumes that the Pos/CrdBitWidth in sparse tensor encoding is
+ // overriden to non-default values.
+ auto sparseMemRef = MemRefType::get({ShapedType::kDynamic}, idxTp);
+ for (LevelType lt : itSp.getLvlTypes()) {
+ // Position and coordinate buffer in the sparse structure.
+ if (lt.isWithPosLT())
+ fields.push_back(sparseMemRef);
+ if (lt.isWithCrdLT())
+ fields.push_back(sparseMemRef);
+ }
+ // Two indices for lower and upper bound.
+ fields.append({idxTp, idxTp});
+ return success();
+}
+
+namespace {
+
+/// Sparse codegen rule for number of entries operator.
+class ExtractIterSpaceConverter
+ : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+public:
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ if (op.getSpaceDim() > 1)
+ llvm_unreachable("Not implemented.");
+ Location loc = op.getLoc();
+
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+ std::unique_ptr<SparseTensorLevel> lvl =
+ makeSparseTensorLevel(rewriter, loc, op.getTensor(), 0, op.getLoLvl());
+
+ SmallVector<Value> result = llvm::to_vector(lvl->getLvlBuffers());
+ if (!op.getParentIter()) {
+ // TODO: handle batch.
+ std::pair<Value, Value> bounds = lvl->peekRangeAt(
+ rewriter, loc, /*batchPrefix*/ {}, constantIndex(rewriter, loc, 0));
+ result.append({bounds.first, bounds.second});
+ } else {
+ llvm_unreachable("Not implemented.");
+ }
+
+ rewriter.replaceOp(op, result, resultMapping);
+ return success();
+ }
+};
+
+} // namespace
+
+mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion(convertIterSpaceType);
+}
+
+void mlir::populateLowerSparseIterationToSCFPatterns(
+ TypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d4c17928d4ca15..3d1a070330a476 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -26,6 +26,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -152,10 +153,32 @@ struct LowerForeachToSCFPass
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateLowerForeachToSCFPatterns(patterns);
+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
+struct LowerSparseIterationToSCFPass
+ : public impl::LowerSparseIterationToSCFBase<
+ LowerSparseIterationToSCFPass> {
+ LowerSparseIterationToSCFPass() = default;
+ LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
+ default;
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ SparseIterationTypeConverter converter;
+ ConversionTarget target(*ctx);
+ target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+
+ populateLowerSparseIterationToSCFPatterns(converter, patterns);
+ if (failed(applyPartialOneToNConversion(getOperation(), converter,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
struct SparseTensorConversionPass
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
SparseTensorConversionPass() = default;
@@ -438,6 +461,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
return std::make_unique<LowerForeachToSCFPass>();
}
+std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
+ return std::make_unique<LowerSparseIterationToSCFPass>();
+}
+
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index bc27fae5d19480..3b501953ef0abe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
namespace {
+template <bool hasPosBuffer>
class SparseLevel : public SparseTensorLevel {
+ // It is either a array of size 2 or size 1 depending on whether the space
+ // level requires a position array.
+ using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+ std::array<Value, 1>>;
+
public:
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+ BufferT buffers)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+ ValueRange getLvlBuffers() const override { return buffers; }
Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value iv) const override {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(iv);
- return genIndexLoad(b, l, crdBuffer, memCrd);
+ return genIndexLoad(b, l, getCrdBuf(), memCrd);
}
protected:
- const Value crdBuffer;
+ template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+ Value getPosBuf() const {
+ return buffers[0];
+ }
+
+ Value getCrdBuf() const {
+ if constexpr (hasPosBuffer)
+ return buffers[1];
+ else
+ return buffers[0];
+ }
+
+ const BufferT buffers;
};
class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
}
};
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel<true> {
public:
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel<true> {
public:
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
p = MULI(p, C_IDX(2));
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel<false> {
public:
SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
}
};
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel<false> {
public:
NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -1314,6 +1332,30 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
// SparseIterator factory functions.
//===----------------------------------------------------------------------===//
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
+ unsigned t, Level l) {
+ assert(lt.getNumBuffer() == b.size());
+ switch (lt.getLvlFmt()) {
+ case LevelFormat::Dense:
+ return std::make_unique<DenseLevel>(t, l, sz);
+ case LevelFormat::Batch:
+ return std::make_unique<BatchLevel>(t, l, sz);
+ case LevelFormat::Compressed:
+ return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
+ case LevelFormat::LooseCompressed:
+ return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
+ case LevelFormat::Singleton:
+ return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
+ case LevelFormat::NOutOfM:
+ return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
+ case LevelFormat::Undef:
+ llvm_unreachable("undefined level format");
+ }
+ llvm_unreachable("unrecognizable level format");
+}
+
std::unique_ptr<SparseTensorLevel>
sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
unsigned tid, Level lvl) {
@@ -1323,33 +1365,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
: b.create<tensor::DimOp>(l, t, lvl).getResult();
- switch (lt.getLvlFmt()) {
- case LevelFormat::Dense:
- return std::make_unique<DenseLevel>(tid, lvl, sz);
- case LevelFormat::Batch:
- return std::make_unique<BatchLevel>(tid, lvl, sz);
- case LevelFormat::Compressed: {
- Value pos = b.create<ToPositionsOp>(l, t, lvl);
- Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
- return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
- }
- case LevelFormat::LooseCompressed: {
+ SmallVector<Value, 2> buffers;
+ if (lt.isWithPosLT()) {
Value pos = b.create<ToPositionsOp>(l, t, lvl);
- Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
- return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
- }
- case LevelFormat::Singleton: {
- Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
- return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
+ buffers.push_back(pos);
}
- case LevelFormat::NOutOfM: {
- Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
- return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
+ if (lt.isWithCrdLT()) {
+ Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
+ buffers.push_back(pos);
}
- case LevelFormat::Undef:
- llvm_unreachable("undefined level format");
- }
- llvm_unreachable("unrecognizable level format");
+ return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
}
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 9f92eecdf75cb6..46188fc112bd95 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -55,6 +55,7 @@ class SparseTensorLevel {
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
Value getSize() const { return lvlSize; }
+ virtual ValueRange getLvlBuffers() const = 0;
//
// Level properties
@@ -287,6 +288,11 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
Location loc, Value t,
unsigned tid, Level l);
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
+ ValueRange buffers,
+ unsigned tid, Level l);
+
/// Helper function to create a simple SparseIterator object that iterate over
/// the SparseTensorLevel.
std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
More information about the Mlir-commits
mailing list