[Mlir-commits] [mlir] 91f6a5f - [MLIR][SPIRV] Control attributes support for loop and selection

George Mitenkov llvmlistbot at llvm.org
Sun Aug 2 23:58:24 PDT 2020


Author: George Mitenkov
Date: 2020-08-03T09:31:37+03:00
New Revision: 91f6a5f7854a542611ed76442acb1ec375a9feb2

URL: https://github.com/llvm/llvm-project/commit/91f6a5f7854a542611ed76442acb1ec375a9feb2
DIFF: https://github.com/llvm/llvm-project/commit/91f6a5f7854a542611ed76442acb1ec375a9feb2.diff

LOG: [MLIR][SPIRV] Control attributes support for loop and selection

This patch handles loopControl and selectionControl in parsing and
printing. In order to reuse the functionality, and avoid handling cases when
`{` of the region is parsed as a dictionary attribute, `control` keyword was
introduced.`None` is a default control attribute. This functionality can be
later extended to `spv.func`.
Also, loopControl and selectionControl can now be (de)serialized.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D84175

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
    mlir/test/Dialect/SPIRV/Serialization/loop.mlir
    mlir/test/Dialect/SPIRV/Serialization/selection.mlir
    mlir/test/Dialect/SPIRV/control-flow-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index bac65a02f63d..b7d36f4a9487 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -36,6 +36,7 @@ static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
 static constexpr const char kCallee[] = "callee";
 static constexpr const char kClusterSize[] = "cluster_size";
+static constexpr const char kControl[] = "control";
 static constexpr const char kDefaultValueAttrName[] = "default_value";
 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
@@ -161,6 +162,25 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
   return success();
 }
 
+/// Parses Function, Selection and Loop control attributes. If no control is
+/// specified, "None" is used as a default.
+template <typename EnumClass>
+static ParseResult
+parseControlAttribute(OpAsmParser &parser, OperationState &state,
+                      StringRef attrName = spirv::attributeName<EnumClass>()) {
+  if (succeeded(parser.parseOptionalKeyword(kControl))) {
+    EnumClass control;
+    if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
+        parser.parseRParen())
+      return failure();
+    return success();
+  }
+  // Set control to "None" otherwise.
+  Builder builder = parser.getBuilder();
+  state.addAttribute(attrName, builder.getI32IntegerAttr(0));
+  return success();
+}
+
 /// Parses optional memory access attributes attached to a memory access
 /// operand/pointer. Specifically, parses the following syntax:
 ///     (`[` memory-access `]`)?
@@ -2082,12 +2102,8 @@ void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
 }
 
 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
-  // TODO: support loop control properly
-  Builder builder = parser.getBuilder();
-  state.addAttribute("loop_control",
-                     builder.getI32IntegerAttr(
-                         static_cast<uint32_t>(spirv::LoopControl::None)));
-
+  if (parseControlAttribute<spirv::LoopControl>(parser, state))
+    return failure();
   return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
                             /*argTypes=*/{});
 }
@@ -2096,6 +2112,9 @@ static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
   auto *op = loopOp.getOperation();
 
   printer << spirv::LoopOp::getOperationName();
+  auto control = loopOp.loop_control();
+  if (control != spirv::LoopControl::None)
+    printer << " control(" << spirv::stringifyLoopControl(control) << ")";
   printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
                       /*printBlockTerminators=*/true);
 }
@@ -2445,12 +2464,8 @@ static LogicalResult verify(spirv::SelectOp op) {
 
 static ParseResult parseSelectionOp(OpAsmParser &parser,
                                     OperationState &state) {
-  // TODO: support selection control properly
-  Builder builder = parser.getBuilder();
-  state.addAttribute("selection_control",
-                     builder.getI32IntegerAttr(
-                         static_cast<uint32_t>(spirv::SelectionControl::None)));
-
+  if (parseControlAttribute<spirv::SelectionControl>(parser, state))
+    return failure();
   return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
                             /*argTypes=*/{});
 }
@@ -2459,6 +2474,9 @@ static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
   auto *op = selectionOp.getOperation();
 
   printer << spirv::SelectionOp::getOperationName();
+  auto control = selectionOp.selection_control();
+  if (control != spirv::SelectionControl::None)
+    printer << " control(" << spirv::stringifySelectionControl(control) << ")";
   printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
                       /*printBlockTerminators=*/true);
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 4ba3f16feef0..eaa8f4d94833 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -64,11 +64,14 @@ struct BlockMergeInfo {
   Block *mergeBlock;
   Block *continueBlock; // nullptr for spv.selection
   Location loc;
-
-  BlockMergeInfo(Location location)
-      : mergeBlock(nullptr), continueBlock(nullptr), loc(location) {}
-  BlockMergeInfo(Location location, Block *m, Block *c = nullptr)
-      : mergeBlock(m), continueBlock(c), loc(location) {}
+  uint32_t control;
+
+  BlockMergeInfo(Location location, uint32_t control)
+      : mergeBlock(nullptr), continueBlock(nullptr), loc(location),
+        control(control) {}
+  BlockMergeInfo(Location location, uint32_t control, Block *m,
+                 Block *c = nullptr)
+      : mergeBlock(m), continueBlock(c), loc(location), control(control) {}
 };
 
 /// A struct for containing OpLine instruction information.
