[Mlir-commits] [mlir] e2b7161 - [MLIR] Add argument related API to Region

Rahul Joshi llvmlistbot at llvm.org
Tue Jul 14 09:28:40 PDT 2020


Author: Rahul Joshi
Date: 2020-07-14T09:28:29-07:00
New Revision: e2b716105be33bc1296b4a0c56f8cfc2e8595037

URL: https://github.com/llvm/llvm-project/commit/e2b716105be33bc1296b4a0c56f8cfc2e8595037
DIFF: https://github.com/llvm/llvm-project/commit/e2b716105be33bc1296b4a0c56f8cfc2e8595037.diff

LOG: [MLIR] Add argument related API to Region

- Arguments of the first block of a region are considered region arguments.
- Add API on Region class to deal with these arguments directly instead of
  using the front() block.
- Changed several instances of existing code that can use this API
- Fixes https://bugs.llvm.org/show_bug.cgi?id=46535

Differential Revision: https://reviews.llvm.org/D83599

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/FunctionSupport.h
    mlir/include/mlir/IR/Region.h
    mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/FunctionImplementation.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/IR/Region.cpp
    mlir/lib/Transforms/SCCP.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index e7e67e24381d..c0a6ac101d7b 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -237,7 +237,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
     /// the workgroup memory
     ArrayRef<BlockArgument> getWorkgroupAttributions() {
       auto begin =
-          std::next(getBody().front().args_begin(), getType().getNumInputs());
+          std::next(getBody().args_begin(), getType().getNumInputs());
       auto end = std::next(begin, getNumWorkgroupAttributions());
       return {begin, end};
     }
@@ -248,7 +248,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
 
     /// Returns the number of buffers located in the private memory.
     unsigned getNumPrivateAttributions() {
-      return getBody().front().getNumArguments() - getType().getNumInputs() -
+      return getBody().getNumArguments() - getType().getNumInputs() -
           getNumWorkgroupAttributions();
     }
  
@@ -258,9 +258,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
       // Buffers on the private memory always come after buffers on the workgroup
       // memory.
       auto begin =
