[Mlir-commits] [mlir] a40767e - [MLIR][SPIRV] Add (de-)serialization support for SpecConstantOpeation.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 10 22:59:27 PST 2021


Author: ergawy
Date: 2021-01-11T07:37:50+01:00
New Revision: a40767ec8851b997e4dcc9987078bd02670f8c7f

URL: https://github.com/llvm/llvm-project/commit/a40767ec8851b997e4dcc9987078bd02670f8c7f
DIFF: https://github.com/llvm/llvm-project/commit/a40767ec8851b997e4dcc9987078bd02670f8c7f.diff

LOG: [MLIR][SPIRV] Add (de-)serialization support for SpecConstantOpeation.

This commit adds support for (de-)serializing SpecConstantOpeation. One
thing worth noting is that during deserialization, we assign a fake ID to
enclosed ops inside SpecConstantOpeation. We need to do this in order
for deserialization logic to properly update ID to value map and to
later reference the created value from the sibling 'spv::YieldOp'.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D93591

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Target/SPIRV/Deserialization.cpp
    mlir/lib/Target/SPIRV/Serialization.cpp
    mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
    mlir/test/Target/SPIRV/spec-constant.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 76374ca481fb..99b245563ca6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3170,6 +3170,7 @@ def SPV_OC_OpSpecConstantTrue          : I32EnumAttrCase<"OpSpecConstantTrue", 4
 def SPV_OC_OpSpecConstantFalse         : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
 def SPV_OC_OpSpecConstant              : I32EnumAttrCase<"OpSpecConstant", 50>;
 def SPV_OC_OpSpecConstantComposite     : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
+def SPV_OC_OpSpecConstantOperation     : I32EnumAttrCase<"OpSpecConstantOperation", 52>;
 def SPV_OC_OpFunction                  : I32EnumAttrCase<"OpFunction", 54>;
 def SPV_OC_OpFunctionParameter         : I32EnumAttrCase<"OpFunctionParameter", 55>;
 def SPV_OC_OpFunctionEnd               : I32EnumAttrCase<"OpFunctionEnd", 56>;
@@ -3314,7 +3315,8 @@ def SPV_OpcodeAttr :
       SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
       SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
       SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
-      SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
+      SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOperation,
+      SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
       SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
       SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
       SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index ad3e78f618b7..c90895197f43 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -3445,9 +3445,8 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
     return constOp.emitOpError("invalid enclosed op");
 
   for (auto operand : enclosedOp.getOperands())
-    if (!isa<spirv::ConstantOp, spirv::SpecConstantOp,
-             spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>(
-            operand.getDefiningOp()))
+    if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
+             spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
       return constOp.emitOpError(
           "invalid operand, must be defined by a constant operation");
 

diff  --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization.cpp
index 30f46f6fc605..07eb3d35e0a4 100644
--- a/mlir/lib/Target/SPIRV/Deserialization.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Target/SPIRV/Deserialization.h"
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVModule.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
@@ -132,6 +134,14 @@ struct DeferredStructTypeInfo {
   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
 };
 
+/// A struct that collects the info needed to materialize/emit a
+/// SpecConstantOperation op.
+struct SpecConstOperationMaterializationInfo {
+  spirv::Opcode enclodesOpcode;
+  uint32_t resultTypeID;
+  SmallVector<uint32_t> enclosedOpOperands;
+};
+
 //===----------------------------------------------------------------------===//
 // Deserializer Declaration
 //===----------------------------------------------------------------------===//
