[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)
Peiming Liu
llvmlistbot at llvm.org
Fri Mar 22 11:41:48 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/85958
>From 1e76a2a5417212bb2e5891a00c0069c4118ce778 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 18 Mar 2024 19:40:46 +0000
Subject: [PATCH 1/3] test parse iterate operation
---
.../SparseTensor/IR/SparseTensorOps.td | 17 +++++
.../SparseTensor/IR/SparseTensorTypes.td | 76 +++++++++++++++++++
2 files changed, 93 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..6efeb6007d649e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1418,6 +1418,23 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def IterateOp : SparseTensor_Op<"iterate",
+ [RecursiveMemoryEffects]> {
+
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
+ Variadic<AnyType>:$initArgs);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$region);
+
+ let extraClassDeclaration = [{}];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..54a8e4d7ecd398 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,80 @@ def SparseTensorStorageSpecifier
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
"::mlir::sparse_tensor::StorageSpecifierType">;
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def LevelTypeArrayParameter : ArrayRefParameter<"::mlir::sparse_tensor::LevelType", "level-types"> {
+ let printer = [{
+ auto lvlStrings = llvm::map_range($_self, [](auto lt){ return lt.toMLIRString(); });
+ $_printer << "[" << llvm::join(lvlStrings, ",") << "]";
+ }];
+
+ let parser = [{ [&]() -> FailureOr<SmallVector<::mlir::sparse_tensor::LevelType>> {
+ SmallVector<::mlir::sparse_tensor::LevelType> ret;
+
+ const auto res = $_parser.parseCommaSeparatedList(
+ mlir::OpAsmParser::Delimiter::Square,
+ [&]() -> ParseResult {
+ ::mlir::sparse_tensor::ir_detail::LvlTypeParser lParser;
+ auto lvlTpOrFail = lParser.parseLvlType($_parser);
+ if (failed(lvlTpOrFail))
+ return failure();
+ ret.emplace_back(*lvlTpOrFail);
+ return success();
+ }, " in level-type list");
+
+ if (failed(res))
+ return failure();
+ return ret;
+ }() }];
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+ let mnemonic = "iterator";
+
+ let parameters = (ins
+ LevelTypeArrayParameter: $lvlTypes
+ );
+
+ // let skipDefaultBuilders = 1;
+ // let hasCustomAssemblyFormat = 1;
+ let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+ let mnemonic = "iter_space";
+
+ let parameters = (ins
+ LevelTypeArrayParameter: $lvlTypes
+ );
+
+ let extraClassDeclaration = [{
+ ::mlir::sparse_tensor::IteratorType getIteratorType() const {
+ return IteratorType::get(getContext(), getLvlTypes());
+ }
+ }];
+
+ // let skipDefaultBuilders = 1;
+ // let hasCustomAssemblyFormat = 1;
+ let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+ : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+ "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+ : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+ "::mlir::sparse_tensor::IteratorType">;
+
+
#endif // SPARSETENSOR_TYPES
>From e64944645a4301b393ec4e8c77ee58aaf4cfc19d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 20 Mar 2024 16:14:12 +0000
Subject: [PATCH 2/3] test sparse space collapse
---
.../Dialect/SparseTensor/IR/SparseTensor.h | 2 +
.../SparseTensor/IR/SparseTensorOps.td | 41 ++-
.../SparseTensor/IR/SparseTensorTypes.td | 15 +-
.../Dialect/SparseTensor/Transforms/Passes.h | 6 +
.../Dialect/SparseTensor/Transforms/Passes.td | 16 ++
.../SparseTensor/IR/SparseTensorDialect.cpp | 253 +++++++++++++++++-
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../Transforms/SparseSpaceCollapse.cpp | 152 +++++++++++
8 files changed, 479 insertions(+), 7 deletions(-)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..78692307820bc5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,7 +17,9 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6efeb6007d649e..f0eaf2191fdbd1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
//===----------------------------------------------------------------------===//
// Base class.
@@ -1422,16 +1424,51 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
// Sparse Tensor Iteration Operations.
//===----------------------------------------------------------------------===//
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+ [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+ let arguments = (ins AnySparseTensor:$tensor,
+ Optional<AnySparseIterator>:$parentIter,
+ LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+ let results = (outs AnySparseIterSpace:$resultSpace);
+
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getHiLvl() - getLoLvl();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
def IterateOp : SparseTensor_Op<"iterate",
- [RecursiveMemoryEffects]> {
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+ "getSingleInductionVar", "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
let arguments = (ins AnySparseIterSpace:$iterSpace,
Variadic<AnyType>:$initArgs);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
- let extraClassDeclaration = [{}];
+ let extraClassDeclaration = [{
+ BlockArgument getIterator() {
+ return getRegion().getArguments().front();
+ }
+ unsigned getNumRegionIterArgs() {
+ return getRegion().getArguments().size() - 1;
+ }
+ }];
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 54a8e4d7ecd398..aa674b613e71db 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -109,8 +109,13 @@ def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
LevelTypeArrayParameter: $lvlTypes
);
- // let skipDefaultBuilders = 1;
- // let hasCustomAssemblyFormat = 1;
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getLvlTypes().size();
+ }
+ }];
+
+
let assemblyFormat="`<` $lvlTypes `>`";
}
@@ -123,13 +128,15 @@ def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
);
let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getLvlTypes().size();
+ }
+
::mlir::sparse_tensor::IteratorType getIteratorType() const {
return IteratorType::get(getContext(), getLvlTypes());
}
}];
- // let skipDefaultBuilders = 1;
- // let hasCustomAssemblyFormat = 1;
let assemblyFormat="`<` $lvlTypes `>`";
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..0e9f5120f7b3dc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -247,6 +247,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
bool enableBufferInitialization, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..3ab75c23dbefa0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -454,4 +454,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
];
}
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+ let summary = "(experimental) sparse space collpasing pass";
+ let description = [{
+ This pass collapse consecutive sparse spaces (extracted from the same tensor)
+ into one multi-dimensional space.
+ }];
+ let constructor = "mlir::createSparseSpaceCollapsePass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..6afa3e6309dc65 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
- isa<ForeachOp>(parentOp))
+ isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
return success();
return emitOpError("expected parent op to be sparse_tensor unary, binary, "
"reduce, select or foreach");
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+ IntegerAttr &lvlHiAttr) {
+ Level lvlLo, lvlHi;
+
+ if (parser.parseInteger(lvlLo))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ if (parser.parseInteger(lvlHi))
+ return failure();
+ } else {
+ lvlHi = lvlLo + 1;
+ }
+ lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+ lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+ return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+ IntegerAttr lvlLo, IntegerAttr lvlHi) {
+ if (op.getLoLvl() + 1 == op.getHiLvl())
+ p << op.getLoLvl();
+ else
+ p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+ MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+ DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+
+ ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+ SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+ ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+ adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+ ret.push_back(IterSpaceType::get(ctx, lts));
+ return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+ SparseTensorType stt = getSparseTensorType(getTensor());
+ if (getLoLvl() >= getHiLvl())
+ return emitOpError("expected smaller level low than level high");
+
+ ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+ if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+ return emitOpError(
+ "mismatch in iteration space level types and tensor level types");
+ }
+
+ TypedValue<IteratorType> pIter = getParentIter();
+ if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+ return emitOpError("parent iterator is only needed iff level low equals 0");
+ }
+
+ if (pIter) {
+ unsigned pDim = pIter.getType().getSpaceDim();
+ if (getLoLvl() < pDim || !stt.getLvlTypes()
+ .slice(getLoLvl() - pDim, pDim)
+ .equals(pIter.getType().getLvlTypes())) {
+ return emitOpError(
+ "mismatch in parent iterator level types and tensor level types");
+ }
+ }
+
+ return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::Argument iterator;
+ OpAsmParser::UnresolvedOperand iterSpace;
+
+ // Parses %iters in %spaces
+ if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+ parser.parseOperand(iterSpace)) {
+ return failure();
+ }
+
+ // Parse the optional initial iteration arguments.
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ // Region arguments starts with iterators and follows by optional
+ // user-provided iter_args.
+ regionArgs.push_back(iterator);
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(regionArgs, operands))
+ return failure();
+
+ // parse ": sparse_tensor.iter_space -> ret"
+ Type iterSpaceTps;
+ if (parser.parseColon() || parser.parseType(iterSpaceTps))
+ return failure();
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(result.types))
+ return failure();
+
+ if (regionArgs.size() != result.types.size() + 1) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of loop-carried values and defined values");
+ }
+
+ // Resolves input operands.
+ if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
+ return failure();
+
+ if (hasIterArgs) {
+ for (auto argOperandType :
+ llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+ Type type = std::get<2>(argOperandType);
+ std::get<0>(argOperandType).type = type;
+ if (parser.resolveOperand(std::get<1>(argOperandType), type,
+ result.operands))
+ return failure();
+ }
+ }
+
+ Region *body = result.addRegion();
+ regionArgs.front().type =
+ iterSpaceTps.cast<IterSpaceType>().getIteratorType();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+
+ IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+/// Prints the initialization list in the form of
+/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ p << prefix << '(';
+ llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it);
+ });
+ p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+ p << " " << getIterator() << " in " << getIterSpace();
+
+ printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+ p << " : " << getIterSpace().getType() << " ";
+ if (!getInitArgs().empty())
+ p << "-> (" << getInitArgs().getTypes() << ") ";
+
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+ if (getInitArgs().size() != getNumResults()) {
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ }
+ return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+ if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+ return emitOpError("mismatch in iterator and iteration space type");
+ if (getNumRegionIterArgs() != getNumResults())
+ return emitOpError(
+ "mismatch in number of basic block args and defined values");
+
+ auto initArgs = getInitArgs();
+ auto iterArgs = getRegionIterArgs();
+ auto yieldVals = getYieldedValues();
+ auto opResults = getResults();
+ if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+ opResults.size()})) {
+ return emitOpError() << "number mismatch between iter args and results.";
+ }
+
+ unsigned i = 0;
+ for (auto e : llvm::zip_equal(initArgs, iterArgs, yieldVals, opResults)) {
+ if (std::get<0>(e).getType() != std::get<3>(e).getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter operand and defined value";
+ if (std::get<1>(e).getType() != std::get<3>(e).getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter region arg and defined value";
+ if (std::get<2>(e).getType() != std::get<3>(e).getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th yield value and defined value";
+
+ ++i;
+ }
+ return success();
+}
+
+/// IterateOp implemented interfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> IterateOp::getSingleInductionVar() {
+ return getIterator();
+}
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+ return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+ return getRegion().getArguments().drop_front();
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+ return cast<sparse_tensor::YieldOp>(
+ getRegion().getBlocks().front().getTerminator())
+ .getResultMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself.
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ // It is possible for loop not to enter the body.
+ regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..8840da9aa56ef7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseGPUCodegen.cpp
SparseReinterpretMap.cpp
SparseStorageSpecifierToLLVM.cpp
+ SparseSpaceCollapse.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 00000000000000..f3207ede9585b4
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,152 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+bool isCollapsableIterations(LoopLikeOpInterface parent,
+ LoopLikeOpInterface node) {
+ auto pIterArgs = parent.getRegionIterArgs();
+ auto nInitArgs = node.getInits();
+ if (pIterArgs.size() != nInitArgs.size())
+ return false;
+
+ auto pYields = parent.getYieldedValues();
+ auto nResult = node.getLoopResults().value();
+
+ bool yieldEq =
+ llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ // Parent iter_args should be passed directly to the node's init_args.
+ bool iterArgEq =
+ llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
+ auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
+ auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+
+ // Can only collapse spaces extracted from the same tensor.
+ if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+ return false;
+
+ // Can only collapse consecutive simple iteration on one tensor (i.e., no
+ // coiteration).
+ if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
+ nItOp->getBlock() != parent->getBlock())
+ return false;
+
+ if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+ return false;
+
+ // TODO: Make sure all other operations in the same basic block as `node` can
+ // be collapsed and sink into the collapsed iteration (through Interfaces
+ // defined in TD files).
+ return true;
+}
+
+void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+ if (toCollapse.size() < 2)
+ return;
+
+ ExtractIterSpaceOp root = toCollapse.front();
+ ExtractIterSpaceOp leaf = toCollapse.back();
+ Location loc = root.getLoc();
+
+ if (!leaf->hasOneUse())
+ return;
+ assert(root->hasOneUse());
+
+ // Insert collapsed operation at the same scope as root operation.
+ OpBuilder builder(toCollapse.front());
+
+ // Construct the collapsed iteration space.
+ auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+ loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+ leaf.getHiLvl());
+
+ auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+ auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
+
+ // This could either be IterateOp or (TODO: in the future) CoIterateOp.
+ auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
+ if (!loop || !isCollapsableIterations(pItOp, loop))
+ return;
+
+ IRMapping mapper;
+ mapper.map(leaf, collapsedSpace.getResultSpace());
+ for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+ mapper.map(std::get<0>(z), std::get<1>(z));
+
+ auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+ cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+
+ rItOp.replaceAllUsesWith(cloned.getResults());
+ // Erase collapsed loops.
+ rItOp.erase();
+ root.erase();
+}
+
+struct SparseSpaceCollapsePass
+ : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+ SparseSpaceCollapsePass() = default;
+
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+
+ // A naive (experimental) implementation to collapse consecutive sparse
+ // spaces. It does NOT handle complex cases where multiple spaces are
+ // extracted in the same basic block. E.g.,
+ //
+ // %space1 = extract_space %t1 ...
+ // %space2 = extract_space %t2 ...
+ // sparse_tensor.iterate(%sp1) ...
+ //
+ SmallVector<ExtractIterSpaceOp> toCollapse;
+ func->walk([&](ExtractIterSpaceOp op) {
+ if (toCollapse.empty()) {
+ // Root space to collapse.
+ toCollapse.push_back(op);
+ } else {
+ if (legalToCollapse(toCollapse.back(), op)) {
+ toCollapse.push_back(op);
+ } else {
+ collapseSparseSpace(toCollapse);
+ toCollapse.clear();
+ }
+ }
+ });
+
+ collapseSparseSpace(toCollapse);
+ }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+ return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir
>From 69cf58f3d4375ab076ef16e98d9c513b9eea1205 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 22 Mar 2024 18:41:34 +0000
Subject: [PATCH 3/3] test collapsing coordinate extraction from iterator.
---
.../SparseTensor/IR/SparseTensorInterfaces.h | 2 +
.../SparseTensor/IR/SparseTensorInterfaces.td | 15 ++
.../SparseTensor/IR/SparseTensorOps.td | 17 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 167 +++++++++++-------
.../Transforms/SparseSpaceCollapse.cpp | 119 +++++++++----
5 files changed, 217 insertions(+), 103 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index c0f31762ee071f..115e08b2cf8b14 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
namespace mlir {
class PatternRewriter;
@@ -20,6 +21,7 @@ class StageWithSortSparseOp;
namespace detail {
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
PatternRewriter &rewriter, Value &tmpBufs);
+
} // namespace detail
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 05eed0483f2c8a..ee1c0b52b47e45 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -42,5 +42,20 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
];
}
+def SparseCollapsableOpInterface : OpInterface<"SparseCollapsableOp"> {
+ let description = [{ TODO }];
+
+ let cppNamespace = "::mlir::sparse_tensor";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/"test",
+ /*retTy=*/"ValueRange",
+ /*methodName=*/"collaspeOpInto",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "::mlir::ArrayRef<::mlir::Operation *>":$loops,
+ "::mlir::Operation *":$collapsed)>,
+ ];
+}
#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index f0eaf2191fdbd1..467030a8d221af 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1280,7 +1280,9 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
let hasVerifier = 1;
}
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+ ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", "ForeachOp",
+ "IterateOp"]>]>,
Arguments<(ins Optional<AnyType>:$result)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
@@ -1311,7 +1313,6 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
let assemblyFormat = [{
$result attr-dict `:` type($result)
}];
- let hasVerifier = 1;
}
def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
@@ -1440,10 +1441,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
}];
let hasVerifier = 1;
- let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+ let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
}
+def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
+ [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
+ let arguments = (ins AnySparseIterator:$iterator);
+ let results = (outs Variadic<Index>:$crds);
+
+ let extraClassDeclaration = [{ }];
+ // let hasVerifier = 1;
+ let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
+}
+
def IterateOp : SparseTensor_Op<"iterate",
[RecursiveMemoryEffects, RecursivelySpeculatable,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6afa3e6309dc65..54be3c3b4c3e5f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1907,18 +1907,6 @@ LogicalResult SortOp::verify() {
return success();
}
-LogicalResult YieldOp::verify() {
- // Check for compatible parent.
- auto *parentOp = (*this)->getParentOp();
- if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
- isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
- isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
- return success();
-
- return emitOpError("expected parent op to be sparse_tensor unary, binary, "
- "reduce, select or foreach");
-}
-
//===----------------------------------------------------------------------===//
// Sparse Tensor Iteration Operations.
//===----------------------------------------------------------------------===//
@@ -1936,17 +1924,82 @@ static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
} else {
lvlHi = lvlLo + 1;
}
+
+ if (lvlHi <= lvlLo)
+ parser.emitError(parser.getNameLoc(),
+ "expect larger level upper bound than lower bound");
+
lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
return success();
}
-static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
- IntegerAttr lvlLo, IntegerAttr lvlHi) {
- if (op.getLoLvl() + 1 == op.getHiLvl())
- p << op.getLoLvl();
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+ IntegerAttr lvlHi) {
+ unsigned lo = lvlLo.getValue().getZExtValue();
+ unsigned hi = lvlHi.getValue().getZExtValue();
+ if (lo + 1 == hi)
+ p << lo;
else
- p << op.getLoLvl() << " to " << op.getHiLvl();
+ p << lo << " to " << hi;
+}
+
+ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<OpAsmParser::Argument> &iterators,
+ SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+ SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+ SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+ // Parses "%iters, ... in %spaces, ..."
+ if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+ parser.parseOperandList(spaces))
+ return failure();
+
+ if (iterators.size() != spaces.size())
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of sparse iterators and sparse spaces");
+
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(iterArgs, initArgs))
+ return failure();
+
+ SmallVector<Type> iterSpaceTps;
+ // parse ": sparse_tensor.iter_space -> ret"
+ if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+ return failure();
+ if (iterSpaceTps.size() != spaces.size())
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space operands "
+ "and iteration space types");
+
+ for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+ IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+ if (!spaceTp)
+ return parser.emitError(parser.getNameLoc(),
+ "expected sparse_tensor.iter_space type for "
+ "iteration space operands");
+ it.type = spaceTp.getIteratorType();
+ }
+
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(state.types))
+ return failure();
+
+ // Resolves input operands.
+ if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+ state.operands))
+ return failure();
+
+ if (hasIterArgs) {
+ for (auto [it, init, tp] : llvm::zip(iterArgs, initArgs, state.types)) {
+ it.type = tp;
+ if (parser.resolveOperand(init, tp, state.operands))
+ return failure();
+ }
+ }
+ return success();
}
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
@@ -1991,60 +2044,45 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
+ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
+ ArrayRef<Operation *> loops,
+ Operation *collapsed) {
+ assert(llvm::all_of(loops,
+ [](Operation *l) { return llvm::isa<IterateOp>(l); }));
+ auto finalLoop = llvm::cast<IterateOp>(collapsed);
+ SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
+ builder.getIndexType());
+ auto collapsedCoords =
+ builder.create<CoordinateOp>(getLoc(), retTps, finalLoop.getIterator());
+
+ for (Operation *l : loops) {
+ if (getIterator().getParentBlock()->getParentOp() == l) {
+ auto space = llvm::cast<IterateOp>(l)
+ .getIterSpace()
+ .getDefiningOp<ExtractIterSpaceOp>();
+
+ return collapsedCoords.getResults().slice(space.getLoLvl(),
+ space.getSpaceDim());
+ }
+ }
+ llvm_unreachable(
+ "Can not find the corresponding iterate space for the collapsable op.");
+}
+
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument iterator;
OpAsmParser::UnresolvedOperand iterSpace;
- // Parses %iters in %spaces
- if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
- parser.parseOperand(iterSpace)) {
+ SmallVector<OpAsmParser::Argument> iters, iterArgs;
+ if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
return failure();
- }
-
- // Parse the optional initial iteration arguments.
- SmallVector<OpAsmParser::Argument> regionArgs;
- SmallVector<OpAsmParser::UnresolvedOperand> operands;
- // Region arguments starts with iterators and follows by optional
- // user-provided iter_args.
- regionArgs.push_back(iterator);
- bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
- if (hasIterArgs)
- if (parser.parseAssignmentList(regionArgs, operands))
- return failure();
-
- // parse ": sparse_tensor.iter_space -> ret"
- Type iterSpaceTps;
- if (parser.parseColon() || parser.parseType(iterSpaceTps))
- return failure();
- if (hasIterArgs)
- if (parser.parseArrowTypeList(result.types))
- return failure();
-
- if (regionArgs.size() != result.types.size() + 1) {
- return parser.emitError(
- parser.getNameLoc(),
- "mismatch in number of loop-carried values and defined values");
- }
-
- // Resolves input operands.
- if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
- return failure();
-
- if (hasIterArgs) {
- for (auto argOperandType :
- llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
- Type type = std::get<2>(argOperandType);
- std::get<0>(argOperandType).type = type;
- if (parser.resolveOperand(std::get<1>(argOperandType), type,
- result.operands))
- return failure();
- }
- }
+ if (iters.size() != 1)
+ return parser.emitError(parser.getNameLoc(),
+ "expected only one iterator/iteration space");
+ iters.append(iterArgs);
Region *body = result.addRegion();
- regionArgs.front().type =
- iterSpaceTps.cast<IterSpaceType>().getIteratorType();
- if (parser.parseRegion(*body, regionArgs))
+ if (parser.parseRegion(*body, iters))
return failure();
IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2078,7 +2116,6 @@ static void printInitializationList(OpAsmPrinter &p,
void IterateOp::print(OpAsmPrinter &p) {
p << " " << getIterator() << " in " << getIterSpace();
-
printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
p << " : " << getIterSpace().getType() << " ";
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index f3207ede9585b4..752b2dfc2a0070 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -20,8 +20,14 @@ namespace mlir {
namespace sparse_tensor {
-bool isCollapsableIterations(LoopLikeOpInterface parent,
- LoopLikeOpInterface node) {
+struct CollapseSpaceInfo {
+ ExtractIterSpaceOp space;
+ // Coiteration as well (if make sense)?
+ IterateOp loop;
+ SmallVector<SparseCollapsableOp> collapseOps;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
auto pIterArgs = parent.getRegionIterArgs();
auto nInitArgs = node.getInits();
if (pIterArgs.size() != nInitArgs.size())
@@ -44,43 +50,81 @@ bool isCollapsableIterations(LoopLikeOpInterface parent,
return yieldEq && iterArgEq;
}
-bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
- auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
- auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+ ExtractIterSpaceOp curSpace) {
+
+ auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+ Value spaceVal = space.getResultSpace();
+ if (spaceVal.hasOneUse())
+ return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+ return nullptr;
+ };
+
+ if (toCollapse.empty()) {
+ // Collapse root.
+ if (auto itOp = getIterateOpOverSpace(curSpace)) {
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = itOp;
+ // No operations need to be collapsed at the root level;
+ info.collapseOps = {};
+ return true;
+ }
+ return false;
+ }
+
+ auto parent = toCollapse.back().space;
+ auto pItOp = toCollapse.back().loop;
+ auto nItOp = getIterateOpOverSpace(curSpace);
// Can only collapse spaces extracted from the same tensor.
- if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+ if (parent.getTensor() != curSpace.getTensor())
return false;
// Can only collapse consecutive simple iteration on one tensor (i.e., no
// coiteration).
- if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
- nItOp->getBlock() != parent->getBlock())
+ if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+ pItOp.getIterator() != curSpace.getParentIter() ||
+ curSpace->getParentOp() != pItOp.getOperation())
return false;
- if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+ if (pItOp && !isCollapsableLoops(pItOp, nItOp))
return false;
// TODO: Make sure all other operations in the same basic block as `node` can
// be collapsed and sink into the collapsed iteration (through Interfaces
// defined in TD files).
+ SmallVector<SparseCollapsableOp> collapsableOps;
+ for (Operation &op : *pItOp.getBody()) {
+ if (&op == curSpace.getOperation() || &op == nItOp.getOperation() ||
+ &op == pItOp.getBody()->getTerminator())
+ continue;
+ // All other ops in parent loop need to be collapsable.
+ auto collapsableOp = llvm::dyn_cast<SparseCollapsableOp>(&op);
+ if (!collapsableOp)
+ return false;
+ collapsableOps.push_back(collapsableOp);
+ }
+
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = nItOp;
+ info.collapseOps = std::move(collapsableOps);
return true;
}
-void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
if (toCollapse.size() < 2)
return;
- ExtractIterSpaceOp root = toCollapse.front();
- ExtractIterSpaceOp leaf = toCollapse.back();
+ ExtractIterSpaceOp root = toCollapse.front().space;
+ ExtractIterSpaceOp leaf = toCollapse.back().space;
Location loc = root.getLoc();
- if (!leaf->hasOneUse())
- return;
- assert(root->hasOneUse());
+ assert(root->hasOneUse() && leaf->hasOneUse());
// Insert collapsed operation at the same scope as root operation.
- OpBuilder builder(toCollapse.front());
+ OpBuilder builder(root);
// Construct the collapsed iteration space.
auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
@@ -88,19 +132,29 @@ void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
leaf.getHiLvl());
auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
- auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
-
- // This could either be IterateOp or (TODO: in the future) CoIterateOp.
- auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
- if (!loop || !isCollapsableIterations(pItOp, loop))
- return;
+ auto innermost = toCollapse.back().loop;
IRMapping mapper;
mapper.map(leaf, collapsedSpace.getResultSpace());
- for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+ for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
mapper.map(std::get<0>(z), std::get<1>(z));
- auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+ auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+ builder.setInsertionPointToStart(cloned.getBody());
+ SmallVector<Operation *> loops =
+ llvm::map_to_vector(toCollapse, [](CollapseSpaceInfo &info) {
+ return info.loop.getOperation();
+ });
+
+ for (const CollapseSpaceInfo &info : toCollapse) {
+ for (SparseCollapsableOp op : info.collapseOps) {
+ ValueRange colVals = op.collaspeOpInto(builder, loops, cloned);
+ for (auto [o, r] : llvm::zip(op->getResults(), colVals))
+ o.replaceAllUsesWith(r);
+ op.erase();
+ }
+ }
+
cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
rItOp.replaceAllUsesWith(cloned.getResults());
@@ -124,18 +178,13 @@ struct SparseSpaceCollapsePass
// %space2 = extract_space %t2 ...
// sparse_tensor.iterate(%sp1) ...
//
- SmallVector<ExtractIterSpaceOp> toCollapse;
+ SmallVector<CollapseSpaceInfo> toCollapse;
func->walk([&](ExtractIterSpaceOp op) {
- if (toCollapse.empty()) {
- // Root space to collapse.
- toCollapse.push_back(op);
- } else {
- if (legalToCollapse(toCollapse.back(), op)) {
- toCollapse.push_back(op);
- } else {
- collapseSparseSpace(toCollapse);
- toCollapse.clear();
- }
+ if (!legalToCollapse(toCollapse, op)) {
+ // if not legal to collapse one more space, collapse the existing ones
+ // and clear.
+ collapseSparseSpace(toCollapse);
+ toCollapse.clear();
}
});
More information about the Mlir-commits
mailing list