-          std::next(getBody().front().args_begin(),
+          std::next(getBody().args_begin(),
                     getType().getNumInputs() + getNumWorkgroupAttributions());
-      return {begin, getBody().front().args_end()};
+      return {begin, getBody().args_end()};
     }
 
     /// Adds a new block argument that corresponds to buffers located in

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b34dac4f38a7..452546f5da83 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -583,7 +583,7 @@ def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
   let extraClassDeclaration = [{
     // The value stored in memref[ivs].
     Value getCurrentValue() {
-      return body().front().getArgument(0);
+      return body().getArgument(0);
     }
     MemRefType getMemRefType() {
       return memref().getType().cast<MemRefType>();

diff  --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 87e4b6780164..b358215ca962 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -216,15 +216,13 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   }
 
   /// Gets argument.
-  BlockArgument getArgument(unsigned idx) {
-    return getBlocks().front().getArgument(idx);
-  }
+  BlockArgument getArgument(unsigned idx) { return getBody().getArgument(idx); }
 
   /// Support argument iteration.
-  using args_iterator = Block::args_iterator;
-  args_iterator args_begin() { return front().args_begin(); }
-  args_iterator args_end() { return front().args_end(); }
-  Block::BlockArgListType getArguments() { return front().getArguments(); }
+  using args_iterator = Region::args_iterator;
+  args_iterator args_begin() { return getBody().args_begin(); }
+  args_iterator args_end() { return getBody().args_end(); }
+  Block::BlockArgListType getArguments() { return getBody().getArguments(); }
 
   //===--------------------------------------------------------------------===//
   // Argument Attributes

diff  --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 35e773d74385..5671f2b5581e 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -16,6 +16,9 @@
 #include "mlir/IR/Block.h"
 
 namespace mlir {
+class TypeRange;
+template <typename ValueRangeT>
+class ValueTypeRange;
 class BlockAndValueMapping;
 
 /// This class contains a list of basic blocks and a link to the parent
@@ -62,6 +65,48 @@ class Region {
     return &Region::blocks;
   }
 
+  //===--------------------------------------------------------------------===//
+  // Argument Handling
+  //===--------------------------------------------------------------------===//
+
+  // This is the list of arguments to the block.
+  using BlockArgListType = MutableArrayRef<BlockArgument>;
+  BlockArgListType getArguments() {
+    return empty() ? BlockArgListType() : front().getArguments();
+  }
+  using args_iterator = BlockArgListType::iterator;
+  using reverse_args_iterator = BlockArgListType::reverse_iterator;
+  args_iterator args_begin() { return getArguments().begin(); }
+  args_iterator args_end() { return getArguments().end(); }
+  reverse_args_iterator args_rbegin() { return getArguments().rbegin(); }
+  reverse_args_iterator args_rend() { return getArguments().rend(); }
+
+  bool args_empty() { return getArguments().empty(); }
+
+  /// Add one value to the argument list.
+  BlockArgument addArgument(Type type) { return front().addArgument(type); }
+
+  /// Insert one value to the position in the argument list indicated by the
+  /// given iterator. The existing arguments are shifted. The block is expected
+  /// not to have predecessors.
+  BlockArgument insertArgument(args_iterator it, Type type) {
+    return front().insertArgument(it, type);
+  }
+
+  /// Add one argument to the argument list for each type specified in the list.
+  iterator_range<args_iterator> addArguments(TypeRange types);
+
+  /// Add one value to the argument list at the specified position.
+  BlockArgument insertArgument(unsigned index, Type type) {
+    return front().insertArgument(index, type);
+  }
+
+  /// Erase the argument at 'index' and remove it from the argument list.
+  void eraseArgument(unsigned index) { front().eraseArgument(index); }
+
+  unsigned getNumArguments() { return getArguments().size(); }
+  BlockArgument getArgument(unsigned i) { return getArguments()[i]; }
+
   //===--------------------------------------------------------------------===//
   // Operation list utilities
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 0a657e5387b2..b1d5a854de80 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -417,8 +417,8 @@ static LogicalResult processParallelLoop(
 
     if (isMappedToProcessor(processor)) {
       // Use the corresponding thread/grid index as replacement for the loop iv.
-      Value operand = launchOp.body().front().getArgument(
-          getLaunchOpArgumentNum(processor));
+      Value operand =
+          launchOp.body().getArgument(getLaunchOpArgumentNum(processor));
       // Take the indexmap and add the lower bound and step computations in.
       // This computes operand * step + lowerBound.
       // Use an affine map here so that it composes nicely with the provided

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index fd0c6245e084..dd8200d3687b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -127,9 +127,9 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
     return allReduce.emitError(
         "expected either an op attribute or a non-empty body");
   if (!allReduce.body().empty()) {
-    if (allReduce.body().front().getNumArguments() != 2)
+    if (allReduce.body().getNumArguments() != 2)
       return allReduce.emitError("expected two region arguments");
-    for (auto argument : allReduce.body().front().getArguments()) {
+    for (auto argument : allReduce.body().getArguments()) {
       if (argument.getType() != allReduce.getType())
         return allReduce.emitError("incorrect region argument type");
     }
@@ -219,25 +219,25 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
 
 KernelDim3 LaunchOp::getBlockIds() {
   assert(!body().empty() && "LaunchOp body must not be empty.");
-  auto args = body().front().getArguments();
+  auto args = body().getArguments();
   return KernelDim3{args[0], args[1], args[2]};
 }
 
 KernelDim3 LaunchOp::getThreadIds() {
   assert(!body().empty() && "LaunchOp body must not be empty.");
-  auto args = body().front().getArguments();
+  auto args = body().getArguments();
   return KernelDim3{args[3], args[4], args[5]};
 }
 
 KernelDim3 LaunchOp::getGridSize() {
   assert(!body().empty() && "LaunchOp body must not be empty.");
-  auto args = body().front().getArguments();
+  auto args = body().getArguments();
   return KernelDim3{args[6], args[7], args[8]};
 }
 
 KernelDim3 LaunchOp::getBlockSize() {
   assert(!body().empty() && "LaunchOp body must not be empty.");
-  auto args = body().getBlocks().front().getArguments();
+  auto args = body().getArguments();
   return KernelDim3{args[9], args[10], args[11]};
 }
 
@@ -254,8 +254,7 @@ static LogicalResult verify(LaunchOp op) {
   // sizes and transforms them into kNumConfigRegionAttributes region arguments
   // for block/thread identifiers and grid/block sizes.
   if (!op.body().empty()) {
-    Block &entryBlock = op.body().front();
-    if (entryBlock.getNumArguments() !=
+    if (op.body().getNumArguments() !=
         LaunchOp::kNumConfigOperands + op.getNumOperands())
       return op.emitOpError("unexpected number of region arguments");
   }
@@ -463,8 +462,8 @@ BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) {
   auto attrName = getNumWorkgroupAttributionsAttrName();
   auto attr = getAttrOfType<IntegerAttr>(attrName);
   setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1));
-  return getBody().front().insertArgument(
-      getType().getNumInputs() + attr.getInt(), type);
+  return getBody().insertArgument(getType().getNumInputs() + attr.getInt(),
+                                  type);
 }
 
 /// Adds a new block argument that corresponds to buffers located in
@@ -472,7 +471,7 @@ BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) {
 BlockArgument GPUFuncOp::addPrivateAttribution(Type type) {
   // Buffers on the private memory always come after buffers on the workgroup
   // memory.
-  return getBody().front().addArgument(type);
+  return getBody().addArgument(type);
 }
 
 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,

diff  --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 6f6f1c27241c..38df9ef99154 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -181,8 +181,8 @@ struct GpuAllReduceRewriter {
 
       // Insert accumulator body between split block.
       BlockAndValueMapping mapping;
-      mapping.map(body.front().getArgument(0), lhs);
-      mapping.map(body.front().getArgument(1), rhs);
+      mapping.map(body.getArgument(0), lhs);
+      mapping.map(body.getArgument(1), rhs);
       rewriter.cloneRegionBefore(body, *split->getParent(),
                                  split->getIterator(), mapping);
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 47440265239d..a2659d6a0eec 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -1102,8 +1102,7 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
                                                      unsigned argIndex,
                                                      NamedAttribute attribute) {
   return verifyRegionAttribute(
-      op->getLoc(),
-      op->getRegion(regionIndex).front().getArgument(argIndex).getType(),
+      op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
       attribute);
 }
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 3e71c48f0871..84c35c9fb7a5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -525,22 +525,21 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
 
     Region *bodyRegion = result.addRegion();
     bodyRegion->push_back(new Block());
-    bodyRegion->front().addArgument(elementType);
+    bodyRegion->addArgument(elementType);
   }
 }
 
 static LogicalResult verify(GenericAtomicRMWOp op) {
-  auto &block = op.body().front();
-  if (block.getNumArguments() != 1)
+  auto &body = op.body();
+  if (body.getNumArguments() != 1)
     return op.emitOpError("expected single number of entry block arguments");
 
-  if (op.getResult().getType() != block.getArgument(0).getType())
+  if (op.getResult().getType() != body.getArgument(0).getType())
     return op.emitOpError(
         "expected block argument of the same type result type");
 
   bool hasSideEffects =
-      op.body()
-          .walk([&](Operation *nestedOp) {
+      body.walk([&](Operation *nestedOp) {
             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
               return WalkResult::advance();
             nestedOp->emitError("body of 'generic_atomic_rmw' should contain "

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 09135021a732..372e4c93dc37 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -619,7 +619,7 @@ unsigned SSANameState::getBlockID(Block *block) {
 
 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
   assert(!region.empty() && "cannot shadow arguments of an empty region");
-  assert(region.front().getNumArguments() == namesToUse.size() &&
+  assert(region.getNumArguments() == namesToUse.size() &&
          "incorrect number of names passed in");
   assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
          "only KnownIsolatedFromAbove ops can shadow names");
@@ -629,7 +629,7 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
     auto nameToUse = namesToUse[i];
     if (nameToUse == nullptr)
       continue;
-    auto nameToReplace = region.front().getArgument(i);
+    auto nameToReplace = region.getArgument(i);
 
     nameStr.clear();
     llvm::raw_svector_ostream nameStream(nameStr);

diff  --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index 8b90ff13244f..13aee344bbdc 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -238,7 +238,7 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
       p << ", ";
 
     if (!isExternal) {
-      p.printOperand(body.front().getArgument(i));
+      p.printOperand(body.getArgument(i));
       p << ": ";
     }
 

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 1e2a47639fdb..8feab8ec903a 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1022,7 +1022,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
     if (region.empty())
       continue;
 
-    if (region.front().getNumArguments() != 0) {
+    if (region.getNumArguments() != 0) {
       if (op->getNumRegions() > 1)
         return op->emitOpError("region #")
                << region.getRegionNumber() << " should have no arguments";

diff  --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index aa2acc00dde4..b616eaa15422 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -33,6 +33,11 @@ Location Region::getLoc() {
   return container->getLoc();
 }
 
+/// Add one argument to the argument list for each type specified in the list.
+iterator_range<Region::args_iterator> Region::addArguments(TypeRange types) {
+  return front().addArguments(types);
+}
+
 Region *Region::getParentRegion() {
   assert(container && "region is not attached to a container");
   return container->getParentRegion();

diff  --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 25115fc8ffbd..95b035ee68cd 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -123,7 +123,7 @@ class CallableLatticeState {
   /// Build a lattice state with a given callable region, and a specified number
   /// of results to be initialized to the default lattice value (Unknown).
   CallableLatticeState(Region *callableRegion, unsigned numResults)
-      : callableArguments(callableRegion->front().getArguments()),
+      : callableArguments(callableRegion->getArguments()),
         resultLatticeValues(numResults) {}
 
   /// Returns the arguments to the callable region.
@@ -403,7 +403,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
       // If not all of the uses of this symbol are visible, we can't track the
       // state of the arguments.
       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
-        markAllOverdefined(callableRegion->front().getArguments());
+        markAllOverdefined(callableRegion->getArguments());
     }
     if (callableLatticeState.empty())
       return;

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c471cd3ead3e..255b1c152a36 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -284,7 +284,7 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
                   ConversionPatternRewriter &rewriter) const final {
     auto illegalOp =
         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
-    rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
+    rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
                                         illegalOp);
     rewriter.updateRootInPlace(op, [] {});
     return success();


        


More information about the Mlir-commits mailing list