@@ -216,9 +226,14 @@ class Deserializer {
   /// Gets the constant's attribute and type associated with the given <id>.
   Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
 
-  /// Gets the constant's integer attribute with the given <id>. Returns a null
-  /// IntegerAttr if the given is not registered or does not correspond to an
-  /// integer constant.
+  /// Gets the info needed to materialize the spec constant operation op
+  /// associated with the given <id>.
+  Optional<SpecConstOperationMaterializationInfo>
+  getSpecConstantOperation(uint32_t id);
+
+  /// Gets the constant's integer attribute with the given <id>. Returns a
+  /// null IntegerAttr if the given is not registered or does not correspond
+  /// to an integer constant.
   IntegerAttr getConstantInt(uint32_t id);
 
   /// Returns a symbol to be used for the function name with the given
@@ -305,8 +320,20 @@ class Deserializer {
   /// `operands`.
   LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpSpecConstantComposite instruction with the given
+  /// `operands`.
   LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpSpecConstantOperation instruction with the given
+  /// `operands`.
+  LogicalResult processSpecConstantOperation(ArrayRef<uint32_t> operands);
+
+  /// Materializes/emits an OpSpecConstantOperation instruction.
+  Value materializeSpecConstantOperation(uint32_t resultID,
+                                         spirv::Opcode enclosedOpcode,
+                                         uint32_t resultTypeID,
+                                         ArrayRef<uint32_t> enclosedOpOperands);
+
   /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
   LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
 
@@ -534,6 +561,11 @@ class Deserializer {
   // Result <id> to composite spec constant mapping.
   DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
 
+  /// Result <id> to info needed to materialize an OpSpecConstantOperation
+  /// mapping.
+  DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
+      specConstOperationMap;
+
   // Result <id> to variable mapping.
   DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
 
@@ -1036,6 +1068,14 @@ Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) {
   return constIt->getSecond();
 }
 
+Optional<SpecConstOperationMaterializationInfo>
+Deserializer::getSpecConstantOperation(uint32_t id) {
+  auto constIt = specConstOperationMap.find(id);
+  if (constIt == specConstOperationMap.end())
+    return llvm::None;
+  return constIt->getSecond();
+}
+
 std::string Deserializer::getFunctionSymbol(uint32_t id) {
   auto funcName = nameMap.lookup(id).str();
   if (funcName.empty()) {
@@ -1745,6 +1785,91 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 3)
+    return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
+                                 "result <id>, and operand opcode");
+
+  uint32_t resultTypeID = operands[0];
+
+  if (!getType(resultTypeID))
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << resultTypeID;
+
+  uint32_t resultID = operands[1];
+  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
+  auto emplaceResult = specConstOperationMap.try_emplace(
+      resultID,
+      SpecConstOperationMaterializationInfo{
+          enclosedOpcode, resultTypeID,
+          SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
+
+  if (!emplaceResult.second)
+    return emitError(unknownLoc, "value with <id>: ")
+           << resultID << " is probably defined before.";
+
+  return success();
+}
+
+Value Deserializer::materializeSpecConstantOperation(
+    uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
+    ArrayRef<uint32_t> enclosedOpOperands) {
+
+  Type resultType = getType(resultTypeID);
+
+  // Instructions wrapped by OpSpecConstantOp need an ID for their
+  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
+  // dialect wrapped op. For that purpose, a new value map is created and "fake"
+  // ID in that map is assigned to the result of the enclosed instruction. Note
+  // that there is no need to update this fake ID since we only need to
+  // reference the created Value for the enclosed op from the spv::YieldOp
+  // created later in this method (both of which are the only values in their
+  // region: the SpecConstantOperation's region). If we encounter another
+  // SpecConstantOperation in the module, we simply re-use the fake ID since the
+  // previous Value assigned to it isn't visible in the current scope anyway.
+  DenseMap<uint32_t, Value> newValueMap;
+  llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
+                                                                newValueMap);
+  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
+
+  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
+  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
+  enclosedOpResultTypeAndOperands.push_back(fakeID);
+  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
+                                         enclosedOpOperands.end());
+
+  // Process enclosed instruction before creating the enclosing
+  // specConstantOperation (and its region). This way, references to constants,
+  // global variables, and spec constants will be materialized outside the new
+  // op's region. For more info, see Deserializer::getValue's implementation.
+  if (failed(
+          processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
+    return Value();
+
+  // Since the enclosed op is emitted in the current block, split it in a
+  // separate new block.
+  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
+
+  auto loc = createFileLineColLoc(opBuilder);
+  auto specConstOperationOp =
+      opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
+
+  Region &body = specConstOperationOp.body();
+  // Move the new block into SpecConstantOperation's body.
+  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
+                          Region::iterator(enclosedBlock));
+  Block &block = body.back();
+
+  // RAII guard to reset the insertion point to the module's region after
+  // deserializing the body of the specConstantOperation.
+  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+  opBuilder.setInsertionPointToEnd(&block);
+
+  opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
+  return specConstOperationOp.getResult();
+}
+
 LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc,
@@ -2378,6 +2503,12 @@ Value Deserializer::getValue(uint32_t id) {
         opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
     return referenceOfOp.reference();
   }
+  if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
+    return materializeSpecConstantOperation(
+        id, specConstOperationInfo->enclodesOpcode,
+        specConstOperationInfo->resultTypeID,
+        specConstOperationInfo->enclosedOpOperands);
+  }
   if (auto undef = getUndefType(id)) {
     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
   }
@@ -2483,6 +2614,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
     return processConstantComposite(operands);
   case spirv::Opcode::OpSpecConstantComposite:
     return processSpecConstantComposite(operands);
+  case spirv::Opcode::OpSpecConstantOperation:
+    return processSpecConstantOperation(operands);
   case spirv::Opcode::OpConstantTrue:
     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
   case spirv::Opcode::OpSpecConstantTrue:

diff  --git a/mlir/lib/Target/SPIRV/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization.cpp
index 227008151f87..fae509ff63f9 100644
--- a/mlir/lib/Target/SPIRV/Serialization.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization.cpp
@@ -204,6 +204,9 @@ class Serializer {
   LogicalResult
   processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
 
+  LogicalResult
+  processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
+
   /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
   /// value to use with other operations. The SPIR-V spec recommends that
   /// OpUndef be generated at module level. The serialization generates an
@@ -711,6 +714,49 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
   return processName(resultID, op.sym_name());
 }
 
+LogicalResult
+Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
+  uint32_t typeID = 0;
+  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
+    return failure();
+  }
+
+  auto resultID = getNextID();
+
+  SmallVector<uint32_t, 8> operands;
+  operands.push_back(typeID);
+  operands.push_back(resultID);
+
+  Block &block = op.getRegion().getBlocks().front();
+  Operation &enclosedOp = block.getOperations().front();
+
+  std::string enclosedOpName;
+  llvm::raw_string_ostream rss(enclosedOpName);
+  rss << "Op" << enclosedOp.getName().stripDialect();
+  auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
+
+  if (!enclosedOpcode) {
+    op.emitError("Couldn't find op code for op ")
+        << enclosedOp.getName().getStringRef();
+    return failure();
+  }
+
+  operands.push_back(static_cast<uint32_t>(enclosedOpcode.getValue()));
+
+  // Append operands to the enclosed op to the list of operands.
+  for (Value operand : enclosedOp.getOperands()) {
+    uint32_t id = getValueID(operand);
+    assert(id && "use before def!");
+    operands.push_back(id);
+  }
+
+  encodeInstructionInto(typesGlobalValues,
+                        spirv::Opcode::OpSpecConstantOperation, operands);
+  valueIDMap[op.getResult()] = resultID;
+
+  return success();
+}
+
 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   auto undefType = op.getType();
   auto &id = undefValIDMap[undefType];
