[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.iterate` operation (PR #88955)
Aart Bik
llvmlistbot at llvm.org
Wed Apr 17 14:42:44 PDT 2024
================
@@ -2063,6 +2163,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::Argument iterator;
+ OpAsmParser::UnresolvedOperand iterSpace;
+
+ SmallVector<OpAsmParser::Argument> iters, iterArgs;
+ if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+ return failure();
+ if (iters.size() != 1)
+ return parser.emitError(parser.getNameLoc(),
+ "expected only one iterator/iteration space");
+
+ iters.append(iterArgs);
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, iters))
+ return failure();
+
+ IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+/// Prints the initialization list in the form of
+/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ p << prefix << '(';
+ llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+ p << std::get<0>(it) << " = " << std::get<1>(it);
+ });
+ p << ")";
+}
+
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+ Block::BlockArgListType blocksArgs,
+ LevelSet crdUsedLvls) {
+ if (crdUsedLvls.empty())
+ return;
+
+ p << " at(";
+ for (unsigned i = 0; i < spaceDim; i++) {
+ if (crdUsedLvls[i]) {
+ p << blocksArgs.front();
+ blocksArgs = blocksArgs.drop_front();
+ } else {
+ p << "_";
+ }
+ if (i != spaceDim - 1)
+ p << ", ";
+ }
+ assert(blocksArgs.empty());
+ p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+ p << " " << getIterator() << " in " << getIterSpace();
+ printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+ printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+ p << " : " << getIterSpace().getType() << " ";
+ if (!getInitArgs().empty())
+ p << "-> (" << getInitArgs().getTypes() << ") ";
+
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+ if (getInitArgs().size() != getNumResults()) {
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ }
+ return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+ if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+ return emitOpError("mismatch in iterator and iteration space type");
+ if (getNumRegionIterArgs() != getNumResults())
+ return emitOpError(
+ "mismatch in number of basic block args and defined values");
+
+ auto initArgs = getInitArgs();
+ auto iterArgs = getRegionIterArgs();
+ auto yieldVals = getYieldedValues();
+ auto opResults = getResults();
+ if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+ opResults.size()})) {
+ return emitOpError() << "number mismatch between iter args and results.";
+ }
+
+ for (auto [i, init, iter, yield, ret] :
+ llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+ if (init.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter operand and defined value";
+ if (iter.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th iter region arg and defined value";
+ if (yield.getType() != ret.getType())
+ return emitOpError() << "types mismatch between " << i
+ << "th yield value and defined value";
+ }
+
+ return success();
+}
+
+/// IterateOp implemented OpInterfaces' methods.
----------------
aartbik wrote:
why paste tense?
https://github.com/llvm/llvm-project/pull/88955
More information about the Mlir-commits
mailing list