[Mlir-commits] [mlir] [mlir][sparse] allow YieldOp to yield multiple values. (PR #87261)

Peiming Liu llvmlistbot at llvm.org
Mon Apr 1 10:13:31 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/87261

>From ff870f6774ccaf534dff450b34c22821ddd2af91 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 1 Apr 2024 16:52:38 +0000
Subject: [PATCH 1/2] [mlir][sparse] allow YieldOp to yield multiple values.

---
 .../SparseTensor/IR/SparseTensorOps.td        | 22 ++++++++++++++-----
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 18 ++++-----------
 .../Transforms/SparseReinterpretMap.cpp       |  5 +++--
 .../lib/Dialect/SparseTensor/Utils/Merger.cpp |  4 ++--
 4 files changed, 26 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..29e5ac749e9fa3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1278,8 +1278,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
   let hasVerifier = 1;
 }
 
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
-    Arguments<(ins Optional<AnyType>:$result)> {
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+    ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
+                 "ForeachOp"]>]>,
+    Arguments<(ins Variadic<AnyType>:$results)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1302,14 +1304,24 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
   let builders = [
     OpBuilder<(ins),
     [{
-      build($_builder, $_state, Value());
+      build($_builder, $_state, ValueRange());
+    }]>,
+    OpBuilder<(ins "Value":$yieldVal),
+    [{
+      build($_builder, $_state, ValueRange(yieldVal));
     }]>
   ];
 
+  let extraClassDeclaration = [{
+     Value getSingleResult() {
+        assert(getResults().size() == 1);
+        return getResults().front();
+     }
+  }];
+
   let assemblyFormat = [{
-        $result attr-dict `:` type($result)
+        $results attr-dict `:` type($results)
   }];
-  let hasVerifier = 1;
 }
 
 def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..3f385d24daf23f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1591,7 +1591,8 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
   if (!yield)
     return op->emitError() << regionName
                            << " region must end with sparse_tensor.yield";
-  if (!yield.getResult() || yield.getResult().getType() != outputType)
+  if (!yield.getSingleResult() ||
+      yield.getSingleResult().getType() != outputType)
     return op->emitError() << regionName << " region yield type mismatch";
 
   return success();
@@ -1654,7 +1655,8 @@ LogicalResult UnaryOp::verify() {
     // Absent branch can only yield invariant values.
     Block *absentBlock = &absent.front();
     Block *parent = getOperation()->getBlock();
-    Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
+    Value absentVal =
+        cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
       if (arg.getOwner() == parent)
         return emitError("absent region cannot yield linalg argument");
@@ -1907,18 +1909,6 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
-LogicalResult YieldOp::verify() {
-  // Check for compatible parent.
-  auto *parentOp = (*this)->getParentOp();
-  if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
-      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp))
-    return success();
-
-  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
-                     "reduce, select or foreach");
-}
-
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 14ea07f0b54b82..9c0fc60877d8a3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -764,9 +764,10 @@ struct ForeachOpDemapper
     if (numInitArgs != 0) {
       rewriter.setInsertionPointToEnd(body);
       auto yield = llvm::cast<YieldOp>(body->getTerminator());
-      if (auto stt = tryGetSparseTensorType(yield.getResult());
+      if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
           stt && !stt->isIdentity()) {
-        Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
+        Value y =
+            genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
         rewriter.create<YieldOp>(loc, y);
         rewriter.eraseOp(yield);
       }
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 72b722c69ae34b..9c0aed3c18eff2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1031,7 +1031,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
       // invariant on the right.
       Block &absentBlock = absentRegion.front();
       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
-      const Value absentVal = absentYield.getResult();
+      const Value absentVal = absentYield.getSingleResult();
       const ExprId rhs = addInvariantExp(absentVal);
       return disjSet(e, child0, buildLattices(rhs, i), unop);
     }
@@ -1500,7 +1500,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
   // Merge cloned block and return yield value.
   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
-  Value val = clonedYield.getResult();
+  Value val = clonedYield.getSingleResult();
   rewriter.eraseOp(clonedYield);
   rewriter.eraseOp(placeholder);
   return val;

>From bc6b8d2b46f18909802caa0b2b3d29f95d078a0c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 1 Apr 2024 17:12:55 +0000
Subject: [PATCH 2/2] address comments (add )

---
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 3 +++
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp     | 2 +-
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29e5ac749e9fa3..15e3ae68f5f742 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1317,6 +1317,9 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
         assert(getResults().size() == 1);
         return getResults().front();
      }
+     bool hasSingleResult() {
+        return getResults().size() == 1;
+     }
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 3f385d24daf23f..e4d93c5623b9c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1591,7 +1591,7 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
   if (!yield)
     return op->emitError() << regionName
                            << " region must end with sparse_tensor.yield";
-  if (!yield.getSingleResult() ||
+  if (!yield.hasSingleResult() ||
       yield.getSingleResult().getType() != outputType)
     return op->emitError() << regionName << " region yield type mismatch";
 



More information about the Mlir-commits mailing list