[Mlir-commits] [mlir] a54930e - [mlir][sparse] allow YieldOp to yield multiple values. (#87261)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 1 10:30:39 PDT 2024


Author: Peiming Liu
Date: 2024-04-01T10:30:36-07:00
New Revision: a54930e696a275ac3947484f44d770cd587ce147

URL: https://github.com/llvm/llvm-project/commit/a54930e696a275ac3947484f44d770cd587ce147
DIFF: https://github.com/llvm/llvm-project/commit/a54930e696a275ac3947484f44d770cd587ce147.diff

LOG: [mlir][sparse] allow YieldOp to yield multiple values. (#87261)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..5df8a176459b7c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1278,8 +1278,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
   let hasVerifier = 1;
 }
 
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
-    Arguments<(ins Optional<AnyType>:$result)> {
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+    ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
+                 "ForeachOp"]>]>,
+    Arguments<(ins Variadic<AnyType>:$results)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1302,14 +1304,27 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
   let builders = [
     OpBuilder<(ins),
     [{
-      build($_builder, $_state, Value());
+      build($_builder, $_state, ValueRange());
+    }]>,
+    OpBuilder<(ins "Value":$yieldVal),
+    [{
+      build($_builder, $_state, ValueRange(yieldVal));
     }]>
   ];
 
+  let extraClassDeclaration = [{
+     Value getSingleResult() {
+        assert(hasSingleResult());
+        return getResults().front();
+     }
+     bool hasSingleResult() {
+        return getResults().size() == 1;
+     }
+  }];
+
   let assemblyFormat = [{
-        $result attr-dict `:` type($result)
+        $results attr-dict `:` type($results)
   }];
-  let hasVerifier = 1;
 }
 
 def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..e4d93c5623b9c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1591,7 +1591,8 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
   if (!yield)
     return op->emitError() << regionName
                            << " region must end with sparse_tensor.yield";
-  if (!yield.getResult() || yield.getResult().getType() != outputType)
+  if (!yield.hasSingleResult() ||
+      yield.getSingleResult().getType() != outputType)
     return op->emitError() << regionName << " region yield type mismatch";
 
   return success();
@@ -1654,7 +1655,8 @@ LogicalResult UnaryOp::verify() {
     // Absent branch can only yield invariant values.
     Block *absentBlock = &absent.front();
     Block *parent = getOperation()->getBlock();
-    Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
+    Value absentVal =
+        cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
       if (arg.getOwner() == parent)
         return emitError("absent region cannot yield linalg argument");
@@ -1907,18 +1909,6 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
-LogicalResult YieldOp::verify() {
-  // Check for compatible parent.
-  auto *parentOp = (*this)->getParentOp();
-  if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
-      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp))
-    return success();
-
-  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
-                     "reduce, select or foreach");
-}
-
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 14ea07f0b54b82..9c0fc60877d8a3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -764,9 +764,10 @@ struct ForeachOpDemapper
     if (numInitArgs != 0) {
       rewriter.setInsertionPointToEnd(body);
       auto yield = llvm::cast<YieldOp>(body->getTerminator());
-      if (auto stt = tryGetSparseTensorType(yield.getResult());
+      if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
           stt && !stt->isIdentity()) {
-        Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
+        Value y =
+            genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
         rewriter.create<YieldOp>(loc, y);
         rewriter.eraseOp(yield);
       }

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 72b722c69ae34b..9c0aed3c18eff2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1031,7 +1031,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
       // invariant on the right.
       Block &absentBlock = absentRegion.front();
       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
-      const Value absentVal = absentYield.getResult();
+      const Value absentVal = absentYield.getSingleResult();
       const ExprId rhs = addInvariantExp(absentVal);
       return disjSet(e, child0, buildLattices(rhs, i), unop);
     }
@@ -1500,7 +1500,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
   // Merge cloned block and return yield value.
   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
-  Value val = clonedYield.getResult();
+  Value val = clonedYield.getSingleResult();
   rewriter.eraseOp(clonedYield);
   rewriter.eraseOp(placeholder);
   return val;


        


More information about the Mlir-commits mailing list