[Mlir-commits] [mlir] 0c8f9b8 - [MLIR][SPIRV] Add initial support for OpSpecConstantComposite.
    Lei Zhang 
    llvmlistbot at llvm.org
       
    Fri Oct  2 12:18:29 PDT 2020
    
    
  
Author: ergawy
Date: 2020-10-02T15:18:16-04:00
New Revision: 0c8f9b8099fd0500cd885bc699924e20371014ff
URL: https://github.com/llvm/llvm-project/commit/0c8f9b8099fd0500cd885bc699924e20371014ff
DIFF: https://github.com/llvm/llvm-project/commit/0c8f9b8099fd0500cd885bc699924e20371014ff.diff
LOG: [MLIR][SPIRV] Add initial support for OpSpecConstantComposite.
This commit adds support to SPIR-V's composite specialization constants.
These are specialization constants which are composed of other spec
constants (whehter scalar or composite), regular constatns, or undef
values.
This commit adds support for parsing, printing, verification, and
(De)serialization.
A few TODOs are still in order:
- Supporting more types of constituents; currently, only scalar spec constatns are supported.
- Extending `spv._reference_of` to support composite spec constatns.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D88568
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
    mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
    mlir/test/Dialect/SPIRV/structure-ops.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index 2ac28ef87ba9..0e866f02b011 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -491,6 +491,8 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
     ```mlir
     %0 = spv._reference_of @spec_const : f32
     ```
