[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)
Ingo Müller
llvmlistbot at llvm.org
Fri Mar 22 07:19:11 PDT 2024
================
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
- isa<ForeachOp>(parentOp))
+ isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
return success();
return emitOpError("expected parent op to be sparse_tensor unary, binary, "
"reduce, select or foreach");
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+ IntegerAttr &lvlHiAttr) {
+ Level lvlLo, lvlHi;
+
+ if (parser.parseInteger(lvlLo))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ if (parser.parseInteger(lvlHi))
+ return failure();
+ } else {
+ lvlHi = lvlLo + 1;
+ }
+ lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+ lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+ return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+ IntegerAttr lvlLo, IntegerAttr lvlHi) {
+ if (op.getLoLvl() + 1 == op.getHiLvl())
+ p << op.getLoLvl();
+ else
+ p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+ MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+ DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+
+ ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+ SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+ ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+ adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+ ret.push_back(IterSpaceType::get(ctx, lts));
+ return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+ SparseTensorType stt = getSparseTensorType(getTensor());
+ if (getLoLvl() >= getHiLvl())
+ return emitOpError("expected smaller level low than level high");
+
+ ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+ if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+ return emitOpError(
+ "mismatch in iteration space level types and tensor level types");
+ }
+
+ TypedValue<IteratorType> pIter = getParentIter();
+ if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+ return emitOpError("parent iterator is only needed iff level low equals 0");
+ }
+
+ if (pIter) {
+ unsigned pDim = pIter.getType().getSpaceDim();
+ if (getLoLvl() < pDim || !stt.getLvlTypes()
+ .slice(getLoLvl() - pDim, pDim)
+ .equals(pIter.getType().getLvlTypes())) {
+ return emitOpError(
+ "mismatch in parent iterator level types and tensor level types");
+ }
+ }
+
+ return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::Argument iterator;
+ OpAsmParser::UnresolvedOperand iterSpace;
+
+ // Parses %iters in %spaces
+ if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+ parser.parseOperand(iterSpace)) {
+ return failure();
+ }
+
+ // Parse the optional initial iteration arguments.
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ // Region arguments starts with iterators and follows by optional
+ // user-provided iter_args.
+ regionArgs.push_back(iterator);
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(regionArgs, operands))
+ return failure();
+
+ // parse ": sparse_tensor.iter_space -> ret"
+ Type iterSpaceTps;
+ if (parser.parseColon() || parser.parseType(iterSpaceTps))
+ return failure();
+ if (hasIterArgs)
+ if (parser.parseArrowTypeList(result.types))
+ return failure();
+
+ if (regionArgs.size() != result.types.size() + 1) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of loop-carried values and defined values");
+ }
+
+ // Resolves input operands.
+ if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
+ return failure();
+
+ if (hasIterArgs) {
+ for (auto argOperandType :
+ llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+ Type type = std::get<2>(argOperandType);
+ std::get<0>(argOperandType).type = type;
+ if (parser.resolveOperand(std::get<1>(argOperandType), type,
+ result.operands))
+ return failure();
+ }
+ }
+
+ Region *body = result.addRegion();
+ regionArgs.front().type =
+ iterSpaceTps.cast<IterSpaceType>().getIteratorType();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+
+ IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+/// Prints the initialization list in the form of
+/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ p << prefix << '(';
+ llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it);
+ });
+ p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+ p << " " << getIterator() << " in " << getIterSpace();
+
+ printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+ p << " : " << getIterSpace().getType() << " ";
+ if (!getInitArgs().empty())
+ p << "-> (" << getInitArgs().getTypes() << ") ";
+
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+ if (getInitArgs().size() != getNumResults()) {
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ }
+ return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+ if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+ return emitOpError("mismatch in iterator and iteration space type");
+ if (getNumRegionIterArgs() != getNumResults())
+ return emitOpError(
+ "mismatch in number of basic block args and defined values");
+
+ auto initArgs = getInitArgs();
----------------
ingomueller-net wrote:
Here and below: I the [LLVM coding style](https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable) says that this type should be spelled out. (Here, I guess the type is an `ArrayRef<...>`, which is neither "obvious from the context" nor "abstracted away anyways"...)
https://github.com/llvm/llvm-project/pull/85958
More information about the Mlir-commits
mailing list