[Mlir-commits] [mlir] 21f1462 - [mlir][spirv] Add support for OpImageType

Lei Zhang llvmlistbot at llvm.org
Mon Feb 1 12:01:48 PST 2021


Author: Lei Zhang
Date: 2021-02-01T15:01:31-05:00
New Revision: 21f1462106b9ee1e646bf409c85528828320b34e

URL: https://github.com/llvm/llvm-project/commit/21f1462106b9ee1e646bf409c85528828320b34e
DIFF: https://github.com/llvm/llvm-project/commit/21f1462106b9ee1e646bf409c85528828320b34e.diff

LOG: [mlir][spirv] Add support for OpImageType

Support OpImageType in SPIRV Dialect.

This change doesn't support operand AccessQualifier since
it is optinal and only enables under Kernel capability.

co-authored-by: Alan Liu <alanliu.yf at gmail.com>

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D95580

Added: 
    mlir/test/Target/SPIRV/image.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
    mlir/lib/Target/SPIRV/Serialization/Serialization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 347b65a7739e..0c2d91133b27 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3157,6 +3157,7 @@ def SPV_OC_OpTypeInt                   : I32EnumAttrCase<"OpTypeInt", 21>;
 def SPV_OC_OpTypeFloat                 : I32EnumAttrCase<"OpTypeFloat", 22>;
 def SPV_OC_OpTypeVector                : I32EnumAttrCase<"OpTypeVector", 23>;
 def SPV_OC_OpTypeMatrix                : I32EnumAttrCase<"OpTypeMatrix", 24>;
+def SPV_OC_OpTypeImage                 : I32EnumAttrCase<"OpTypeImage", 25>;
 def SPV_OC_OpTypeArray                 : I32EnumAttrCase<"OpTypeArray", 28>;
 def SPV_OC_OpTypeRuntimeArray          : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
 def SPV_OC_OpTypeStruct                : I32EnumAttrCase<"OpTypeStruct", 30>;
@@ -3315,7 +3316,7 @@ def SPV_OpcodeAttr :
       SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
       SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
       SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
-      SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix,
+      SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, SPV_OC_OpTypeImage,
       SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
       SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer,
       SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index b68c75039f66..58f2a7eed593 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -157,6 +157,7 @@ LogicalResult spirv::Deserializer::processInstruction(
   case spirv::Opcode::OpTypeMatrix:
   case spirv::Opcode::OpTypeArray:
   case spirv::Opcode::OpTypeFunction:
+  case spirv::Opcode::OpTypeImage:
   case spirv::Opcode::OpTypeRuntimeArray:
   case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 5ce169a0d47f..3b98c2efbb2a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -713,6 +713,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     return processCooperativeMatrixType(operands);
   case spirv::Opcode::OpTypeFunction:
     return processFunctionType(operands);
+  case spirv::Opcode::OpTypeImage:
+    return processImageType(operands);
   case spirv::Opcode::OpTypeRuntimeArray:
     return processRuntimeArrayType(operands);
   case spirv::Opcode::OpTypeStruct:
@@ -1004,6 +1006,54 @@ spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
+  // TODO: Add support for Access Qualifier.
+  if (operands.size() != 8)
+    return emitError(
+        unknownLoc,
+        "OpTypeImage with non-eight operands are not supported yet");
+
+  Type elementTy = getType(operands[1]);
+  if (!elementTy)
+    return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
+           << operands[1];
+
+  auto dim = spirv::symbolizeDim(operands[2]);
+  if (!dim)
+    return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
+           << operands[2];
+
+  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
+  if (!depthInfo)
+    return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
+           << operands[3];
+
+  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
+  if (!arrayedInfo)
+    return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
+           << operands[4];
+
+  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
+  if (!samplingInfo)
+    return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
+
+  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
+  if (!samplerUseInfo)
+    return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
+           << operands[6];
+
+  auto format = spirv::symbolizeImageFormat(operands[7]);
+  if (!format)
+    return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
+           << operands[7];
+
+  typeMap[operands[0]] = spirv::ImageType::get(
+      elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(),
+      samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Constant
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index c2e75d0a62ba..54e7eb0381d0 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -273,6 +273,8 @@ class Deserializer {
 
   LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processImageType(ArrayRef<uint32_t> operands);
+
   LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
 
   LogicalResult processStructType(ArrayRef<uint32_t> operands);

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
index 5eea6f1752dd..c8421293cdce 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
@@ -1192,6 +1192,22 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
+  if (auto imageType = type.dyn_cast<spirv::ImageType>()) {
+    typeEnum = spirv::Opcode::OpTypeImage;
+    uint32_t sampledTypeID = 0;
+    if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
+      return failure();
+
+    operands.push_back(sampledTypeID);
+    operands.push_back(static_cast<uint32_t>(imageType.getDim()));
+    operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
+    operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
+    operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
+    operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
+    operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
+    return success();
+  }
+
   if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
     typeEnum = spirv::Opcode::OpTypeArray;
     uint32_t elementTypeID = 0;

diff  --git a/mlir/test/Target/SPIRV/image.mlir b/mlir/test/Target/SPIRV/image.mlir
new file mode 100644
index 000000000000..ebcbe4c1a191
--- /dev/null
+++ b/mlir/test/Target/SPIRV/image.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+  // CHECK: !spv.ptr<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, UniformConstant>
+  spv.globalVariable @var0 bind(0, 1) : !spv.ptr<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, UniformConstant>
+
+  // CHECK: !spv.ptr<!spv.image<si32, Cube, IsDepth, NonArrayed, SingleSampled, NeedSampler, R8ui>, UniformConstant>
+  spv.globalVariable @var1 : !spv.ptr<!spv.image<si32, Cube, IsDepth, NonArrayed, SingleSampled, NeedSampler, R8ui>, UniformConstant>
+
+  // CHECK: !spv.ptr<!spv.image<i32, SubpassData, DepthUnknown, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+  spv.globalVariable @var2 : !spv.ptr<!spv.image<i32, SubpassData, DepthUnknown, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+}


        


More information about the Mlir-commits mailing list