[Mlir-commits] [mlir] 296e97a - [MLIR] Support for return values in Affine.For yield

Uday Bondhugula llvmlistbot at llvm.org
Thu Sep 17 11:05:40 PDT 2020


Author: Abhishek Varma
Date: 2020-09-17T23:34:59+05:30
New Revision: 296e97ae8f7183c2f8737b9e6e68df4904dbfadf

URL: https://github.com/llvm/llvm-project/commit/296e97ae8f7183c2f8737b9e6e68df4904dbfadf
DIFF: https://github.com/llvm/llvm-project/commit/296e97ae8f7183c2f8737b9e6e68df4904dbfadf.diff

LOG: [MLIR] Support for return values in Affine.For yield

Add support for return values in affine.for yield along the same lines
as scf.for and affine.parallel.

Signed-off-by: Abhishek Varma <abhishek.varma at polymagelabs.com>

Differential Revision: https://reviews.llvm.org/D87437

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/EDSC/Builders.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/lib/Dialect/Affine/EDSC/Builders.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/invalid.mlir
    mlir/test/Dialect/Affine/ops.mlir
    mlir/test/EDSC/builder-api-test.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h
index 96191e01296a..d99f29f3b5ba 100644
--- a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h
@@ -47,6 +47,18 @@ void affineLoopNestBuilder(
 void affineLoopBuilder(ValueRange lbs, ValueRange ubs, int64_t step,
                        function_ref<void(Value)> bodyBuilderFn = nullptr);
 
+/// Creates a single affine "for" loop, iterating from max(lbs) to min(ubs) with
+/// the given step. Uses the OpBuilder and Location stored in ScopedContext and
+/// assumes they are non-null. "iterArgs" is used to specify the initial values
+/// of the result affine "for" might yield. The optional "bodyBuilderFn"
+/// callback is called to construct the body of the loop and is passed the
+/// induction variable and the iteration arguments. The function is expected to
+/// use the builder and location stored in ScopedContext at the moment of the
+/// call. The function will create the affine terminator op in case "iterArgs"
+/// is empty and "bodyBuilderFn" is not present.
+void affineLoopBuilder(
+    ValueRange lbs, ValueRange ubs, int64_t step, ValueRange iterArgs,
+    function_ref<void(Value, ValueRange)> bodyBuilderFn = nullptr);
 namespace op {
 
 Value operator+(Value lhs, Value rhs);

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 480e1717c588..88c4a6fda7f4 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -174,30 +174,74 @@ def AffineForOp : Affine_Op<"for",
       return
     }
     ```
+    `affine.for` can also operate on loop-carried variables and return the final
+    values after loop termination. The initial values of the variables are
+    passed as additional SSA operands to the "affine.for" following the 2 loop
+    control values lower bound, upper bound. The operation region has equivalent
+    arguments for each variable representing the value of the variable at the
+    current iteration.
+
+    The region must terminate with an `affine.yield` that passes all the current
+    iteration variables to the next iteration, or to the `affine.for` result, if
+    at the last iteration.
+
+    `affine.for` results hold the final values after the last iteration.
+    For example, to sum-reduce a memref:
+
+     ```mlir
+    func @reduce(%buffer: memref<1024xf32>) -> (f32) {
+      // Initial sum set to 0.
+      %sum_0 = constant 0.0 : f32
+      // iter_args binds initial values to the loop's region arguments.
+      %sum = affine.for %i = 0 to 10 step 2
+          iter_args(%sum_iter = %sum_0) -> (f32) {
+        %t = affine.load %buffer[%i] : memref<1024xf32>
+        %sum_next = addf %sum_iter, %t : f32
+        // Yield current iteration sum to next iteration %sum_iter or to %sum
+        // if final iteration.
+        affine.yield %sum_next : f32
+      }
+      return %sum : f32
+    }
+    ```
+    If the `affine.for` defines any values, a yield terminator must be
+    explicitly present. The number and types of the "affine.for" results must
+    match the initial values in the `iter_args` binding and the yield operands.
   }];
   let arguments = (ins Variadic<AnyType>);
+  let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<"OpBuilder &builder, OperationState &result, "
               "int64_t lowerBound, int64_t upperBound, int64_t step = 1, "
-              "function_ref<void(OpBuilder &, Location, Value)> bodyBuilder "
-              "    = nullptr">,
+              "ValueRange iterArgs = llvm::None, function_ref<void(OpBuilder "
+              "&, Location, Value, ValueRange)> bodyBuilder = nullptr">,
     OpBuilder<"OpBuilder &builder, OperationState &result, "
               "ValueRange lbOperands, AffineMap lbMap, "
               "ValueRange ubOperands, AffineMap ubMap, "
-              "int64_t step = 1, "
-              "function_ref<void(OpBuilder &, Location, Value)> bodyBuilder "
-              "    = nullptr">
+              "int64_t step = 1, ValueRange iterArgs = llvm::None, "
+              "function_ref<void(OpBuilder &, Location, Value, ValueRange)> "
+              "bodyBuilder = nullptr">
   ];
 
   let extraClassDeclaration = [{
+    /// Defining the function type we use for building the body of affine.for.
+    using BodyBuilderFn =
+        function_ref<void(OpBuilder &, Location, Value, ValueRange)>;
+
     static StringRef getStepAttrName() { return "step"; }
     static StringRef getLowerBoundAttrName() { return "lower_bound"; }
     static StringRef getUpperBoundAttrName() { return "upper_bound"; }
 
     Value getInductionVar() { return getBody()->getArgument(0); }
+    Block::BlockArgListType getRegionIterArgs() {
+      return getBody()->getArguments().drop_front();
+    }
+    Operation::operand_range getIterOperands() {
+      return getOperands().drop_front(getNumControlOperands());
+    }
 
     // TODO: provide iterators for the lower and upper bound operands
     // if the current access via getLowerBound(), getUpperBound() is too slow.
@@ -251,6 +295,17 @@ def AffineForOp : Affine_Op<"for",
               IntegerAttr::get(IndexType::get(context), step));
     }
 
+    /// Returns number of region arguments for loop-carried values.
+    unsigned getNumRegionIterArgs() {
+      return getBody()->getNumArguments() - 1;
+    }
+
+    /// Number of operands controlling the loop: lb and ub.
+    unsigned getNumControlOperands() { return getOperation()->getNumOperands() - getNumIterOperands(); }
+
+    /// Get the number of loop-carried values.
+    unsigned getNumIterOperands();
+
     /// Returns true if the lower bound is constant.
     bool hasConstantLowerBound();
     /// Returns true if the upper bound is constant.
@@ -540,7 +595,7 @@ def AffineMaxOp : AffineMinMaxOpBase<"max", [NoSideEffect]> {
   }];
 }
 
-def AffineParallelOp : Affine_Op<"parallel", 
+def AffineParallelOp : Affine_Op<"parallel",
     [ImplicitAffineTerminator, RecursiveSideEffects,
      DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
   let summary = "multi-index parallel band operation";
@@ -569,7 +624,7 @@ def AffineParallelOp : Affine_Op<"parallel",
 
     Note: Calling AffineParallelOp::build will create the required region and
     block, and insert the required terminator if it is trivial (i.e. no values
-    are yielded).  Parsing will also create the required region, block, and 
+    are yielded).  Parsing will also create the required region, block, and
     terminator, even when they are missing from the textual representation.
 
     Example (3x3 valid convolution):

diff  --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
index a96ba970afde..11926d26368b 100644
--- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
@@ -47,8 +47,9 @@ void mlir::edsc::affineLoopBuilder(ValueRange lbs, ValueRange ubs, int64_t step,
   // updating the scoped context.
   builder.create<AffineForOp>(
       loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
-      builder.getMultiDimIdentityMap(ubs.size()), step,
-      [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv) {
+      builder.getMultiDimIdentityMap(ubs.size()), step, llvm::None,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
+          ValueRange itrArgs) {
         if (bodyBuilderFn) {
           ScopedContext nestedContext(nestedBuilder, nestedLoc);
           OpBuilder::InsertionGuard guard(nestedBuilder);
@@ -58,6 +59,30 @@ void mlir::edsc::affineLoopBuilder(ValueRange lbs, ValueRange ubs, int64_t step,
       });
 }
 
+void mlir::edsc::affineLoopBuilder(
+    ValueRange lbs, ValueRange ubs, int64_t step, ValueRange iterArgs,
+    function_ref<void(Value, ValueRange)> bodyBuilderFn) {
+  // Fetch the builder and location.
+  assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
+  OpBuilder &builder = ScopedContext::getBuilderRef();
+  Location loc = ScopedContext::getLocation();
+
+  // Create the actual loop and call the body builder, if provided, after
+  // updating the scoped context.
+  builder.create<AffineForOp>(
+      loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
+      builder.getMultiDimIdentityMap(ubs.size()), step, iterArgs,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
+          ValueRange itrArgs) {
+        if (bodyBuilderFn) {
+          ScopedContext nestedContext(nestedBuilder, nestedLoc);
+          OpBuilder::InsertionGuard guard(nestedBuilder);
+          bodyBuilderFn(iv, itrArgs);
+        } else if (itrArgs.empty())
+          nestedBuilder.create<AffineYieldOp>(nestedLoc);
+      });
+}
+
 static std::pair<AffineExpr, Value>
 categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
                             unsigned &numSymbols) {

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f3473859e88c..440875db3918 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1173,10 +1173,12 @@ LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
 // AffineForOp
 //===----------------------------------------------------------------------===//
 
-void AffineForOp::build(
-    OpBuilder &builder, OperationState &result, ValueRange lbOperands,
-    AffineMap lbMap, ValueRange ubOperands, AffineMap ubMap, int64_t step,
-    function_ref<void(OpBuilder &, Location, Value)> bodyBuilder) {
+/// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
+/// bodyBuilder are empty/null, we include default terminator op.
+void AffineForOp::build(OpBuilder &builder, OperationState &result,
+                        ValueRange lbOperands, AffineMap lbMap,
+                        ValueRange ubOperands, AffineMap ubMap, int64_t step,
+                        ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
   assert(((!lbMap && lbOperands.empty()) ||
           lbOperands.size() == lbMap.getNumInputs()) &&
          "lower bound operand count does not match the affine map");
@@ -1185,6 +1187,9 @@ void AffineForOp::build(
          "upper bound operand count does not match the affine map");
   assert(step > 0 && "step has to be a positive integer constant");
 
+  for (Value val : iterArgs)
+    result.addTypes(val.getType());
+
   // Add an attribute for the step.
   result.addAttribute(getStepAttrName(),
                       builder.getIntegerAttr(builder.getIndexType(), step));
@@ -1197,56 +1202,75 @@ void AffineForOp::build(
   result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
   result.addOperands(ubOperands);
 
+  result.addOperands(iterArgs);
   // Create a region and a block for the body.  The argument of the region is
   // the loop induction variable.
   Region *bodyRegion = result.addRegion();
-  Block *body = new Block;
-  Value inductionVar = body->addArgument(IndexType::get(builder.getContext()));
-  bodyRegion->push_back(body);
-  if (bodyBuilder) {
-    OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToStart(body);
-    bodyBuilder(builder, result.location, inductionVar);
-  } else {
+  bodyRegion->push_back(new Block);
+  Block &bodyBlock = bodyRegion->front();
+  Value inductionVar = bodyBlock.addArgument(builder.getIndexType());
+  for (Value val : iterArgs)
+    bodyBlock.addArgument(val.getType());
+
+  // Create the default terminator if the builder is not provided and if the
+  // iteration arguments are not provided. Otherwise, leave this to the caller
+  // because we don't know which values to return from the loop.
+  if (iterArgs.empty() && !bodyBuilder) {
     ensureTerminator(*bodyRegion, builder, result.location);
+  } else if (bodyBuilder) {
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(&bodyBlock);
+    bodyBuilder(builder, result.location, inductionVar,
+                bodyBlock.getArguments().drop_front());
   }
 }
 
-void AffineForOp::build(
-    OpBuilder &builder, OperationState &result, int64_t lb, int64_t ub,
-    int64_t step,
-    function_ref<void(OpBuilder &, Location, Value)> bodyBuilder) {
+void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
+                        int64_t ub, int64_t step, ValueRange iterArgs,
+                        BodyBuilderFn bodyBuilder) {
   auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
   auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
-  return build(builder, result, {}, lbMap, {}, ubMap, step, bodyBuilder);
+  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
+               bodyBuilder);
 }
 
 static LogicalResult verify(AffineForOp op) {
   // Check that the body defines as single block argument for the induction
   // variable.
   auto *body = op.getBody();
-  if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
+  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
     return op.emitOpError(
         "expected body to have a single index argument for the "
         "induction variable");
 
-  // Verify that there are enough operands for the bounds.
-  AffineMap lowerBoundMap = op.getLowerBoundMap(),
-            upperBoundMap = op.getUpperBoundMap();
-  if (op.getNumOperands() !=
-      (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs()))
-    return op.emitOpError(
-        "operand count must match with affine map dimension and symbol count");
-
   // Verify that the bound operands are valid dimension/symbols.
   /// Lower bound.
-  if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
-                                           op.getLowerBoundMap().getNumDims())))
-    return failure();
+  if (op.getLowerBoundMap().getNumInputs() > 0)
+    if (failed(
+            verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
+                                          op.getLowerBoundMap().getNumDims())))
+      return failure();
   /// Upper bound.
-  if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
-                                           op.getUpperBoundMap().getNumDims())))
-    return failure();
+  if (op.getUpperBoundMap().getNumInputs() > 0)
+    if (failed(
+            verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
+                                          op.getUpperBoundMap().getNumDims())))
+      return failure();
+
+  unsigned opNumResults = op.getNumResults();
+  if (opNumResults == 0)
+    return success();
+
+  // If ForOp defines values, check that the number and types of the defined
+  // values match ForOp initial iter operands and backedge basic block
+  // arguments.
+  if (op.getNumIterOperands() != opNumResults)
+    return op.emitOpError(
+        "mismatch between the number of loop-carried values and results");
+  if (op.getNumRegionIterArgs() != opNumResults)
+    return op.emitOpError(
+        "mismatch between the number of basic block args and results");
+
   return success();
 }
 
@@ -1375,9 +1399,34 @@ static ParseResult parseAffineForOp(OpAsmParser &parser,
           "expected step to be representable as a positive signed integer");
   }
 
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
+  SmallVector<Type, 4> argTypes;
+  regionArgs.push_back(inductionVariable);
+
+  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
+    // Parse assignment list and results type list.
+    if (parser.parseAssignmentList(regionArgs, operands) ||
+        parser.parseArrowTypeList(result.types))
+      return failure();
+    // Resolve input operands.
+    for (auto operandType : llvm::zip(operands, result.types))
+      if (parser.resolveOperand(std::get<0>(operandType),
+                                std::get<1>(operandType), result.operands))
+        return failure();
+  }
+  // Induction variable.
+  Type indexType = builder.getIndexType();
+  argTypes.push_back(indexType);
+  // Loop carried variables.
+  argTypes.append(result.types.begin(), result.types.end());
   // Parse the body region.
   Region *body = result.addRegion();
-  if (parser.parseRegion(*body, inductionVariable, builder.getIndexType()))
+  if (regionArgs.size() != argTypes.size())
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch between the number of loop-carried values and results");
+  if (parser.parseRegion(*body, regionArgs, argTypes))
     return failure();
 
   AffineForOp::ensureTerminator(*body, builder, result.location);
@@ -1427,6 +1476,13 @@ static void printBound(AffineMapAttr boundMap,
                         map.getNumDims(), p);
 }
 
+unsigned AffineForOp::getNumIterOperands() {
+  AffineMap lbMap = getLowerBoundMapAttr().getValue();
+  AffineMap ubMap = getUpperBoundMapAttr().getValue();
+
+  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
+}
+
 static void print(OpAsmPrinter &p, AffineForOp op) {
   p << op.getOperationName() << ' ';
   p.printOperand(op.getBody()->getArgument(0));
@@ -1437,9 +1493,22 @@ static void print(OpAsmPrinter &p, AffineForOp op) {
 
   if (op.getStep() != 1)
     p << " step " << op.getStep();
+
+  bool printBlockTerminators = false;
+  if (op.getNumIterOperands() > 0) {
+    p << " iter_args(";
+    auto regionArgs = op.getRegionIterArgs();
+    auto operands = op.getIterOperands();
+
+    llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
+      p << std::get<0>(it) << " = " << std::get<1>(it);
+    });
+    p << ") -> (" << op.getResultTypes() << ")";
+    printBlockTerminators = true;
+  }
+
   p.printRegion(op.region(),
-                /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/false);
+                /*printEntryBlockArgs=*/false, printBlockTerminators);
   p.printOptionalAttrDict(op.getAttrs(),
                           /*elidedAttrs=*/{op.getLowerBoundAttrName(),
                                            op.getUpperBoundAttrName(),
@@ -1555,8 +1624,8 @@ AffineBound AffineForOp::getLowerBound() {
 AffineBound AffineForOp::getUpperBound() {
   auto lbMap = getLowerBoundMap();
   auto ubMap = getUpperBoundMap();
-  return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(),
-                     ubMap);
+  return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
+                     lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
 }
 
 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
@@ -1567,6 +1636,8 @@ void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
 
   auto ubOperands = getUpperBoundOperands();
   newOperands.append(ubOperands.begin(), ubOperands.end());
+  auto iterOperands = getIterOperands();
+  newOperands.append(iterOperands.begin(), iterOperands.end());
   getOperation()->setOperands(newOperands);
 
   setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
@@ -1578,6 +1649,8 @@ void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
 
   SmallVector<Value, 4> newOperands(getLowerBoundOperands());
   newOperands.append(ubOperands.begin(), ubOperands.end());
+  auto iterOperands = getIterOperands();
+  newOperands.append(iterOperands.begin(), iterOperands.end());
   getOperation()->setOperands(newOperands);
 
   setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
@@ -1630,7 +1703,9 @@ AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
 }
 
 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
-  return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
+  return {operand_begin() + getLowerBoundMap().getNumInputs(),
+          operand_begin() + getLowerBoundMap().getNumInputs() +
+              getUpperBoundMap().getNumInputs()};
 }
 
 bool AffineForOp::matchingBoundOperandList() {
@@ -1710,8 +1785,8 @@ static void buildAffineLoopNestImpl(
   ivs.reserve(lbs.size());
   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
     // Callback for creating the loop body, always creates the terminator.
-    auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                        Value iv) {
+    auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
+                        ValueRange iterArgs) {
       ivs.push_back(iv);
       // In the innermost loop, call the body builder.
       if (i == e - 1 && bodyBuilderFn) {
@@ -1729,16 +1804,19 @@ static void buildAffineLoopNestImpl(
 }
 
 /// Creates an affine loop from the bounds known to be constants.
-static AffineForOp buildAffineLoopFromConstants(
-    OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step,
-    function_ref<void(OpBuilder &, Location, Value)> bodyBuilderFn) {
-  return builder.create<AffineForOp>(loc, lb, ub, step, bodyBuilderFn);
+static AffineForOp
+buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
+                             int64_t ub, int64_t step,
+                             AffineForOp::BodyBuilderFn bodyBuilderFn) {
+  return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
+                                     bodyBuilderFn);
 }
 
 /// Creates an affine loop from the bounds that may or may not be constants.
-static AffineForOp buildAffineLoopFromValues(
-    OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step,
-    function_ref<void(OpBuilder &, Location, Value)> bodyBuilderFn) {
+static AffineForOp
+buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
+                          int64_t step,
+                          AffineForOp::BodyBuilderFn bodyBuilderFn) {
   auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
   auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
   if (lbConst && ubConst)
@@ -1747,7 +1825,7 @@ static AffineForOp buildAffineLoopFromValues(
                                         bodyBuilderFn);
   return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
                                      builder.getDimIdentityMap(), step,
-                                     bodyBuilderFn);
+                                     /*iterArgs=*/llvm::None, bodyBuilderFn);
 }
 
 void mlir::buildAffineLoopNest(

diff  --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 4d7c9c23edb6..c38a78060dc6 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -379,3 +379,14 @@ func @affine_if_with_else_region_args(%N: index) {
   return
 }
 
+// -----
+
+func @affine_for_iter_args_mismatch(%buffer: memref<1024xf32>) -> f32 {
+  %sum_0 = constant 0.0 : f32
+  // expected-error at +1 {{mismatch between the number of loop-carried values and results}}
+  %res = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_0) -> (f32, f32) {
+    %t = affine.load %buffer[%i] : memref<1024xf32>
+    affine.yield %t : f32
+  }
+  return %res : f32
+}

diff  --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index cd6086910648..627104bae976 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -184,3 +184,53 @@ func @affine_if() -> f32 {
   // CHECK: return %[[OUT]] : f32
   return %0 : f32
 }
+
+// -----
+
+//  Test affine.for with yield values.
+
+#set = affine_set<(d0): (d0 - 10 >= 0)>
+
+// CHECK-LABEL: func @yield_loop
+func @yield_loop(%buffer: memref<1024xf32>) -> f32 {
+  %sum_init_0 = constant 0.0 : f32
+  %res = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_init_0) -> f32 {
+    %t = affine.load %buffer[%i] : memref<1024xf32>
+    %sum_next = affine.if #set(%i) -> (f32) {
+      %new_sum = addf %sum_iter, %t : f32
+      affine.yield %new_sum : f32
+    } else {
+      affine.yield %sum_iter : f32
+    }
+    affine.yield %sum_next : f32
+  }
+  return %res : f32
+}
+// CHECK:      %[[const_0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[output:.*]] = affine.for %{{.*}} = 0 to 10 step 2 iter_args(%{{.*}} = %[[const_0]]) -> (f32) {
+// CHECK:        affine.if #set0(%{{.*}}) -> f32 {
+// CHECK:          affine.yield %{{.*}} : f32
+// CHECK-NEXT:   } else {
+// CHECK-NEXT:     affine.yield %{{.*}} : f32
+// CHECK-NEXT:   }
+// CHECK-NEXT:   affine.yield %{{.*}} : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[output]] : f32
+
+// CHECK-LABEL: func @affine_for_multiple_yield
+func @affine_for_multiple_yield(%buffer: memref<1024xf32>) -> (f32, f32) {
+  %init_0 = constant 0.0 : f32
+  %res1, %res2 = affine.for %i = 0 to 10 step 2 iter_args(%iter_arg1 = %init_0, %iter_arg2 = %init_0) -> (f32, f32) {
+    %t = affine.load %buffer[%i] : memref<1024xf32>
+    %ret1 = addf %t, %iter_arg1 : f32
+    %ret2 = addf %t, %iter_arg2 : f32
+    affine.yield %ret1, %ret2 : f32, f32
+  }
+  return %res1, %res2 : f32, f32
+}
+// CHECK:      %[[const_0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[output:[0-9]+]]:2 = affine.for %{{.*}} = 0 to 10 step 2 iter_args(%[[iter_arg1:.*]] = %[[const_0]], %[[iter_arg2:.*]] = %[[const_0]]) -> (f32, f32) {
+// CHECK:        %[[res1:.*]] = addf %{{.*}}, %[[iter_arg1]] : f32
+// CHECK-NEXT:   %[[res2:.*]] = addf %{{.*}}, %[[iter_arg2]] : f32
+// CHECK-NEXT:   affine.yield %[[res1]], %[[res2]] : f32, f32
+// CHECK-NEXT: }

diff  --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 4695090dacb5..ec22dd04dc4a 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -177,6 +177,38 @@ TEST_FUNC(builder_max_min_for) {
   f.erase();
 }
 
+TEST_FUNC(builder_affine_for_iter_args) {
+  auto indexType = IndexType::get(&globalContext());
+  auto f = makeFunction("builder_affine_for_iter_args", {},
+                        {indexType, indexType, indexType});
+
+  OpBuilder builder(f.getBody());
+  ScopedContext scope(builder, f.getLoc());
+  Value i, lb_1(f.getArgument(0)), ub_1(f.getArgument(1)),
+      ub_2(f.getArgument(2));
+  Value c32(std_constant_int(32, 32));
+  Value c42(std_constant_int(42, 32));
+  using namespace edsc::op;
+  affineLoopBuilder(
+      lb_1, {ub_1, ub_2}, 2, {c32, c42}, [&](Value iv, ValueRange args) {
+        Value sum(args[0] + args[1]);
+        builder.create<AffineYieldOp>(f.getLoc(), ValueRange({args[1], sum}));
+      });
+
+  // clang-format off
+  // CHECK-LABEL: func @builder_affine_for_iter_args
+  // CHECK:       (%[[lb_1:.*]]: index, %[[ub_1:.*]]: index, %[[ub_2:.*]]: index) {
+  // CHECK-NEXT:    %[[c32:.*]] = constant 32 : i32
+  // CHECK-NEXT:    %[[c42:.*]] = constant 42 : i32
+  // CHECK-NEXT:    %{{.*}} = affine.for %{{.*}} = affine_map<(d0) -> (d0)>(%{{.*}}) to min affine_map<(d0, d1) -> (d0, d1)>(%[[ub_1]], %[[ub_2]]) step 2 iter_args(%[[iarg_1:.*]] = %[[c32]], %[[iarg_2:.*]] = %[[c42]]) -> (i32, i32) {
+  // CHECK-NEXT:      %[[sum:.*]] = addi %[[iarg_1]], %[[iarg_2]] : i32
+  // CHECK-NEXT:      affine.yield %[[iarg_2]], %[[sum]] : i32, i32
+  // CHECK-NEXT:    }
+  // clang-format on
+  f.print(llvm::outs());
+  f.erase();
+}
+
 TEST_FUNC(builder_block_append) {
   using namespace edsc::op;
   auto f = makeFunction("builder_blocks");


        


More information about the Mlir-commits mailing list