[llvm-branch-commits] [mlir] 6551c9a - [mlir][spirv] Add parsing and printing support for SpecConstantOperation
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Dec 16 05:32:20 PST 2020
Author: ergawy
Date: 2020-12-16T08:26:48-05:00
New Revision: 6551c9ac365ca46e83354703d1a63c671a50258a
URL: https://github.com/llvm/llvm-project/commit/6551c9ac365ca46e83354703d1a63c671a50258a
DIFF: https://github.com/llvm/llvm-project/commit/6551c9ac365ca46e83354703d1a63c671a50258a.diff
LOG: [mlir][spirv] Add parsing and printing support for SpecConstantOperation
Adds more support for `SpecConstantOperation` by defining a custom
syntax for the op and implementing its parsing and printing.
Reviewed By: mravishankar, antiagainst
Differential Revision: https://reviews.llvm.org/D92919
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/structure-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index b8e76c3662ec..1ae7d285cd93 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -608,9 +608,12 @@ def SPV_SpecConstantCompositeOp : SPV_Op<"specConstantComposite", [InModuleScope
let autogenSerialization = 0;
}
-def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> {
- let summary = "Yields the result computed in `spv.SpecConstantOperation`'s"
- "region back to the parent op.";
+def SPV_YieldOp : SPV_Op<"mlir.yield", [
+ HasParent<"SpecConstantOperationOp">, NoSideEffect, Terminator]> {
+ let summary = [{
+ Yields the result computed in `spv.SpecConstantOperation`'s
+ region back to the parent op.
+ }];
let description = [{
This op is a special terminator whose only purpose is to terminate
@@ -639,12 +642,16 @@ def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> {
let autogenSerialization = 0;
let assemblyFormat = "attr-dict $operand `:` type($operand)";
+
+ let verifier = [{ return success(); }];
}
def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
- InFunctionScope, NoSideEffect,
- IsolatedFromAbove]> {
- let summary = "Declare a new specialization constant that results from doing an operation.";
+ NoSideEffect, InFunctionScope,
+ SingleBlockImplicitTerminator<"YieldOp">]> {
+ let summary = [{
+ Declare a new specialization constant that results from doing an operation.
+ }];
let description = [{
This op declares a SPIR-V specialization constant that results from
@@ -653,12 +660,8 @@ def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
In the `spv` dialect, this op is modelled as follows:
```
- spv-spec-constant-operation-op ::= `"spv.SpecConstantOperation"`
- `(`ssa-id (`, ` ssa-id)`)`
- `({`
- ssa-id = spirv-op
- `spv.mlir.yield` ssa-id
- `})` `:` function-type
+ spv-spec-constant-operation-op ::= `spv.SpecConstantOperation` `wraps`
+ generic-spirv-op `:` function-type
```
In particular, an `spv.SpecConstantOperation` contains exactly one
@@ -712,17 +715,15 @@ def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
#### Example:
```mlir
%0 = spv.constant 1: i32
+ %1 = spv.constant 1: i32
- %1 = "spv.SpecConstantOperation"(%0) ({
- %ret = spv.IAdd %0, %0 : i32
- spv.mlir.yield %ret : i32
- }) : (i32) -> i32
+ %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32
```
}];
- let arguments = (ins Variadic<AnyType>:$operands);
+ let arguments = (ins);
- let results = (outs AnyType:$results);
+ let results = (outs AnyType:$result);
let regions = (region SizedRegion<1>:$body);
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 03e416e95441..43b3c517a4c6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -3396,35 +3396,39 @@ static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
}
//===----------------------------------------------------------------------===//
-// spv.mlir.yield
+// spv.SpecConstantOperation
//===----------------------------------------------------------------------===//
-static LogicalResult verify(spirv::YieldOp yieldOp) {
- Operation *parentOp = yieldOp->getParentOp();
+static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
+ OperationState &state) {
+ Region *body = state.addRegion();
- if (!parentOp || !isa<spirv::SpecConstantOperationOp>(parentOp))
- return yieldOp.emitOpError(
- "expected parent op to be 'spv.SpecConstantOperation'");
+ if (parser.parseKeyword("wraps"))
+ return failure();
- Block &block = parentOp->getRegion(0).getBlocks().front();
- Operation &enclosedOp = block.getOperations().front();
+ body->push_back(new Block);
+ Block &block = body->back();
+ Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
- if (yieldOp.getOperand().getDefiningOp() != &enclosedOp)
- return yieldOp.emitOpError(
- "expected operand to be defined by preceeding op");
+ if (!wrappedOp)
+ return failure();
- return success();
-}
+ OpBuilder builder(parser.getBuilder().getContext());
+ builder.setInsertionPointToEnd(&block);
+ builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
+ state.location = wrappedOp->getLoc();
-static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
- OperationState &state) {
- // TODO: For now, only generic form is supported.
- return failure();
+ state.addTypes(wrappedOp->getResult(0).getType());
+
+ if (parser.parseOptionalAttrDict(state.attributes))
+ return failure();
+
+ return success();
}
static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
- // TODO
- printer.printGenericOp(op);
+ printer << op.getOperationName() << " wraps ";
+ printer.printGenericOp(&op.body().front().front());
}
static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
@@ -3433,11 +3437,6 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
if (block.getOperations().size() != 2)
return constOp.emitOpError("expected exactly 2 nested ops");
- Operation &yieldOp = block.getOperations().back();
-
- if (!isa<spirv::YieldOp>(yieldOp))
- return constOp.emitOpError("expected terminator to be a yield op");
-
Operation &enclosedOp = block.getOperations().front();
// TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below
@@ -3457,21 +3456,12 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp))
return constOp.emitOpError("invalid enclosed op");
- if (enclosedOp.getNumOperands() != constOp.getOperands().size())
- return constOp.emitOpError("invalid number of operands; expected ")
- << enclosedOp.getNumOperands() << ", actual "
- << constOp.getOperands().size();
-
- if (enclosedOp.getNumOperands() != constOp.getRegion().getNumArguments())
- return constOp.emitOpError("invalid number of region arguments; expected ")
- << enclosedOp.getNumOperands() << ", actual "
- << constOp.getRegion().getNumArguments();
-
- for (auto operand : constOp.getOperands())
+ for (auto operand : enclosedOp.getOperands())
if (!isa<spirv::ConstantOp, spirv::SpecConstantOp,
spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>(
operand.getDefiningOp()))
- return constOp.emitOpError("invalid operand");
+ return constOp.emitOpError(
+ "invalid operand, must be defined by a constant operation");
return success();
}
diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index 89a30e23dec9..c0b495115d6c 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -757,6 +757,7 @@ spv.module Logical GLSL450 {
// expected-error @+1 {{unsupported composite type}}
spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device>
}
+
//===----------------------------------------------------------------------===//
// spv.SpecConstantOperation
//===----------------------------------------------------------------------===//
@@ -765,34 +766,15 @@ spv.module Logical GLSL450 {
spv.module Logical GLSL450 {
spv.func @foo() -> i32 "None" {
+ // CHECK: [[LHS:%.*]] = spv.constant
%0 = spv.constant 1: i32
- %2 = spv.constant 1: i32
-
- %1 = "spv.SpecConstantOperation"(%0, %0) ({
- ^bb(%lhs : i32, %rhs : i32):
- %ret = spv.IAdd %lhs, %rhs : i32
- spv.mlir.yield %ret : i32
- }) : (i32, i32) -> i32
-
- spv.ReturnValue %1 : i32
- }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
- spv.func @foo() -> i32 "None" {
- %0 = spv.constant 1: i32
- %2 = spv.constant 1: i32
+ // CHECK: [[RHS:%.*]] = spv.constant
+ %1 = spv.constant 1: i32
- // expected-error @+1 {{invalid number of operands; expected 2, actual 1}}
- %1 = "spv.SpecConstantOperation"(%0) ({
- ^bb(%lhs : i32, %rhs : i32):
- %ret = spv.IAdd %lhs, %rhs : i32
- spv.mlir.yield %ret : i32
- }) : (i32) -> i32
+ // CHECK: spv.SpecConstantOperation wraps "spv.IAdd"([[LHS]], [[RHS]]) : (i32, i32) -> i32
+ %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32
- spv.ReturnValue %1 : i32
+ spv.ReturnValue %2 : i32
}
}
@@ -801,93 +783,20 @@ spv.module Logical GLSL450 {
spv.module Logical GLSL450 {
spv.func @foo() -> i32 "None" {
%0 = spv.constant 1: i32
- %2 = spv.constant 1: i32
-
- // expected-error @+1 {{invalid number of region arguments; expected 2, actual 1}}
- %1 = "spv.SpecConstantOperation"(%0, %0) ({
- ^bb(%lhs : i32):
- %ret = spv.IAdd %lhs, %lhs : i32
- spv.mlir.yield %ret : i32
- }) : (i32, i32) -> i32
-
- spv.ReturnValue %1 : i32
- }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
- spv.func @foo() -> i32 "None" {
- %0 = spv.constant 1: i32
- // expected-error @+1 {{expected parent op to be 'spv.SpecConstantOperation'}}
+ // expected-error @+1 {{op expects parent op 'spv.SpecConstantOperation'}}
spv.mlir.yield %0 : i32
}
}
// -----
-spv.module Logical GLSL450 {
- spv.func @foo() -> i32 "None" {
- %0 = spv.constant 1: i32
-
- %1 = "spv.SpecConstantOperation"(%0, %0) ({
- ^bb(%lhs : i32, %rhs : i32):
- %ret = spv.ISub %lhs, %rhs : i32
- // expected-error @+1 {{expected operand to be defined by preceeding op}}
- spv.mlir.yield %lhs : i32
- }) : (i32, i32) -> i32
-
- spv.ReturnValue %1 : i32
- }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
- spv.func @foo() -> i32 "None" {
- %0 = spv.constant 1: i32
-
- // expected-error @+1 {{expected exactly 2 nested ops}}
- %1 = "spv.SpecConstantOperation"(%0, %0) ({
- ^bb(%lhs : i32, %rhs : i32):
- %ret = spv.IAdd %lhs, %rhs : i32
- %ret2 = spv.IAdd %lhs, %rhs : i32
- spv.mlir.yield %ret : i32
- }) : (i32, i32) -> i32
-
- spv.ReturnValue %1 : i32
- }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
- spv.func @foo() -> i32 "None" {
- %0 = spv.constant 1: i32
-
- // expected-error @+1 {{expected terminator to be a yield op}}
- %1 = "spv.SpecConstantOperation"(%0, %0) ({
- ^bb(%lhs : i32, %rhs : i32):
- %ret = spv.IAdd %lhs, %rhs : i32
- spv.ReturnValue %ret : i32
- }) : (i32, i32) -> i32
-
- spv.ReturnValue %1 : i32
- }
-}
-
-// -----
-
spv.module Logical GLSL450 {
spv.func @foo() -> () "None" {
%0 = spv.Variable : !spv.ptr<i32, Function>
// expected-error @+1 {{invalid enclosed op}}
- %2 = "spv.SpecConstantOperation"(%0) ({
- ^bb(%arg0 : !spv.ptr<i32, Function>):
- %ret = spv.Load "Function" %arg0 : i32
- spv.mlir.yield %ret : i32
- }) : (!spv.ptr<i32, Function>) -> i32
+ %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<i32, Function>) -> i32
+ spv.Return
}
}
@@ -898,11 +807,9 @@ spv.module Logical GLSL450 {
%0 = spv.Variable : !spv.ptr<i32, Function>
%1 = spv.Load "Function" %0 : i32
- // expected-error @+1 {{invalid operand}}
- %2 = "spv.SpecConstantOperation"(%1, %1) ({
- ^bb(%lhs: i32, %rhs: i32):
- %ret = spv.IAdd %lhs, %lhs : i32
- spv.mlir.yield %ret : i32
- }) : (i32, i32) -> i32
+ // expected-error @+1 {{invalid operand, must be defined by a constant operation}}
+ %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%1, %1) : (i32, i32) -> i32
+
+ spv.Return
}
}
More information about the llvm-branch-commits
mailing list