[llvm-branch-commits] [mlir] 6551c9a - [mlir][spirv] Add parsing and printing support for SpecConstantOperation

Lei Zhang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 16 05:32:20 PST 2020


Author: ergawy
Date: 2020-12-16T08:26:48-05:00
New Revision: 6551c9ac365ca46e83354703d1a63c671a50258a

URL: https://github.com/llvm/llvm-project/commit/6551c9ac365ca46e83354703d1a63c671a50258a
DIFF: https://github.com/llvm/llvm-project/commit/6551c9ac365ca46e83354703d1a63c671a50258a.diff

LOG: [mlir][spirv] Add parsing and printing support for SpecConstantOperation

Adds more support for `SpecConstantOperation` by defining a custom
syntax for the op and implementing its parsing and printing.

Reviewed By: mravishankar, antiagainst

Differential Revision: https://reviews.llvm.org/D92919

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/structure-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index b8e76c3662ec..1ae7d285cd93 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -608,9 +608,12 @@ def SPV_SpecConstantCompositeOp : SPV_Op<"specConstantComposite", [InModuleScope
   let autogenSerialization = 0;
 }
 
-def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> {
-  let summary = "Yields the result computed in `spv.SpecConstantOperation`'s"
-                "region back to the parent op.";
+def SPV_YieldOp : SPV_Op<"mlir.yield", [
+    HasParent<"SpecConstantOperationOp">, NoSideEffect, Terminator]> {
+  let summary = [{
+    Yields the result computed in `spv.SpecConstantOperation`'s
+    region back to the parent op.
+  }];
 
   let description = [{
     This op is a special terminator whose only purpose is to terminate
@@ -639,12 +642,16 @@ def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> {
   let autogenSerialization = 0;
 
   let assemblyFormat = "attr-dict $operand `:` type($operand)";
+
+  let verifier = [{ return success(); }];
 }
 
 def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
-                                         InFunctionScope, NoSideEffect,
-                                         IsolatedFromAbove]> {
-  let summary = "Declare a new specialization constant that results from doing an operation.";
+       NoSideEffect, InFunctionScope,
+       SingleBlockImplicitTerminator<"YieldOp">]> {
+  let summary = [{
+    Declare a new specialization constant that results from doing an operation.
+  }];
 
   let description = [{
     This op declares a SPIR-V specialization constant that results from
@@ -653,12 +660,8 @@ def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
     In the `spv` dialect, this op is modelled as follows:
 
     ```
-    spv-spec-constant-operation-op ::= `"spv.SpecConstantOperation"`
-                                         `(`ssa-id (`, ` ssa-id)`)`
-                                       `({`
-                                         ssa-id = spirv-op
-                                         `spv.mlir.yield` ssa-id
-                                       `})` `:` function-type
+    spv-spec-constant-operation-op ::= `spv.SpecConstantOperation` `wraps`
+                                         generic-spirv-op `:` function-type
     ```
 
     In particular, an `spv.SpecConstantOperation` contains exactly one
@@ -712,17 +715,15 @@ def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
     #### Example:
     ```mlir
     %0 = spv.constant 1: i32
+    %1 = spv.constant 1: i32
 
-    %1 = "spv.SpecConstantOperation"(%0) ({
-      %ret = spv.IAdd %0, %0 : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32) -> i32
+    %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32
     ```
   }];
 
-  let arguments = (ins Variadic<AnyType>:$operands);
+  let arguments = (ins);
 
-  let results = (outs AnyType:$results);
+  let results = (outs AnyType:$result);
 
   let regions = (region SizedRegion<1>:$body);
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 03e416e95441..43b3c517a4c6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -3396,35 +3396,39 @@ static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
 }
 
 //===----------------------------------------------------------------------===//
-// spv.mlir.yield
+// spv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(spirv::YieldOp yieldOp) {
-  Operation *parentOp = yieldOp->getParentOp();
+static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
+                                                OperationState &state) {
+  Region *body = state.addRegion();
 
-  if (!parentOp || !isa<spirv::SpecConstantOperationOp>(parentOp))
-    return yieldOp.emitOpError(
-        "expected parent op to be 'spv.SpecConstantOperation'");
+  if (parser.parseKeyword("wraps"))
+    return failure();
 
-  Block &block = parentOp->getRegion(0).getBlocks().front();
-  Operation &enclosedOp = block.getOperations().front();
+  body->push_back(new Block);
+  Block &block = body->back();
+  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
 
-  if (yieldOp.getOperand().getDefiningOp() != &enclosedOp)
-    return yieldOp.emitOpError(
-        "expected operand to be defined by preceeding op");
+  if (!wrappedOp)
+    return failure();
 
-  return success();
-}
+  OpBuilder builder(parser.getBuilder().getContext());
+  builder.setInsertionPointToEnd(&block);
+  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
+  state.location = wrappedOp->getLoc();
 
-static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
-                                                OperationState &state) {
-  // TODO: For now, only generic form is supported.
-  return failure();
+  state.addTypes(wrappedOp->getResult(0).getType());
+
+  if (parser.parseOptionalAttrDict(state.attributes))
+    return failure();
+
+  return success();
 }
 
 static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
-  // TODO
-  printer.printGenericOp(op);
+  printer << op.getOperationName() << " wraps ";
+  printer.printGenericOp(&op.body().front().front());
 }
 
 static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
