[Mlir-commits] [mlir] 91f6a5f - [MLIR][SPIRV] Control attributes support for loop and selection
George Mitenkov
llvmlistbot at llvm.org
Sun Aug 2 23:58:24 PDT 2020
Author: George Mitenkov
Date: 2020-08-03T09:31:37+03:00
New Revision: 91f6a5f7854a542611ed76442acb1ec375a9feb2
URL: https://github.com/llvm/llvm-project/commit/91f6a5f7854a542611ed76442acb1ec375a9feb2
DIFF: https://github.com/llvm/llvm-project/commit/91f6a5f7854a542611ed76442acb1ec375a9feb2.diff
LOG: [MLIR][SPIRV] Control attributes support for loop and selection
This patch handles loopControl and selectionControl in parsing and
printing. In order to reuse the functionality, and avoid handling cases when
`{` of the region is parsed as a dictionary attribute, `control` keyword was
introduced.`None` is a default control attribute. This functionality can be
later extended to `spv.func`.
Also, loopControl and selectionControl can now be (de)serialized.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D84175
Added:
Modified:
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/loop.mlir
mlir/test/Dialect/SPIRV/Serialization/selection.mlir
mlir/test/Dialect/SPIRV/control-flow-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index bac65a02f63d..b7d36f4a9487 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -36,6 +36,7 @@ static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kCallee[] = "callee";
static constexpr const char kClusterSize[] = "cluster_size";
+static constexpr const char kControl[] = "control";
static constexpr const char kDefaultValueAttrName[] = "default_value";
static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
@@ -161,6 +162,25 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
return success();
}
+/// Parses Function, Selection and Loop control attributes. If no control is
+/// specified, "None" is used as a default.
+template <typename EnumClass>
+static ParseResult
+parseControlAttribute(OpAsmParser &parser, OperationState &state,
+ StringRef attrName = spirv::attributeName<EnumClass>()) {
+ if (succeeded(parser.parseOptionalKeyword(kControl))) {
+ EnumClass control;
+ if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
+ parser.parseRParen())
+ return failure();
+ return success();
+ }
+ // Set control to "None" otherwise.
+ Builder builder = parser.getBuilder();
+ state.addAttribute(attrName, builder.getI32IntegerAttr(0));
+ return success();
+}
+
/// Parses optional memory access attributes attached to a memory access
/// operand/pointer. Specifically, parses the following syntax:
/// (`[` memory-access `]`)?
@@ -2082,12 +2102,8 @@ void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
}
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
- // TODO: support loop control properly
- Builder builder = parser.getBuilder();
- state.addAttribute("loop_control",
- builder.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::LoopControl::None)));
-
+ if (parseControlAttribute<spirv::LoopControl>(parser, state))
+ return failure();
return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
/*argTypes=*/{});
}
@@ -2096,6 +2112,9 @@ static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
auto *op = loopOp.getOperation();
printer << spirv::LoopOp::getOperationName();
+ auto control = loopOp.loop_control();
+ if (control != spirv::LoopControl::None)
+ printer << " control(" << spirv::stringifyLoopControl(control) << ")";
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
@@ -2445,12 +2464,8 @@ static LogicalResult verify(spirv::SelectOp op) {
static ParseResult parseSelectionOp(OpAsmParser &parser,
OperationState &state) {
- // TODO: support selection control properly
- Builder builder = parser.getBuilder();
- state.addAttribute("selection_control",
- builder.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::SelectionControl::None)));
-
+ if (parseControlAttribute<spirv::SelectionControl>(parser, state))
+ return failure();
return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
/*argTypes=*/{});
}
@@ -2459,6 +2474,9 @@ static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
auto *op = selectionOp.getOperation();
printer << spirv::SelectionOp::getOperationName();
+ auto control = selectionOp.selection_control();
+ if (control != spirv::SelectionControl::None)
+ printer << " control(" << spirv::stringifySelectionControl(control) << ")";
printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 4ba3f16feef0..eaa8f4d94833 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -64,11 +64,14 @@ struct BlockMergeInfo {
Block *mergeBlock;
Block *continueBlock; // nullptr for spv.selection
Location loc;
-
- BlockMergeInfo(Location location)
- : mergeBlock(nullptr), continueBlock(nullptr), loc(location) {}
- BlockMergeInfo(Location location, Block *m, Block *c = nullptr)
- : mergeBlock(m), continueBlock(c), loc(location) {}
+ uint32_t control;
+
+ BlockMergeInfo(Location location, uint32_t control)
+ : mergeBlock(nullptr), continueBlock(nullptr), loc(location),
+ control(control) {}
+ BlockMergeInfo(Location location, uint32_t control, Block *m,
+ Block *c = nullptr)
+ : mergeBlock(m), continueBlock(c), loc(location), control(control) {}
};
/// A struct for containing OpLine instruction information.
@@ -1681,16 +1684,12 @@ LogicalResult Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
"OpSelectionMerge must specify merge target and selection control");
}
- if (static_cast<uint32_t>(spirv::SelectionControl::None) != operands[1]) {
- return emitError(unknownLoc,
- "unimplmented OpSelectionMerge selection control: ")
- << operands[2];
- }
-
auto *mergeBlock = getOrCreateBlock(operands[0]);
auto loc = createFileLineColLoc(opBuilder);
+ auto selectionControl = operands[1];
- if (!blockMergeInfo.try_emplace(curBlock, loc, mergeBlock).second) {
+ if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
+ .second) {
return emitError(
unknownLoc,
"a block cannot have more than one OpSelectionMerge instruction");
@@ -1709,16 +1708,13 @@ LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
"continue target and loop control");
}
- if (static_cast<uint32_t>(spirv::LoopControl::None) != operands[2]) {
- return emitError(unknownLoc, "unimplmented OpLoopMerge loop control: ")
- << operands[2];
- }
-
auto *mergeBlock = getOrCreateBlock(operands[0]);
auto *continueBlock = getOrCreateBlock(operands[1]);
auto loc = createFileLineColLoc(opBuilder);
+ uint32_t loopControl = operands[2];
- if (!blockMergeInfo.try_emplace(curBlock, loc, mergeBlock, continueBlock)
+ if (!blockMergeInfo
+ .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
.second) {
return emitError(
unknownLoc,
@@ -1771,25 +1767,27 @@ class ControlFlowStructurizer {
/// the `headerBlock` will be redirected to the `mergeBlock`.
/// This method will also update `mergeInfo` by remapping all blocks inside to
/// the newly cloned ones inside structured control flow op's regions.
- static LogicalResult structurize(Location loc, BlockMergeInfoMap &mergeInfo,
+ static LogicalResult structurize(Location loc, uint32_t control,
+ BlockMergeInfoMap &mergeInfo,
Block *headerBlock, Block *mergeBlock,
Block *continueBlock) {
- return ControlFlowStructurizer(loc, mergeInfo, headerBlock, mergeBlock,
- continueBlock)
+ return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock,
+ mergeBlock, continueBlock)
.structurizeImpl();
}
private:
- ControlFlowStructurizer(Location loc, BlockMergeInfoMap &mergeInfo,
- Block *header, Block *merge, Block *cont)
- : location(loc), blockMergeInfo(mergeInfo), headerBlock(header),
- mergeBlock(merge), continueBlock(cont) {}
+ ControlFlowStructurizer(Location loc, uint32_t control,
+ BlockMergeInfoMap &mergeInfo, Block *header,
+ Block *merge, Block *cont)
+ : location(loc), control(control), blockMergeInfo(mergeInfo),
+ headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
/// Creates a new spv.selection op at the beginning of the `mergeBlock`.
- spirv::SelectionOp createSelectionOp();
+ spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
/// Creates a new spv.loop op at the beginning of the `mergeBlock`.
- spirv::LoopOp createLoopOp();
+ spirv::LoopOp createLoopOp(uint32_t loopControl);
/// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
void collectBlocksInConstruct();
@@ -1797,6 +1795,7 @@ class ControlFlowStructurizer {
LogicalResult structurizeImpl();
Location location;
+ uint32_t control;
BlockMergeInfoMap &blockMergeInfo;
@@ -1808,26 +1807,26 @@ class ControlFlowStructurizer {
};
} // namespace
-spirv::SelectionOp ControlFlowStructurizer::createSelectionOp() {
+spirv::SelectionOp
+ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
// Create a builder and set the insertion point to the beginning of the
// merge block so that the newly created SelectionOp will be inserted there.
OpBuilder builder(&mergeBlock->front());
- auto control = builder.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::SelectionControl::None));
+ auto control = builder.getI32IntegerAttr(selectionControl);
auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
selectionOp.addMergeBlock();
return selectionOp;
}
-spirv::LoopOp ControlFlowStructurizer::createLoopOp() {
+spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
// Create a builder and set the insertion point to the beginning of the
// merge block so that the newly created LoopOp will be inserted there.
OpBuilder builder(&mergeBlock->front());
- // TODO: handle loop control properly
- auto loopOp = builder.create<spirv::LoopOp>(location);
+ auto control = builder.getI32IntegerAttr(loopControl);
+ auto loopOp = builder.create<spirv::LoopOp>(location, control);
loopOp.addEntryAndMergeBlock();
return loopOp;
@@ -1852,10 +1851,10 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
Operation *op = nullptr;
bool isLoop = continueBlock != nullptr;
if (isLoop) {
- if (auto loopOp = createLoopOp())
+ if (auto loopOp = createLoopOp(control))
op = loopOp.getOperation();
} else {
- if (auto selectionOp = createSelectionOp())
+ if (auto selectionOp = createSelectionOp(control))
op = selectionOp.getOperation();
}
if (!op)
@@ -1992,7 +1991,8 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
// The iterator should be erased before adding a new entry into
// blockMergeInfo to avoid iterator invalidation.
blockMergeInfo.erase(it);
- blockMergeInfo.try_emplace(newHeader, loc, newMerge, newContinue);
+ blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
+ newContinue);
}
// The structured selection/loop's entry block does not have arguments.
@@ -2096,9 +2096,9 @@ LogicalResult Deserializer::structurizeControlFlow() {
// Erase this case before calling into structurizer, who will update
// blockMergeInfo.
blockMergeInfo.erase(blockMergeInfo.begin());
- if (failed(ControlFlowStructurizer::structurize(mergeInfo.loc,
- blockMergeInfo, headerBlock,
- mergeBlock, continueBlock)))
+ if (failed(ControlFlowStructurizer::structurize(
+ mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock,
+ mergeBlock, continueBlock)))
return failure();
}
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 859ea556f39f..223adf47ab2e 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1573,10 +1573,9 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
auto emitSelectionMerge = [&]() {
emitDebugLine(functionBody, loc);
lastProcessedWasMergeInst = true;
- // TODO: properly support selection control here
encodeInstructionInto(
functionBody, spirv::Opcode::OpSelectionMerge,
- {mergeID, static_cast<uint32_t>(spirv::SelectionControl::None)});
+ {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
};
// For structured selection, we cannot have blocks in the selection construct
// branching to the selection header block. Entering the selection (and
@@ -1636,10 +1635,9 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
auto emitLoopMerge = [&]() {
emitDebugLine(functionBody, loc);
lastProcessedWasMergeInst = true;
- // TODO: properly support loop control here
encodeInstructionInto(
functionBody, spirv::Opcode::OpLoopMerge,
- {mergeID, continueID, static_cast<uint32_t>(spirv::LoopControl::None)});
+ {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
};
if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
return failure();
diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
index d6e2090f02bb..8f0b35ef6fc8 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
@@ -119,8 +119,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: spv.Branch ^bb1
// CHECK-NEXT: ^bb1:
-// CHECK-NEXT: spv.loop
- spv.loop {
+// CHECK-NEXT: spv.loop control(Unroll)
+ spv.loop control(Unroll) {
// CHECK-NEXT: spv.Branch ^bb1
spv.Branch ^header
@@ -140,8 +140,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.Store "Function" %jvar, %zero : i32
// CHECK-NEXT: spv.Branch ^bb3
// CHECK-NEXT: ^bb3:
-// CHECK-NEXT: spv.loop
- spv.loop {
+// CHECK-NEXT: spv.loop control(DontUnroll)
+ spv.loop control(DontUnroll) {
// CHECK-NEXT: spv.Branch ^bb1
spv.Branch ^header
diff --git a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir
index e391bae5b486..9e6ff5698f8a 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir
@@ -11,10 +11,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%two = spv.constant 2: i32
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
-// CHECK-NEXT: spv.selection {
+// CHECK-NEXT: spv.selection control(Flatten)
// CHECK-NEXT: spv.constant 0
// CHECK-NEXT: spv.Variable
- spv.selection {
+ spv.selection control(Flatten) {
// CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2
spv.BranchConditional %cond [5, 10], ^then, ^else
diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
index 97ee02d45f24..267c45b080fa 100644
--- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir
@@ -317,6 +317,16 @@ func @empty_region() -> () {
// -----
+// CHECK-LABEL: @loop_with_control
+func @loop_with_control() -> () {
+ // CHECK: spv.loop control(Unroll)
+ spv.loop control(Unroll) {
+ }
+ return
+}
+
+// -----
+
func @wrong_merge_block() -> () {
// expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}}
spv.loop {
@@ -718,6 +728,16 @@ func @empty_region() -> () {
// -----
+// CHECK-LABEL: @selection_with_control
+func @selection_with_control() -> () {
+ // CHECK: spv.selection control(Flatten)
+ spv.selection control(Flatten) {
+ }
+ return
+}
+
+// -----
+
func @wrong_merge_block() -> () {
// expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}}
spv.selection {
More information about the Mlir-commits
mailing list