[Mlir-commits] [mlir] [mlir][spirv] Restructure code in `SPIRVConversion.cpp` (PR #99393)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jul 17 15:03:57 PDT 2024
================
@@ -731,73 +715,134 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
}
//===----------------------------------------------------------------------===//
-// SPIRVTypeConverter
+// Builtin Variables
//===----------------------------------------------------------------------===//
-SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
- const SPIRVConversionOptions &options)
- : targetEnv(targetAttr), options(options) {
- // Add conversions. The order matters here: later ones will be tried earlier.
+spirv::GlobalVariableOp getBuiltinVariable(Block &body,
+ spirv::BuiltIn builtin) {
+ // Look through all global variables in the given `body` block and check if
+ // there is a spirv.GlobalVariable that has the same `builtin` attribute.
+ for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+ if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
+ spirv::SPIRVDialect::getAttributeName(
+ spirv::Decoration::BuiltIn))) {
+ auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+ if (varBuiltIn && *varBuiltIn == builtin) {
+ return varOp;
+ }
+ }
+ }
+ return nullptr;
+}
- // Allow all SPIR-V dialect specific types. This assumes all builtin types
- // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
- // were tried before.
- //
- // TODO: This assumes that the SPIR-V types are valid to use in the given
- // target environment, which should be the case if the whole pipeline is
- // driven by the same target environment. Still, we probably still want to
- // validate and convert to be safe.
- addConversion([](spirv::SPIRVType type) { return type; });
+/// Gets name of global variable for a builtin.
+std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
+ StringRef suffix) {
+ return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
+}
- addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
+/// Gets or inserts a global variable for a builtin within `body` block.
+spirv::GlobalVariableOp
+getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
+ Type integerType, OpBuilder &builder,
+ StringRef prefix, StringRef suffix) {
+ if (auto varOp = getBuiltinVariable(body, builtin))
+ return varOp;
- addConversion([this](IntegerType intType) -> std::optional<Type> {
- if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
- return convertScalarType(this->targetEnv, this->options, scalarType);
- if (intType.getWidth() < 8)
- return convertSubByteIntegerType(this->options, intType);
- return Type();
- });
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&body);
- addConversion([this](FloatType floatType) -> std::optional<Type> {
- if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
- return convertScalarType(this->targetEnv, this->options, scalarType);
- return Type();
- });
+ spirv::GlobalVariableOp newVarOp;
+ switch (builtin) {
+ case spirv::BuiltIn::NumWorkgroups:
+ case spirv::BuiltIn::WorkgroupSize:
+ case spirv::BuiltIn::WorkgroupId:
+ case spirv::BuiltIn::LocalInvocationId:
+ case spirv::BuiltIn::GlobalInvocationId: {
+ auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
+ spirv::StorageClass::Input);
+ std::string name = getBuiltinVarName(builtin, prefix, suffix);
+ newVarOp =
+ builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+ break;
+ }
+ case spirv::BuiltIn::SubgroupId:
+ case spirv::BuiltIn::NumSubgroups:
+ case spirv::BuiltIn::SubgroupSize: {
+ auto ptrType =
+ spirv::PointerType::get(integerType, spirv::StorageClass::Input);
+ std::string name = getBuiltinVarName(builtin, prefix, suffix);
+ newVarOp =
+ builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+ break;
+ }
+ default:
+ emitError(loc, "unimplemented builtin variable generation for ")
+ << stringifyBuiltIn(builtin);
+ }
+ return newVarOp;
+}
- addConversion([this](ComplexType complexType) {
- return convertComplexType(this->targetEnv, this->options, complexType);
- });
+//===----------------------------------------------------------------------===//
+// Push constant storage
+//===----------------------------------------------------------------------===//
- addConversion([this](VectorType vectorType) {
- return convertVectorType(this->targetEnv, this->options, vectorType);
- });
+/// Returns the pointer type for the push constant storage containing
+/// `elementCount` 32-bit integer values.
+spirv::PointerType getPushConstantStorageType(unsigned elementCount,
+ Builder &builder,
+ Type indexType) {
+ auto arrayType = spirv::ArrayType::get(indexType, elementCount,
+ /*stride=*/4);
+ auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
+ return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
+}
- addConversion([this](TensorType tensorType) {
- return convertTensorType(this->targetEnv, this->options, tensorType);
- });
+/// Returns the push constant varible containing `elementCount` 32-bit integer
+/// values in `body`. Returns null op if such an op does not exit.
+spirv::GlobalVariableOp getPushConstantVariable(Block &body,
+ unsigned elementCount) {
+ for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+ auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
+ if (!ptrType)
+ continue;
- addConversion([this](MemRefType memRefType) {
- return convertMemrefType(this->targetEnv, this->options, memRefType);
- });
+ // Note that Vulkan requires "There must be no more than one push constant
+ // block statically used per shader entry point." So we should always reuse
+ // the existing one.
+ if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
+ auto numElements = cast<spirv::ArrayType>(
+ cast<spirv::StructType>(ptrType.getPointeeType())
+ .getElementType(0))
+ .getNumElements();
+ if (numElements == elementCount)
+ return varOp;
+ }
+ }
+ return nullptr;
+}
- // Register some last line of defense casting logic.
- addSourceMaterialization(
- [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
- return castToSourceType(this->targetEnv, builder, type, inputs, loc);
- });
- addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
- Location loc) {
- auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
- return std::optional<Value>(cast.getResult(0));
- });
+/// Gets or inserts a global variable for push constant storage containing
+/// `elementCount` 32-bit integer values in `block`.
+spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc,
+ Block &block,
+ unsigned elementCount,
+ OpBuilder &b,
+ Type indexType) {
+ if (auto varOp = getPushConstantVariable(block, elementCount))
+ return varOp;
+
+ auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
+ auto type = getPushConstantStorageType(elementCount, builder, indexType);
+ const char *name = "__push_constant_var__";
+ return builder.create<spirv::GlobalVariableOp>(loc, type, name,
+ /*initializer=*/nullptr);
}
//===----------------------------------------------------------------------===//
// func::FuncOp Conversion Patterns
//===----------------------------------------------------------------------===//
-namespace {
/// A pattern for rewriting function signature to convert arguments of functions
/// to be of valid SPIR-V types.
class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
----------------
kuhar wrote:
```suggestion
struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
```
and no `public:` below. Also please inline the implementation of `matchAndRewrite`
https://github.com/llvm/llvm-project/pull/99393
More information about the Mlir-commits
mailing list