@@ -1929,6 +1975,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
       .Case([&](spirv::SpecConstantCompositeOp op) {
         return processSpecConstantCompositeOp(op);
       })
+      .Case([&](spirv::SpecConstantOperationOp op) {
+        return processSpecConstantOperationOp(op);
+      })
       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index c0b495115d6c..c3f715f06ae2 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -780,6 +780,20 @@ spv.module Logical GLSL450 {
 
 // -----
 
+spv.module Logical GLSL450 {
+  spv.specConstant @sc = 42 : i32
+
+  spv.func @foo() -> i32 "None" {
+    // CHECK: [[SC:%.*]] = spv.mlir.referenceof @sc
+    %0 = spv.mlir.referenceof @sc : i32
+    // CHECK: spv.SpecConstantOperation wraps "spv.ISub"([[SC]], [[SC]]) : (i32, i32) -> i32
+    %1 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %0) : (i32, i32) -> i32
+    spv.ReturnValue %1 : i32
+  }
+}
+
+// -----
+
 spv.module Logical GLSL450 {
   spv.func @foo() -> i32 "None" {
     %0 = spv.constant 1: i32

diff  --git a/mlir/test/Target/SPIRV/spec-constant.mlir b/mlir/test/Target/SPIRV/spec-constant.mlir
index 54b6e2a2eb12..88dc64d035be 100644
--- a/mlir/test/Target/SPIRV/spec-constant.mlir
+++ b/mlir/test/Target/SPIRV/spec-constant.mlir
@@ -85,3 +85,34 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
   spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
 }
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+
+  spv.specConstant @sc_i32_1 = 1 : i32
+
+  spv.func @use_composite() -> (i32) "None" {
+    // CHECK: [[USE1:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32
+    // CHECK: [[USE2:%.*]] = spv.constant 0 : i32
+
+    // CHECK: [[RES1:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE1]], [[USE2]]) : (i32, i32) -> i32
+
+    // CHECK: [[USE3:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32
+    // CHECK: [[USE4:%.*]] = spv.constant 0 : i32
+
+    // CHECK: [[RES2:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE3]], [[USE4]]) : (i32, i32) -> i32
+
+    %0 = spv.mlir.referenceof @sc_i32_1 : i32
+    %1 = spv.constant 0 : i32
+    %2 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %1) : (i32, i32) -> i32
+
+    // CHECK: [[RES3:%.*]] = spv.SpecConstantOperation wraps "spv.IMul"([[RES1]], [[RES2]]) : (i32, i32) -> i32
+    %3 = spv.SpecConstantOperation wraps "spv.IMul"(%2, %2) : (i32, i32) -> i32
+
+    // Make sure deserialization continues from the right place after creating
+    // the previous op.
+    // CHECK: spv.ReturnValue [[RES3]]
+    spv.ReturnValue %3 : i32
+  }
+}


        


More information about the Mlir-commits mailing list