@@ -3433,11 +3437,6 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
   if (block.getOperations().size() != 2)
     return constOp.emitOpError("expected exactly 2 nested ops");
 
-  Operation &yieldOp = block.getOperations().back();
-
-  if (!isa<spirv::YieldOp>(yieldOp))
-    return constOp.emitOpError("expected terminator to be a yield op");
-
   Operation &enclosedOp = block.getOperations().front();
 
   // TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below
@@ -3457,21 +3456,12 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
            spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp))
     return constOp.emitOpError("invalid enclosed op");
 
-  if (enclosedOp.getNumOperands() != constOp.getOperands().size())
-    return constOp.emitOpError("invalid number of operands; expected ")
-           << enclosedOp.getNumOperands() << ", actual "
-           << constOp.getOperands().size();
-
-  if (enclosedOp.getNumOperands() != constOp.getRegion().getNumArguments())
-    return constOp.emitOpError("invalid number of region arguments; expected ")
-           << enclosedOp.getNumOperands() << ", actual "
-           << constOp.getRegion().getNumArguments();
-
-  for (auto operand : constOp.getOperands())
+  for (auto operand : enclosedOp.getOperands())
     if (!isa<spirv::ConstantOp, spirv::SpecConstantOp,
              spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>(
             operand.getDefiningOp()))
-      return constOp.emitOpError("invalid operand");
+      return constOp.emitOpError(
+          "invalid operand, must be defined by a constant operation");
 
   return success();
 }

diff  --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index 89a30e23dec9..c0b495115d6c 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -757,6 +757,7 @@ spv.module Logical GLSL450 {
   // expected-error @+1 {{unsupported composite type}}
   spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device>
 }
+
 //===----------------------------------------------------------------------===//
 // spv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
@@ -765,34 +766,15 @@ spv.module Logical GLSL450 {
 
 spv.module Logical GLSL450 {
   spv.func @foo() -> i32 "None" {
+    // CHECK: [[LHS:%.*]] = spv.constant
     %0 = spv.constant 1: i32
-    %2 = spv.constant 1: i32
-
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.IAdd %lhs, %rhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-    %2 = spv.constant 1: i32
+    // CHECK: [[RHS:%.*]] = spv.constant
+    %1 = spv.constant 1: i32
 
-    // expected-error @+1 {{invalid number of operands; expected 2, actual 1}}
-    %1 = "spv.SpecConstantOperation"(%0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.IAdd %lhs, %rhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32) -> i32
+    // CHECK: spv.SpecConstantOperation wraps "spv.IAdd"([[LHS]], [[RHS]]) : (i32, i32) -> i32
+    %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32
 
-    spv.ReturnValue %1 : i32
+    spv.ReturnValue %2 : i32
   }
 }
 
@@ -801,93 +783,20 @@ spv.module Logical GLSL450 {
 spv.module Logical GLSL450 {
   spv.func @foo() -> i32 "None" {
     %0 = spv.constant 1: i32
-    %2 = spv.constant 1: i32
-
-    // expected-error @+1 {{invalid number of region arguments; expected 2, actual 1}}
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32):
-      %ret = spv.IAdd %lhs, %lhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-    // expected-error @+1 {{expected parent op to be 'spv.SpecConstantOperation'}}
+    // expected-error @+1 {{op expects parent op 'spv.SpecConstantOperation'}}
     spv.mlir.yield %0 : i32
   }
 }
 
 // -----
 
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.ISub %lhs, %rhs : i32
-      // expected-error @+1 {{expected operand to be defined by preceeding op}}
-      spv.mlir.yield %lhs : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-
-    // expected-error @+1 {{expected exactly 2 nested ops}}
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.IAdd %lhs, %rhs : i32
-      %ret2 = spv.IAdd %lhs, %rhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-
-    // expected-error @+1 {{expected terminator to be a yield op}}
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.IAdd %lhs, %rhs : i32
-      spv.ReturnValue %ret : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
 spv.module Logical GLSL450 {
   spv.func @foo() -> () "None" {
     %0 = spv.Variable : !spv.ptr<i32, Function>
 
     // expected-error @+1 {{invalid enclosed op}}
-    %2 = "spv.SpecConstantOperation"(%0) ({
-    ^bb(%arg0 : !spv.ptr<i32, Function>):
-      %ret = spv.Load "Function" %arg0 : i32
-      spv.mlir.yield %ret : i32
-    }) : (!spv.ptr<i32, Function>) -> i32
+    %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<i32, Function>) -> i32
+    spv.Return
   }
 }
 
@@ -898,11 +807,9 @@ spv.module Logical GLSL450 {
     %0 = spv.Variable : !spv.ptr<i32, Function>
     %1 = spv.Load "Function" %0 : i32
 
-    // expected-error @+1 {{invalid operand}}
-    %2 = "spv.SpecConstantOperation"(%1, %1) ({
-    ^bb(%lhs: i32, %rhs: i32):
-      %ret = spv.IAdd %lhs, %lhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32, i32) -> i32
+    // expected-error @+1 {{invalid operand, must be defined by a constant operation}}
+    %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%1, %1) : (i32, i32) -> i32
+
+    spv.Return
   }
 }


        


More information about the llvm-branch-commits mailing list