[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)
Yinying Li
llvmlistbot at llvm.org
Wed Mar 20 10:37:24 PDT 2024
================
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
- isa<ForeachOp>(parentOp))
+ isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
return success();
return emitOpError("expected parent op to be sparse_tensor unary, binary, "
"reduce, select or foreach");
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+ IntegerAttr &lvlHiAttr) {
+ Level lvlLo, lvlHi;
+
+ if (parser.parseInteger(lvlLo))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ if (parser.parseInteger(lvlHi))
+ return failure();
+ } else {
+ lvlHi = lvlLo + 1;
+ }
+ lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+ lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+ return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+ IntegerAttr lvlLo, IntegerAttr lvlHi) {
+ if (op.getLoLvl() + 1 == op.getHiLvl())
+ p << op.getLoLvl();
+ else
+ p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+ MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+ DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+
+ ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+ SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+ ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+ adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+ ret.push_back(IterSpaceType::get(ctx, lts));
+ return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+ SparseTensorType stt = getSparseTensorType(getTensor());
+ if (getLoLvl() >= getHiLvl())
+ return emitOpError("expected smaller level low than level high");
+
+ ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+ if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+ return emitOpError(
+ "mismatch in iteration space level types and tensor level types");
+ }
+
+ TypedValue<IteratorType> pIter = getParentIter();
+ if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+ return emitOpError("parent iterator is only needed iff level low equals 0");
+ }
+
+ if (pIter) {
+ unsigned pDim = pIter.getType().getSpaceDim();
+ if (getLoLvl() < pDim || !stt.getLvlTypes()
+ .slice(getLoLvl() - pDim, pDim)
+ .equals(pIter.getType().getLvlTypes())) {
+ return emitOpError(
+ "mismatch in parent iterator level types and tensor level types");
+ }
+ }
+
+ return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::Argument iterator;
+ OpAsmParser::UnresolvedOperand iterSpace;
+
+ // Parses %iters in %spaces
----------------
yinying-lisa-li wrote:
nit: Parse and period in the end.
https://github.com/llvm/llvm-project/pull/85958
More information about the Mlir-commits
mailing list