+
+    TODO Add support for composite specialization constants.
   }];
 
   let arguments = (ins
@@ -541,8 +543,6 @@ def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> {
     spv.specConstant @spec_const1 = true
     spv.specConstant @spec_const2 spec_id(5) = 42 : i32
     ```
-
-    TODO: support composite spec constants with another op
   }];
 
   let arguments = (ins
@@ -557,6 +557,56 @@ def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> {
   let autogenSerialization = 0;
 }
 
+def SPV_SpecConstantCompositeOp : SPV_Op<"specConstantComposite", [InModuleScope, Symbol]> {
+  let summary = "Declare a new composite specialization constant.";
+
+  let description = [{
+    This op declares a SPIR-V composite specialization constant. This covers
+    the `OpSpecConstantComposite` SPIR-V instruction. Scalar constants are
+    covered by `spv.specConstant`.
+
+    A constituent of a spec constant composite can be:
+    - A symbol referring of another spec constant.
+    - The SSA ID of a non-specialization constant (i.e. defined through
+      `spv.specConstant`).
+    - The SSA ID of a `spv.undef`.
+
+    ```
+    spv-spec-constant-composite-op ::= `spv.specConstantComposite` symbol-ref-id ` (`
+                                       symbol-ref-id (`, ` symbol-ref-id)*
+                                       `) :` composite-type
+    ```
+
+     where `composite-type` is some non-scalar type that can be represented in the `spv`
+     dialect: `spv.struct`, `spv.array`, or `vector`.
+
+     #### Example:
+
+     ```mlir
+     spv.specConstant @sc1 = 1   : i32
+     spv.specConstant @sc2 = 2.5 : f32
+     spv.specConstant @sc3 = 3.5 : f32
+     spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32, f32>
+     ```
+
+    TODO Add support for constituents that are:
+    - regular constants.
+    - undef.
+    - spec constant composite.
+  }];
+
+  let arguments = (ins
+    TypeAttr:$type,
+    StrAttr:$sym_name,
+    SymbolRefArrayAttr:$constituents
+  );
+
+  let results = (outs);
+
+  let hasOpcode = 0;
+
+  let autogenSerialization = 0;
+}
 // -----
 
 #endif // SPIRV_STRUCTURE_OPS
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index a01177132b27..363785e2b782 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -53,6 +53,7 @@ static constexpr const char kTypeAttrName[] = "type";
 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
 static constexpr const char kValueAttrName[] = "value";
 static constexpr const char kValuesAttrName[] = "values";
+static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
 
 //===----------------------------------------------------------------------===//
 // Common utility functions
@@ -3287,6 +3288,95 @@ static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser,
+                                                OperationState &state) {
+
+  StringAttr compositeName;
+  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+                             state.attributes))
+    return failure();
+
+  if (parser.parseLParen())
+    return failure();
+
+  SmallVector<Attribute, 4> constituents;
+
+  do {
+    // The name of the constituent attribute isn't important
+    const char *attrName = "spec_const";
+    FlatSymbolRefAttr specConstRef;
+    NamedAttrList attrs;
+
+    if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
+      return failure();
+
+    constituents.push_back(specConstRef);
+  } while (!parser.parseOptionalComma());
+
+  if (parser.parseRParen())
+    return failure();
+
+  state.addAttribute(kCompositeSpecConstituentsName,
+                     parser.getBuilder().getArrayAttr(constituents));
+
+  Type type;
+  if (parser.parseColonType(type))
+    return failure();
+
+  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
+
+  return success();
+}
+
+static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) {
+  printer << spirv::SpecConstantCompositeOp::getOperationName() << " ";
+  printer.printSymbolName(op.sym_name());
+  printer << " (";
+  auto constituents = op.constituents().getValue();
+
+  if (!constituents.empty())
+    llvm::interleaveComma(constituents, printer);
+
+  printer << ") : " << op.type();
+}
+
+static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
+  auto cType = constOp.type().dyn_cast<spirv::CompositeType>();
+  auto constituents = constOp.constituents().getValue();
+
+  if (!cType)
+    return constOp.emitError(
+               "result type must be a composite type, but provided ")
+           << constOp.type();
+
+  if (cType.isa<spirv::CooperativeMatrixNVType>())
+    return constOp.emitError("unsupported composite type  ") << cType;
+  else if (constituents.size() != cType.getNumElements())
+    return constOp.emitError("has incorrect number of operands: expected ")
+           << cType.getNumElements() << ", but provided "
+           << constituents.size();
+
+  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
+    auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
+
+    auto constituentSpecConstOp =
+        dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+            constOp.getParentOp(), constituent.getValue()));
+
+    if (constituentSpecConstOp.default_value().getType() !=
+        cType.getElementType(index))
+      return constOp.emitError("has incorrect types of operands: expected ")
+             << cType.getElementType(index) << ", but provided "
+             << constituentSpecConstOp.default_value().getType();
+  }
+
+  return success();
+}
+
 namespace mlir {
 namespace spirv {
 
diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index b5eea4333824..153540ddb281 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -249,6 +249,8 @@ class Deserializer {
   /// `operands`.
   LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
 
+  LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
+
   /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
   LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
 
@@ -1546,6 +1548,39 @@ Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 2) {
+    return emitError(unknownLoc,
+                     "OpConstantComposite must have type <id> and result <id>");
+  }
+  if (operands.size() < 3) {
+    return emitError(unknownLoc,
+                     "OpConstantComposite must have at least 1 parameter");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1]));
+
+  SmallVector<Attribute, 4> elements;
+  elements.reserve(operands.size() - 2);
+  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
+    auto elementInfo = getSpecConstant(operands[i]);
+    elements.push_back(opBuilder.getSymbolRefAttr(elementInfo));
+  }
+
+  opBuilder.create<spirv::SpecConstantCompositeOp>(
+      unknownLoc, TypeAttr::get(resultType), symName,
+      opBuilder.getArrayAttr(elements));
+
+  return success();
+}
+
 LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc,
@@ -2276,6 +2311,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
     return processConstant(operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantComposite:
     return processConstantComposite(operands);
+  case spirv::Opcode::OpSpecConstantComposite:
+    return processSpecConstantComposite(operands);
   case spirv::Opcode::OpConstantTrue:
     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
   case spirv::Opcode::OpSpecConstantTrue:
diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 1eda166a0325..426c838a7e5d 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -200,6 +200,9 @@ class Serializer {
 
   LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
 
+  LogicalResult
+  processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
+
   /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
   /// value to use with other operations. The SPIR-V spec recommends that
   /// OpUndef be generated at module level. The serialization generates an
@@ -645,6 +648,42 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
   return failure();
 }
 
+LogicalResult
+Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
+  uint32_t typeID = 0;
+  if (failed(processType(op.getLoc(), op.type(), typeID))) {
+    return failure();
+  }
+
+  auto resultID = getNextID();
+
+  SmallVector<uint32_t, 8> operands;
+  operands.push_back(typeID);
+  operands.push_back(resultID);
+
+  auto constituents = op.constituents();
+
+  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
+    auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
+
+    auto constituentName = constituent.getValue();
+    auto constituentID = getSpecConstID(constituentName);
+
+    if (!constituentID) {
+      return op.emitError("unknown result <id> for specialization constant ")
+             << constituentName;
+    }
+
+    operands.push_back(constituentID);
+  }
+
+  encodeInstructionInto(typesGlobalValues,
+                        spirv::Opcode::OpSpecConstantComposite, operands);
+  specConstIDMap[op.sym_name()] = resultID;
+
+  return processName(resultID, op.sym_name());
+}
+
 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   auto undefType = op.getType();
   auto &id = undefValIDMap[undefType];
@@ -1765,6 +1804,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
+      .Case([&](spirv::SpecConstantCompositeOp op) {
+        return processSpecConstantCompositeOp(op);
+      })
       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
 
diff  --git a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
index 03cc85b8c087..0df930162c74 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
 
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK: spv.specConstant @sc_true = true
@@ -25,3 +25,23 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     spv.ReturnValue %1 : i32
   }
 }
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+
+  spv.specConstant @sc_f32_1 = 1.5 : f32
+  spv.specConstant @sc_f32_2 = 2.5 : f32
+  spv.specConstant @sc_f32_3 = 3.5 : f32
+
+  spv.specConstant @sc_i32_1 = 1   : i32
+
+  // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+  spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+
+  // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+  spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+
+  // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
+  spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
+}
diff  --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index 98da480b83ff..765eba959a26 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -596,3 +596,130 @@ func @use_in_function() -> () {
   spv.specConstant @sc = false
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite
+//===----------------------------------------------------------------------===//
+
+spv.module Logical GLSL450 {
+  // expected-error @+1 {{result type must be a composite type}}
+  spv.specConstantComposite @scc2 (@sc1, @sc2, @sc3) : i32
+}
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite (spv.array)
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1.5 : f32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32>
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32>
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = false
+  spv.specConstant @sc2 spec_id(5) = 42 : i64
+  spv.specConstant @sc3 = 1.5 : f32
+  // expected-error @+1 {{has incorrect number of operands: expected 4, but provided 3}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<4 x f32>
+
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1   : i32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // expected-error @+1 {{has incorrect types of operands: expected 'f32', but provided 'i32'}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32>
+}
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite (spv.struct)
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1   : i32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32, f32>
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32, f32>
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1   : i32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // expected-error @+1 {{has incorrect number of operands: expected 2, but provided 3}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32>
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1.5 : f32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // expected-error @+1 {{has incorrect types of operands: expected 'i32', but provided 'f32'}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32, f32>
+}
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite (vector)
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1.5 : f32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3xf32>
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3 x f32>
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = false
+  spv.specConstant @sc2 spec_id(5) = 42 : i64
+  spv.specConstant @sc3 = 1.5 : f32
+  // expected-error @+1 {{has incorrect number of operands: expected 4, but provided 3}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<4xf32>
+
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1   : i32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // expected-error @+1 {{has incorrect types of operands: expected 'f32', but provided 'i32'}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite (spv.coopmatrix)
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1.5 : f32
+  // expected-error @+1 {{unsupported composite type}}
+  spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device>
+}
        
    
    
More information about the Mlir-commits
mailing list