@@ -1681,16 +1684,12 @@ LogicalResult Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
         "OpSelectionMerge must specify merge target and selection control");
   }
 
-  if (static_cast<uint32_t>(spirv::SelectionControl::None) != operands[1]) {
-    return emitError(unknownLoc,
-                     "unimplmented OpSelectionMerge selection control: ")
-           << operands[2];
-  }
-
   auto *mergeBlock = getOrCreateBlock(operands[0]);
   auto loc = createFileLineColLoc(opBuilder);
+  auto selectionControl = operands[1];
 
-  if (!blockMergeInfo.try_emplace(curBlock, loc, mergeBlock).second) {
+  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
+           .second) {
     return emitError(
         unknownLoc,
         "a block cannot have more than one OpSelectionMerge instruction");
@@ -1709,16 +1708,13 @@ LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
                                  "continue target and loop control");
   }
 
-  if (static_cast<uint32_t>(spirv::LoopControl::None) != operands[2]) {
-    return emitError(unknownLoc, "unimplmented OpLoopMerge loop control: ")
-           << operands[2];
-  }
-
   auto *mergeBlock = getOrCreateBlock(operands[0]);
   auto *continueBlock = getOrCreateBlock(operands[1]);
   auto loc = createFileLineColLoc(opBuilder);
+  uint32_t loopControl = operands[2];
 
