[Mlir-commits] [mlir] 785a24f - [mlir][sparse] introduce `sparse_tensor.coiterate` operation. (#101100)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 31 15:14:30 PDT 2024
Author: Peiming Liu
Date: 2024-07-31T15:14:27-07:00
New Revision: 785a24f1561c610ecbce7cdfbff053e0a3a7caec
URL: https://github.com/llvm/llvm-project/commit/785a24f1561c610ecbce7cdfbff053e0a3a7caec
DIFF: https://github.com/llvm/llvm-project/commit/785a24f1561c610ecbce7cdfbff053e0a3a7caec.diff
LOG: [mlir][sparse] introduce `sparse_tensor.coiterate` operation. (#101100)
This PR introduces `sparse_tensor.coiterate` operation, which represents
a loop that traverses multiple sparse iteration space.
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 68ca036121520..388efd1c454b1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -61,37 +61,62 @@ struct COOSegment {
/// A simple wrapper to encode a bitset of (at most 64) levels, currently used
/// by `sparse_tensor.iterate` operation for the set of levels on which the
/// coordinates should be loaded.
-class LevelSet {
- uint64_t bits = 0;
+class I64BitSet {
+ uint64_t storage = 0;
public:
- LevelSet() = default;
- explicit LevelSet(uint64_t bits) : bits(bits) {}
- operator uint64_t() const { return bits; }
+ using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>;
+ const_set_bits_iterator begin() const {
+ return const_set_bits_iterator(*this);
+ }
+ const_set_bits_iterator end() const {
+ return const_set_bits_iterator(*this, -1);
+ }
+ iterator_range<const_set_bits_iterator> bits() const {
+ return make_range(begin(), end());
+ }
+
+ I64BitSet() = default;
+ explicit I64BitSet(uint64_t bits) : storage(bits) {}
+ operator uint64_t() const { return storage; }
- LevelSet &set(unsigned i) {
+ I64BitSet &set(unsigned i) {
assert(i < 64);
- bits |= static_cast<uint64_t>(0x01u) << i;
+ storage |= static_cast<uint64_t>(0x01u) << i;
return *this;
}
- LevelSet &operator|=(LevelSet lhs) {
- bits |= static_cast<uint64_t>(lhs);
+ I64BitSet &operator|=(I64BitSet lhs) {
+ storage |= static_cast<uint64_t>(lhs);
return *this;
}
- LevelSet &lshift(unsigned offset) {
- bits = bits << offset;
+ I64BitSet &lshift(unsigned offset) {
+ storage = storage << offset;
return *this;
}
+ // Needed by `llvm::const_set_bits_iterator_impl`.
+ int find_first() const { return min(); }
+ int find_next(unsigned prev) const {
+ if (prev >= max())
+ return -1;
+
+ uint64_t b = storage >> (prev + 1);
+ if (b == 0)
+ return -1;
+
+ return llvm::countr_zero(b) + prev + 1;
+ }
+
bool operator[](unsigned i) const {
assert(i < 64);
- return (bits & (1 << i)) != 0;
+ return (storage & (1 << i)) != 0;
}
- unsigned max() const { return 64 - llvm::countl_zero(bits); }
- unsigned count() const { return llvm::popcount(bits); }
- bool empty() const { return bits == 0; }
+ unsigned min() const { return llvm::countr_zero(storage); }
+ unsigned max() const { return 64 - llvm::countl_zero(storage); }
+ unsigned count() const { return llvm::popcount(storage); }
+ bool empty() const { return storage == 0; }
};
} // namespace sparse_tensor
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 69b212cce4ceb..cb6c1b63e4e4b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -24,16 +24,17 @@ class SparseTensor_Attr<string name,
// sparse tensor levels.
//===----------------------------------------------------------------------===//
-def LevelSetAttr :
- TypedAttrBase<
- I64, "IntegerAttr",
+def I64BitSetAttr : TypedAttrBase<I64, "IntegerAttr",
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
"LevelSet attribute"> {
- let returnType = [{::mlir::sparse_tensor::LevelSet}];
- let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+ let returnType = [{::mlir::sparse_tensor::I64BitSet}];
+ let convertFromStorage = [{::mlir::sparse_tensor::I64BitSet($_self.getValue().getZExtValue())}];
}
+def I64BitSetArrayAttr :
+ TypedArrayAttrBase<I64BitSetAttr, "I64BitSet array attribute">;
+
//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index ff9858d5832ba..6e17f804993e2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1306,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
- "ForeachOp", "IterateOp"]>]> {
+ "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1629,14 +1629,14 @@ def IterateOp : SparseTensor_Op<"iterate",
let arguments = (ins AnySparseIterSpace:$iterSpace,
Variadic<AnyType>:$initArgs,
- LevelSetAttr:$crdUsedLvls);
+ I64BitSetAttr:$crdUsedLvls);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
- OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
+ OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "I64BitSet" :$crdUsedLvls)>
];
let extraClassDeclaration = [{
@@ -1669,6 +1669,127 @@ def IterateOp : SparseTensor_Op<"iterate",
let hasCustomAssemblyFormat = 1;
}
+def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
+ [AttrSizedOperandSegments,
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
+ RecursiveMemoryEffects]> {
+ let summary = "Co-iterates over a set of sparse iteration spaces";
+ let description = [{
+ The `sparse_tensor.coiterate` operation represents a loop (nest) over
+ a set of iteration spaces. The operation can have multiple regions,
+ with each of them defining a case to compute a result at the current iterations.
+ The case condition is defined solely based on the pattern of specified iterators.
+ For example:
+ ```mlir
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+ : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+ !sparse_tensor.iter_space<#COO, lvls = 0>)
+ -> index
+ case %it1, _ {
+ // %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
+ }
+ case %it1, %it2 {
+ // %coord is specifed in *BOTH* spaces %sp1 and %sp2.
+ }
+ ```
+
+ `sparse_tensor.coiterate` can also operate on loop-carried variables.
+ It returns the final value for each loop-carried variable after loop termination.
+ The initial values of the variables are passed as additional SSA operands
+ to the iterator SSA value and used coordinate SSA values.
+ Each operation region has variadic arguments for specified (used), one argument
+ for each loop-carried variable, representing the value of the variable
+ at the current iteration, followed by a list of arguments for iterators.
+ The body region must contain exactly one block that terminates with
+ `sparse_tensor.yield`.
+
+ The results of an `sparse_tensor.coiterate` hold the final values after
+ the last iteration. If the `sparse_tensor.coiterate` defines any values,
+ a yield must be explicitly present in every region defined in the operation.
+ The number and types of the `sparse_tensor.coiterate` results must match
+ the initial values in the iter_args binding and the yield operands.
+
+
+ A `sparse_tensor.coiterate` example that does elementwise addition between two
+ sparse vectors.
+
+
+ ```mlir
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
+ : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
+ !sparse_tensor.iter_space<#CSR, lvls = 0>)
+ -> tensor<?xindex, #CSR>
+ case %it1, _ {
+ // v = v1 + 0 = v1
+ %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+ %yield = sparse_tensor.insert %v1 into %arg[%coord]
+ sparse_tensor.yield %yield
+ }
+ case _, %it2 {
+ // v = v2 + 0 = v2
+ %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+ %yield = sparse_tensor.insert %v1 into %arg[%coord]
+ sparse_tensor.yield %yield
+ }
+ case %it1, %it2 {
+ // v = v1 + v2
+ %v1 = sparse_tensor.extract_value %t1 at %it1 : index
+ %v2 = sparse_tensor.extract_value %t2 at %it2 : index
+ %v = arith.addi %v1, %v2 : index
+ %yield = sparse_tensor.insert %v into %arg[%coord]
+ sparse_tensor.yield %yield
+ }
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
+ Variadic<AnyType>:$initArgs,
+ I64BitSetAttr:$crdUsedLvls,
+ I64BitSetArrayAttr:$cases);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
+ getIterSpaces().front().getType())
+ .getSpaceDim();
+ }
+ I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
+ return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
+ .getValue().getZExtValue());
+ }
+ auto getRegionDefinedSpaces() {
+ return llvm::map_range(getCases().getValue(), [](Attribute attr) {
+ return I64BitSet(llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
+ });
+ }
+
+ // The block arguments starts with referenced coordinates, follows by
+ // user-provided iteration arguments and ends with iterators.
+ Block::BlockArgListType getCrds(unsigned regionIdx) {
+ return getRegion(regionIdx).getArguments()
+ .take_front(getCrdUsedLvls().count());
+ }
+ unsigned getNumRegionIterArgs(unsigned regionIdx) {
+ return getInitArgs().size();
+ }
+ Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
+ return getRegion(regionIdx).getArguments()
+ .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
+ }
+ Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
+ return getRegion(regionIdx).getArguments()
+ .take_back(getRegionDefinedSpace(regionIdx).count());
+ }
+ ValueRange getYieldedValues(unsigned regionIdx);
+ }];
+
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0a276d87f3bca..1135ea32fe1ab 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2131,10 +2131,82 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
printLevelRange(p, lo, hi);
}
+/// Parses a list of `optional` defined list in the form of
+/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
+/// corresponding value is not defined (e.g., to represent an undefined
+/// coordinate in the sparse iteration space).
+static ParseResult parseOptionalDefinedList(
+ OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
+ SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
+ unsigned maxCnt = std::numeric_limits<unsigned>::max(),
+ OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
+ unsigned cnt = 0;
+ ParseResult crdList =
+ parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
+ if (parser.parseOptionalKeyword("_")) {
+ if (parser.parseArgument(definedArgs.emplace_back()))
+ return failure();
+ definedSet.set(cnt);
+ }
+ cnt += 1;
+ return success();
+ });
+
+ if (cnt > maxCnt)
+ return parser.emitError(parser.getNameLoc(),
+ "parsed more value than expected.");
+
+ if (failed(crdList)) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expecting SSA value or \"_\" for level coordinates");
+ }
+ assert(definedArgs.size() == definedSet.count());
+ return success();
+}
+
+static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
+ Block::BlockArgListType blocksArgs,
+ I64BitSet definedSet) {
+ if (definedSet.empty())
+ return;
+
+ for (unsigned i = 0; i < size; i++) {
+ if (definedSet[i]) {
+ p << blocksArgs.front();
+ blocksArgs = blocksArgs.drop_front();
+ } else {
+ p << "_";
+ }
+ if (i != size - 1)
+ p << ", ";
+ }
+ assert(blocksArgs.empty());
+}
+
static ParseResult
-parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
- SmallVectorImpl<OpAsmParser::Argument> &iterators,
- SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+parseUsedCoordList(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<OpAsmParser::Argument> &coords) {
+ // Parse "at(%crd0, _, ...)"
+ I64BitSet crdUsedLvlSet;
+ if (succeeded(parser.parseOptionalKeyword("at")) &&
+ failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
+ return failure();
+
+ // Always use IndexType for the coordinate.
+ for (auto &coord : coords)
+ coord.type = parser.getBuilder().getIndexType();
+
+ // Set the CrdUsedLvl bitset.
+ state.addAttribute("crdUsedLvls",
+ parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+ return success();
+}
+
+static ParseResult
+parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<OpAsmParser::Argument> &iterators,
+ SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
SmallVector<OpAsmParser::UnresolvedOperand> spaces;
SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
@@ -2148,37 +2220,14 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
- // Parse "at(%crd0, _, ...)"
- LevelSet crdUsedLvlSet;
- bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
- unsigned lvlCrdCnt = 0;
- if (hasUsedCrds) {
- ParseResult crdList = parser.parseCommaSeparatedList(
- OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
- if (parser.parseOptionalKeyword("_")) {
- if (parser.parseArgument(iterArgs.emplace_back()))
- return failure();
- // Always use IndexType for the coordinate.
- crdUsedLvlSet.set(lvlCrdCnt);
- iterArgs.back().type = parser.getBuilder().getIndexType();
- }
- lvlCrdCnt += 1;
- return success();
- });
- if (failed(crdList)) {
- return parser.emitError(
- parser.getNameLoc(),
- "expecting SSA value or \"_\" for level coordinates");
- }
- }
- // Set the CrdUsedLvl bitset.
- state.addAttribute("crdUsedLvls",
- parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+ if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ return failure();
+ size_t numCrds = blockArgs.size();
// Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
if (hasIterArgs)
- if (parser.parseAssignmentList(iterArgs, initArgs))
+ if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
SmallVector<Type> iterSpaceTps;
@@ -2196,10 +2245,6 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
return parser.emitError(parser.getNameLoc(),
"expected sparse_tensor.iter_space type for "
"iteration space operands");
- if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
- return parser.emitError(parser.getNameLoc(),
- "mismatch in number of iteration space dimension "
- "and specified coordinates");
it.type = spaceTp.getIteratorType();
}
@@ -2213,9 +2258,68 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
return failure();
if (hasIterArgs) {
- unsigned numCrds = crdUsedLvlSet.count();
// Strip off leading args that used for coordinates.
- MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+ MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+ if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of iteration arguments and return values");
+ }
+
+ for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+ it.type = tp;
+ if (parser.resolveOperand(init, tp, state.operands))
+ return failure();
+ }
+ }
+ return success();
+}
+
+static ParseResult
+parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<Value> &spacesVals,
+ SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
+
+ // Parse "(%spaces, ...)"
+ SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+ if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ if (failed(parseUsedCoordList(parser, state, blockArgs)))
+ return failure();
+ size_t numCrds = blockArgs.size();
+
+ // Parse "iter_args(%arg = %init, ...)"
+ SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(blockArgs, initArgs))
+ return failure();
+
+ SmallVector<Type> iterSpaceTps;
+ // parse ": (sparse_tensor.iter_space, ...) -> ret"
+ if (parser.parseColon() || parser.parseLParen() ||
+ parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
+ return failure();
+
+ if (iterSpaceTps.size() != spaces.size())
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space operands "
+ "and iteration space types");
+
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(state.types))
+ return failure();
+
+ // Resolves input sparse iteration spaces.
+ if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+ spacesVals))
+ return failure();
+ state.operands.append(spacesVals);
+
+ if (hasIterArgs) {
+ // Strip off leading args that used for coordinates.
+ MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
@@ -2285,7 +2389,7 @@ struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
LogicalResult matchAndRewrite(IterateOp iterateOp,
PatternRewriter &rewriter) const override {
- LevelSet newUsedLvls(0);
+ I64BitSet newUsedLvls(0);
llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
if (auto crd = iterateOp.getLvlCrd(i)) {
@@ -2317,13 +2421,13 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
Value iterSpace, ValueRange initArgs) {
unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
// All ones.
- LevelSet set((1 << rank) - 1);
+ I64BitSet set((1 << rank) - 1);
return build(builder, odsState, iterSpace, initArgs, set);
}
void IterateOp::build(OpBuilder &builder, OperationState &odsState,
Value iterSpace, ValueRange initArgs,
- LevelSet crdUsedLvls) {
+ I64BitSet crdUsedLvls) {
OpBuilder::InsertionGuard guard(builder);
odsState.addOperands(iterSpace);
@@ -2353,7 +2457,7 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand iterSpace;
SmallVector<OpAsmParser::Argument> iters, iterArgs;
- if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+ if (parseSparseIterateLoop(parser, result, iters, iterArgs))
return failure();
if (iters.size() != 1)
return parser.emitError(parser.getNameLoc(),
@@ -2393,51 +2497,39 @@ static void printInitializationList(OpAsmPrinter &p,
p << ")";
}
-static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
- Block::BlockArgListType blocksArgs,
- LevelSet crdUsedLvls) {
- if (crdUsedLvls.empty())
- return;
-
- p << " at(";
- for (unsigned i = 0; i < spaceDim; i++) {
- if (crdUsedLvls[i]) {
- p << blocksArgs.front();
- blocksArgs = blocksArgs.drop_front();
- } else {
- p << "_";
- }
- if (i != spaceDim - 1)
- p << ", ";
+template <typename SparseLoopOp>
+static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
+ if (op.getInitArgs().size() != op.getNumResults()) {
+ return op.emitOpError(
+ "mismatch in number of loop-carried values and defined values");
}
- assert(blocksArgs.empty());
- p << ")";
+ if (op.getCrdUsedLvls().max() > op.getSpaceDim())
+ return op.emitOpError("required out-of-bound coordinates");
+
+ return success();
}
+LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
+LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
+
void IterateOp::print(OpAsmPrinter &p) {
p << " " << getIterator() << " in " << getIterSpace();
- printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+ if (!getCrdUsedLvls().empty()) {
+ p << " at(";
+ printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+ p << ")";
+ }
printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
p << " : " << getIterSpace().getType() << " ";
if (!getInitArgs().empty())
- p << "-> (" << getInitArgs().getTypes() << ") ";
+ p.printArrowTypeList(getInitArgs().getTypes());
+ p << " ";
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/!getInitArgs().empty());
}
-LogicalResult IterateOp::verify() {
- if (getInitArgs().size() != getNumResults()) {
- return emitOpError(
- "mismatch in number of loop-carried values and defined values");
- }
- if (getCrdUsedLvls().max() > getSpaceDim())
- return emitOpError("required out-of-bound coordinates");
-
- return success();
-}
-
LogicalResult IterateOp::verifyRegions() {
if (getIterator().getType() != getIterSpace().getType().getIteratorType())
return emitOpError("mismatch in iterator and iteration space type");
@@ -2495,13 +2587,136 @@ OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
void IterateOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
- // Both the operation itself and the region may be branching into the body or
- // back into the operation itself.
+ // Both the operation itself and the region may be branching into the body
+ // or back into the operation itself.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
// It is possible for loop not to enter the body.
regions.push_back(RegionSuccessor(getResults()));
}
+ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
+
+ SmallVector<Value> spaces;
+ // The block argument list of each regions, it is arranged in the order of
+ // ([used coordinate list], [loop iterations args], [sparse iterator list]).
+ SmallVector<OpAsmParser::Argument> blockArgs;
+ if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
+ return failure();
+
+ result.addAttribute("operandSegmentSizes",
+ parser.getBuilder().getDenseI32ArrayAttr(
+ {static_cast<int32_t>(spaces.size()),
+ static_cast<int32_t>(result.types.size())}));
+
+ SmallVector<Attribute> cases;
+ while (succeeded(parser.parseOptionalKeyword("case"))) {
+ // Parse one region per case.
+ I64BitSet definedItSet;
+ SmallVector<OpAsmParser::Argument> definedIts;
+ if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
+ spaces.size(), OpAsmParser::Delimiter::None))
+ return failure();
+
+ cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
+
+ for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
+ // Resolve the iterator type based on the iteration space type.
+ auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
+ definedIts[i].type = spaceTp.getIteratorType();
+ }
+ definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, definedIts))
+ return failure();
+
+ CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+ }
+
+ result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+void CoIterateOp::print(OpAsmPrinter &p) {
+ p << " (";
+ llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
+ p << ")";
+
+ if (!getCrdUsedLvls().empty()) {
+ p << " at(";
+ printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
+ p << ")";
+ }
+
+ printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
+
+ p << " : (" << getIterSpaces().getTypes() << ")";
+ if (!getInitArgs().empty())
+ p.printArrowTypeList(getInitArgs().getTypes());
+
+ for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
+ p.printNewline();
+ p << "case ";
+ printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
+ getRegionDefinedSpace(idx));
+ p << " ";
+ p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/!getInitArgs().empty());
+ }
+}
+
+ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
+ return cast<sparse_tensor::YieldOp>(
+ getRegion(regionIdx).getBlocks().front().getTerminator())
+ .getResults();
+}
+
+LogicalResult CoIterateOp::verifyRegions() {
+ for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
+ if (getNumRegionIterArgs(r) != getNumResults())
+ return emitOpError(
+ "mismatch in number of basic block args and defined values");
+
+ auto initArgs = getInitArgs();
+ auto iterArgs = getRegionIterArgs(r);
+ auto yieldVals = getYieldedValues(r);
+ auto opResults = getResults();
+ if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+ opResults.size()})) {
+ return emitOpError()
+ << "number mismatch between iter args and results on " << r
+ << "th region";
+ }
+
+ for (auto [i, init, iter, yield, ret] :
+ llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+ if (init.getType() != ret.getType())
+ return emitOpError()
+ << "types mismatch between " << i
+ << "th iter operand and defined value on " << r << "th region";
+ if (iter.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter region arg and defined value on " << r
+ << "th region";
+ if (yield.getType() != ret.getType())
+ return emitOpError()
+ << "types mismatch between " << i
+ << "th yield value and defined value on " << r << "th region";
+ }
+ }
+
+ auto cases = getRegionDefinedSpaces();
+ llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
+ if (set.size() != getNumRegions())
+ return emitOpError("contains duplicated cases.");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 924046fcd9961..f85c4761a8d52 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -141,10 +141,10 @@ void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
builder.setInsertionPointToStart(cloned.getBody());
- LevelSet crdUsedLvls;
+ I64BitSet crdUsedLvls;
unsigned shift = 0, argIdx = 1;
for (auto info : toCollapse.drop_back()) {
- LevelSet set = info.loop.getCrdUsedLvls();
+ I64BitSet set = info.loop.getCrdUsedLvls();
crdUsedLvls |= set.lshift(shift);
shift += info.loop.getSpaceDim();
for (BlockArgument crd : info.loop.getCrds()) {
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 61cc9be88685c..737b736ba795f 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1191,3 +1191,78 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
}
return %r1 : index
}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+ %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+ %init = arith.constant 0 : index
+ // expected-error @+1 {{'sparse_tensor.coiterate' op contains duplicated cases.}}
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
+ : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+ -> index
+ case %it1, _ {
+ sparse_tensor.yield %arg : index
+ }
+ case %it1, _ {
+ sparse_tensor.yield %arg : index
+ }
+ return %ret : index
+}
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+ %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+ %init = arith.constant 0 : index
+ // expected-error @+1 {{'sparse_tensor.coiterate' op types mismatch between 0th yield value and defined value on 0th region}}
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
+ : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+ -> index
+ case %it1, _ {
+ %i = arith.constant 1 : i32
+ sparse_tensor.yield %i : i32
+ }
+ return %ret : index
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+ %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+ %init = arith.constant 0 : index
+ // expected-error @+1 {{'sparse_tensor.coiterate' op required out-of-bound coordinates}}
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord1, %coord2) iter_args(%arg = %init)
+ : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+ -> index
+ case %it1, _ {
+ %i = arith.constant 1 : i32
+ sparse_tensor.yield %i : i32
+ }
+ return %ret : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 055709ee69eb7..ab861a2019dfa 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -801,7 +801,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
// CHECK-SAME: %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: index) -> index {
// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
-// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> index {
// CHECK: sparse_tensor.yield %[[VAL_7]] : index
// CHECK: }
// CHECK: return %[[VAL_4]] : index
@@ -813,3 +813,36 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
}
return %r1 : index
}
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+
+// CHECK-LABEL: func.func @sparse_coiteration(
+// CHECK-SAME: %[[SP1:.*]]: !sparse_tensor.iter_space<#sparse, lvls = 0>,
+// CHECK-SAME: %[[SP2:.*]]: !sparse_tensor.iter_space<#sparse, lvls = 1>) -> index {
+// CHECK: %[[INIT:.*]] = arith.constant 0 : index
+// CHECK: %[[RET:.*]] = sparse_tensor.coiterate (%[[SP1]], %[[SP2]]) at(%[[COORD:.*]]) iter_args(%[[ARG:.*]] = %[[INIT]])
+// CHECK: case %[[VAL_6:.*]], _ {
+// CHECK: sparse_tensor.yield %[[ARG]] : index
+// CHECK: }
+// CHECK: return %[[RET]] : index
+// CHECK: }
+func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
+ %sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
+ %init = arith.constant 0 : index
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
+ : (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
+ -> index
+ case %it1, _ {
+ sparse_tensor.yield %arg : index
+ }
+ return %ret : index
+}
More information about the Mlir-commits
mailing list