[Mlir-commits] [mlir] [mlir][sparse] allow YieldOp to yield multiple values. (PR #87261)
Peiming Liu
llvmlistbot at llvm.org
Mon Apr 1 09:53:35 PDT 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/87261
None
>From ff870f6774ccaf534dff450b34c22821ddd2af91 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 1 Apr 2024 16:52:38 +0000
Subject: [PATCH] [mlir][sparse] allow YieldOp to yield multiple values.
---
.../SparseTensor/IR/SparseTensorOps.td | 22 ++++++++++++++-----
.../SparseTensor/IR/SparseTensorDialect.cpp | 18 ++++-----------
.../Transforms/SparseReinterpretMap.cpp | 5 +++--
.../lib/Dialect/SparseTensor/Utils/Merger.cpp | 4 ++--
4 files changed, 26 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..29e5ac749e9fa3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1278,8 +1278,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
let hasVerifier = 1;
}
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
- Arguments<(ins Optional<AnyType>:$result)> {
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+ ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
+ "ForeachOp"]>]>,
+ Arguments<(ins Variadic<AnyType>:$results)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1302,14 +1304,24 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
let builders = [
OpBuilder<(ins),
[{
- build($_builder, $_state, Value());
+ build($_builder, $_state, ValueRange());
+ }]>,
+ OpBuilder<(ins "Value":$yieldVal),
+ [{
+ build($_builder, $_state, ValueRange(yieldVal));
}]>
];
+ let extraClassDeclaration = [{
+ Value getSingleResult() {
+ assert(getResults().size() == 1);
+ return getResults().front();
+ }
+ }];
+
let assemblyFormat = [{
- $result attr-dict `:` type($result)
+ $results attr-dict `:` type($results)
}];
- let hasVerifier = 1;
}
def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..3f385d24daf23f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1591,7 +1591,8 @@ static LogicalResult verifyNumBlockArgs(T *op, Region ®ion,
if (!yield)
return op->emitError() << regionName
<< " region must end with sparse_tensor.yield";
- if (!yield.getResult() || yield.getResult().getType() != outputType)
+ if (!yield.getSingleResult() ||
+ yield.getSingleResult().getType() != outputType)
return op->emitError() << regionName << " region yield type mismatch";
return success();
@@ -1654,7 +1655,8 @@ LogicalResult UnaryOp::verify() {
// Absent branch can only yield invariant values.
Block *absentBlock = &absent.front();
Block *parent = getOperation()->getBlock();
- Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
+ Value absentVal =
+ cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
if (arg.getOwner() == parent)
return emitError("absent region cannot yield linalg argument");
@@ -1907,18 +1909,6 @@ LogicalResult SortOp::verify() {
return success();
}
-LogicalResult YieldOp::verify() {
- // Check for compatible parent.
- auto *parentOp = (*this)->getParentOp();
- if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
- isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
- isa<ForeachOp>(parentOp))
- return success();
-
- return emitOpError("expected parent op to be sparse_tensor unary, binary, "
- "reduce, select or foreach");
-}
-
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 14ea07f0b54b82..9c0fc60877d8a3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -764,9 +764,10 @@ struct ForeachOpDemapper
if (numInitArgs != 0) {
rewriter.setInsertionPointToEnd(body);
auto yield = llvm::cast<YieldOp>(body->getTerminator());
- if (auto stt = tryGetSparseTensorType(yield.getResult());
+ if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
stt && !stt->isIdentity()) {
- Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
+ Value y =
+ genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
rewriter.create<YieldOp>(loc, y);
rewriter.eraseOp(yield);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 72b722c69ae34b..9c0aed3c18eff2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1031,7 +1031,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// invariant on the right.
Block &absentBlock = absentRegion.front();
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
- const Value absentVal = absentYield.getResult();
+ const Value absentVal = absentYield.getSingleResult();
const ExprId rhs = addInvariantExp(absentVal);
return disjSet(e, child0, buildLattices(rhs, i), unop);
}
@@ -1500,7 +1500,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion,
// Merge cloned block and return yield value.
Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
- Value val = clonedYield.getResult();
+ Value val = clonedYield.getSingleResult();
rewriter.eraseOp(clonedYield);
rewriter.eraseOp(placeholder);
return val;
More information about the Mlir-commits
mailing list