-  if (!blockMergeInfo.try_emplace(curBlock, loc, mergeBlock, continueBlock)
+  if (!blockMergeInfo
+           .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
            .second) {
     return emitError(
         unknownLoc,
@@ -1771,25 +1767,27 @@ class ControlFlowStructurizer {
   /// the `headerBlock` will be redirected to the `mergeBlock`.
   /// This method will also update `mergeInfo` by remapping all blocks inside to
   /// the newly cloned ones inside structured control flow op's regions.
-  static LogicalResult structurize(Location loc, BlockMergeInfoMap &mergeInfo,
+  static LogicalResult structurize(Location loc, uint32_t control,
+                                   BlockMergeInfoMap &mergeInfo,
                                    Block *headerBlock, Block *mergeBlock,
                                    Block *continueBlock) {
-    return ControlFlowStructurizer(loc, mergeInfo, headerBlock, mergeBlock,
-                                   continueBlock)
+    return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock,
+                                   mergeBlock, continueBlock)
         .structurizeImpl();
   }
 
 private:
-  ControlFlowStructurizer(Location loc, BlockMergeInfoMap &mergeInfo,
-                          Block *header, Block *merge, Block *cont)
-      : location(loc), blockMergeInfo(mergeInfo), headerBlock(header),
-        mergeBlock(merge), continueBlock(cont) {}
+  ControlFlowStructurizer(Location loc, uint32_t control,
+                          BlockMergeInfoMap &mergeInfo, Block *header,
+                          Block *merge, Block *cont)
+      : location(loc), control(control), blockMergeInfo(mergeInfo),
+        headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
 
   /// Creates a new spv.selection op at the beginning of the `mergeBlock`.
-  spirv::SelectionOp createSelectionOp();
+  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
 
   /// Creates a new spv.loop op at the beginning of the `mergeBlock`.
-  spirv::LoopOp createLoopOp();
+  spirv::LoopOp createLoopOp(uint32_t loopControl);
 
   /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
   void collectBlocksInConstruct();
@@ -1797,6 +1795,7 @@ class ControlFlowStructurizer {
   LogicalResult structurizeImpl();
 
   Location location;
+  uint32_t control;
 
   BlockMergeInfoMap &blockMergeInfo;
 
@@ -1808,26 +1807,26 @@ class ControlFlowStructurizer {
 };
 } // namespace
 
-spirv::SelectionOp ControlFlowStructurizer::createSelectionOp() {
+spirv::SelectionOp
+ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
   // Create a builder and set the insertion point to the beginning of the
   // merge block so that the newly created SelectionOp will be inserted there.
   OpBuilder builder(&mergeBlock->front());
 
-  auto control = builder.getI32IntegerAttr(
-      static_cast<uint32_t>(spirv::SelectionControl::None));
+  auto control = builder.getI32IntegerAttr(selectionControl);
   auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
   selectionOp.addMergeBlock();
 
   return selectionOp;
 }
 
-spirv::LoopOp ControlFlowStructurizer::createLoopOp() {
+spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
   // Create a builder and set the insertion point to the beginning of the
   // merge block so that the newly created LoopOp will be inserted there.
   OpBuilder builder(&mergeBlock->front());
 
-  // TODO: handle loop control properly
-  auto loopOp = builder.create<spirv::LoopOp>(location);
+  auto control = builder.getI32IntegerAttr(loopControl);
+  auto loopOp = builder.create<spirv::LoopOp>(location, control);
   loopOp.addEntryAndMergeBlock();
 
   return loopOp;
@@ -1852,10 +1851,10 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
   Operation *op = nullptr;
   bool isLoop = continueBlock != nullptr;
   if (isLoop) {
-    if (auto loopOp = createLoopOp())
+    if (auto loopOp = createLoopOp(control))
       op = loopOp.getOperation();
   } else {
-    if (auto selectionOp = createSelectionOp())
+    if (auto selectionOp = createSelectionOp(control))
       op = selectionOp.getOperation();
   }
   if (!op)
@@ -1992,7 +1991,8 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
       // The iterator should be erased before adding a new entry into
       // blockMergeInfo to avoid iterator invalidation.
       blockMergeInfo.erase(it);
-      blockMergeInfo.try_emplace(newHeader, loc, newMerge, newContinue);
+      blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
+                                 newContinue);
     }
 
     // The structured selection/loop's entry block does not have arguments.
@@ -2096,9 +2096,9 @@ LogicalResult Deserializer::structurizeControlFlow() {
     // Erase this case before calling into structurizer, who will update
     // blockMergeInfo.
     blockMergeInfo.erase(blockMergeInfo.begin());
-    if (failed(ControlFlowStructurizer::structurize(mergeInfo.loc,
-                                                    blockMergeInfo, headerBlock,
-                                                    mergeBlock, continueBlock)))
+    if (failed(ControlFlowStructurizer::structurize(
+            mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock,
+            mergeBlock, continueBlock)))
       return failure();
   }
 

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 859ea556f39f..223adf47ab2e 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1573,10 +1573,9 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   auto emitSelectionMerge = [&]() {
     emitDebugLine(functionBody, loc);
     lastProcessedWasMergeInst = true;
-    // TODO: properly support selection control here
     encodeInstructionInto(
         functionBody, spirv::Opcode::OpSelectionMerge,
-        {mergeID, static_cast<uint32_t>(spirv::SelectionControl::None)});
+        {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
   };
   // For structured selection, we cannot have blocks in the selection construct
   // branching to the selection header block. Entering the selection (and
@@ -1636,10 +1635,9 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   auto emitLoopMerge = [&]() {
     emitDebugLine(functionBody, loc);
     lastProcessedWasMergeInst = true;
-    // TODO: properly support loop control here
     encodeInstructionInto(
         functionBody, spirv::Opcode::OpLoopMerge,
-        {mergeID, continueID, static_cast<uint32_t>(spirv::LoopControl::None)});
+        {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
   };
   if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
     return failure();

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
index d6e2090f02bb..8f0b35ef6fc8 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
@@ -119,8 +119,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
 
 // CHECK:        spv.Branch ^bb1
 // CHECK-NEXT: ^bb1:
-// CHECK-NEXT:   spv.loop
-    spv.loop {
+// CHECK-NEXT:   spv.loop control(Unroll)
+    spv.loop control(Unroll) {
 // CHECK-NEXT:     spv.Branch ^bb1
       spv.Branch ^header
 
@@ -140,8 +140,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
       spv.Store "Function" %jvar, %zero : i32
 // CHECK-NEXT:     spv.Branch ^bb3
 // CHECK-NEXT:   ^bb3:
-// CHECK-NEXT:     spv.loop
-      spv.loop {
+// CHECK-NEXT:     spv.loop control(DontUnroll)
+      spv.loop control(DontUnroll) {
 // CHECK-NEXT:       spv.Branch ^bb1
         spv.Branch ^header
 

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir
index e391bae5b486..9e6ff5698f8a 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir
@@ -11,10 +11,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %two = spv.constant 2: i32
     %var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
 
-// CHECK-NEXT:   spv.selection {
+// CHECK-NEXT:   spv.selection control(Flatten)
 // CHECK-NEXT:     spv.constant 0
 // CHECK-NEXT:     spv.Variable
-    spv.selection {
+    spv.selection control(Flatten) {
 // CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2
       spv.BranchConditional %cond [5, 10], ^then, ^else
 

diff  --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
index 97ee02d45f24..267c45b080fa 100644
--- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
@@ -317,6 +317,16 @@ func @empty_region() -> () {
 
 // -----
 
+// CHECK-LABEL: @loop_with_control
+func @loop_with_control() -> () {
+  // CHECK: spv.loop control(Unroll)
+  spv.loop control(Unroll) {
+  }
+  return
+}
+
+// -----
+
 func @wrong_merge_block() -> () {
   // expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}}
   spv.loop {
@@ -718,6 +728,16 @@ func @empty_region() -> () {
 
 // -----
 
+// CHECK-LABEL: @selection_with_control
+func @selection_with_control() -> () {
+  // CHECK: spv.selection control(Flatten)
+  spv.selection control(Flatten) {
+  }
+  return
+}
+
+// -----
+
 func @wrong_merge_block() -> () {
   // expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}}
   spv.selection {


        


More information about the Mlir-commits mailing list