[Mlir-commits] [mlir] 21f1462 - [mlir][spirv] Add support for OpImageType
Lei Zhang
llvmlistbot at llvm.org
Mon Feb 1 12:01:48 PST 2021
Author: Lei Zhang
Date: 2021-02-01T15:01:31-05:00
New Revision: 21f1462106b9ee1e646bf409c85528828320b34e
URL: https://github.com/llvm/llvm-project/commit/21f1462106b9ee1e646bf409c85528828320b34e
DIFF: https://github.com/llvm/llvm-project/commit/21f1462106b9ee1e646bf409c85528828320b34e.diff
LOG: [mlir][spirv] Add support for OpImageType
Support OpImageType in SPIRV Dialect.
This change doesn't support operand AccessQualifier since
it is optinal and only enables under Kernel capability.
co-authored-by: Alan Liu <alanliu.yf at gmail.com>
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D95580
Added:
mlir/test/Target/SPIRV/image.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 347b65a7739e..0c2d91133b27 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3157,6 +3157,7 @@ def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>;
+def SPV_OC_OpTypeImage : I32EnumAttrCase<"OpTypeImage", 25>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>;
@@ -3315,7 +3316,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
- SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix,
+ SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, SPV_OC_OpTypeImage,
SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer,
SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index b68c75039f66..58f2a7eed593 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -157,6 +157,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeMatrix:
case spirv::Opcode::OpTypeArray:
case spirv::Opcode::OpTypeFunction:
+ case spirv::Opcode::OpTypeImage:
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 5ce169a0d47f..3b98c2efbb2a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -713,6 +713,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processCooperativeMatrixType(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
+ case spirv::Opcode::OpTypeImage:
+ return processImageType(operands);
case spirv::Opcode::OpTypeRuntimeArray:
return processRuntimeArrayType(operands);
case spirv::Opcode::OpTypeStruct:
@@ -1004,6 +1006,54 @@ spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
+ // TODO: Add support for Access Qualifier.
+ if (operands.size() != 8)
+ return emitError(
+ unknownLoc,
+ "OpTypeImage with non-eight operands are not supported yet");
+
+ Type elementTy = getType(operands[1]);
+ if (!elementTy)
+ return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
+ << operands[1];
+
+ auto dim = spirv::symbolizeDim(operands[2]);
+ if (!dim)
+ return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
+ << operands[2];
+
+ auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
+ if (!depthInfo)
+ return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
+ << operands[3];
+
+ auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
+ if (!arrayedInfo)
+ return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
+ << operands[4];
+
+ auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
+ if (!samplingInfo)
+ return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
+
+ auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
+ if (!samplerUseInfo)
+ return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
+ << operands[6];
+
+ auto format = spirv::symbolizeImageFormat(operands[7]);
+ if (!format)
+ return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
+ << operands[7];
+
+ typeMap[operands[0]] = spirv::ImageType::get(
+ elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(),
+ samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index c2e75d0a62ba..54e7eb0381d0 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -273,6 +273,8 @@ class Deserializer {
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
+ LogicalResult processImageType(ArrayRef<uint32_t> operands);
+
LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
LogicalResult processStructType(ArrayRef<uint32_t> operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
index 5eea6f1752dd..c8421293cdce 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
@@ -1192,6 +1192,22 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
+ if (auto imageType = type.dyn_cast<spirv::ImageType>()) {
+ typeEnum = spirv::Opcode::OpTypeImage;
+ uint32_t sampledTypeID = 0;
+ if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
+ return failure();
+
+ operands.push_back(sampledTypeID);
+ operands.push_back(static_cast<uint32_t>(imageType.getDim()));
+ operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
+ operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
+ operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
+ operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
+ operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
+ return success();
+ }
+
if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
typeEnum = spirv::Opcode::OpTypeArray;
uint32_t elementTypeID = 0;
diff --git a/mlir/test/Target/SPIRV/image.mlir b/mlir/test/Target/SPIRV/image.mlir
new file mode 100644
index 000000000000..ebcbe4c1a191
--- /dev/null
+++ b/mlir/test/Target/SPIRV/image.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+ // CHECK: !spv.ptr<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, UniformConstant>
+ spv.globalVariable @var0 bind(0, 1) : !spv.ptr<!spv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>, UniformConstant>
+
+ // CHECK: !spv.ptr<!spv.image<si32, Cube, IsDepth, NonArrayed, SingleSampled, NeedSampler, R8ui>, UniformConstant>
+ spv.globalVariable @var1 : !spv.ptr<!spv.image<si32, Cube, IsDepth, NonArrayed, SingleSampled, NeedSampler, R8ui>, UniformConstant>
+
+ // CHECK: !spv.ptr<!spv.image<i32, SubpassData, DepthUnknown, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+ spv.globalVariable @var2 : !spv.ptr<!spv.image<i32, SubpassData, DepthUnknown, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+}
More information about the Mlir-commits
mailing list