[Mlir-commits] [mlir] [mlir][spirv] Restructure code in `SPIRVConversion.cpp` (PR #99393)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 17 14:37:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Angel Zhang (angelz913)

<details>
<summary>Changes</summary>



---

Patch is 31.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99393.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+245-244) 


``````````diff
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index e3a09ef1ff684..710e39692471a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -40,11 +40,13 @@
 
 using namespace mlir;
 
+namespace {
+
 //===----------------------------------------------------------------------===//
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-static int getComputeVectorSize(int64_t size) {
+int getComputeVectorSize(int64_t size) {
   for (int i : {4, 3, 2}) {
     if (size % i == 0)
       return i;
@@ -52,7 +54,7 @@ static int getComputeVectorSize(int64_t size) {
   return 1;
 }
 
-static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
   LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
   if (vecType.isScalable()) {
     LLVM_DEBUG(llvm::dbgs()
@@ -88,7 +90,7 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
 /// convention.
 template <typename LabelT>
-static LogicalResult checkExtensionRequirements(
+LogicalResult checkExtensionRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
   for (const auto &ors : candidates) {
@@ -116,7 +118,7 @@ static LogicalResult checkExtensionRequirements(
 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
 /// convention.
 template <typename LabelT>
-static LogicalResult checkCapabilityRequirements(
+LogicalResult checkCapabilityRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
   for (const auto &ors : candidates) {
@@ -139,7 +141,7 @@ static LogicalResult checkCapabilityRequirements(
 
 /// Returns true if the given `storageClass` needs explicit layout when used in
 /// Shader environments.
-static bool needsExplicitLayout(spirv::StorageClass storageClass) {
+bool needsExplicitLayout(spirv::StorageClass storageClass) {
   switch (storageClass) {
   case spirv::StorageClass::PhysicalStorageBuffer:
   case spirv::StorageClass::PushConstant:
@@ -153,8 +155,8 @@ static bool needsExplicitLayout(spirv::StorageClass storageClass) {
 
 /// Wraps the given `elementType` in a struct and gets the pointer to the
 /// struct. This is used to satisfy Vulkan interface requirements.
-static spirv::PointerType
-wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
+spirv::PointerType wrapInStructAndGetPointer(Type elementType,
+                                             spirv::StorageClass storageClass) {
   auto structType = needsExplicitLayout(storageClass)
                         ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
                         : spirv::StructType::get(elementType);
@@ -165,28 +167,16 @@ wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
-static spirv::ScalarType getIndexType(MLIRContext *ctx,
-                                      const SPIRVConversionOptions &options) {
+spirv::ScalarType getIndexType(MLIRContext *ctx,
+                               const SPIRVConversionOptions &options) {
   return cast<spirv::ScalarType>(
       IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
 }
 
-Type SPIRVTypeConverter::getIndexType() const {
-  return ::getIndexType(getContext(), options);
-}
-
-MLIRContext *SPIRVTypeConverter::getContext() const {
-  return targetEnv.getAttr().getContext();
-}
-
-bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
-  return targetEnv.allows(capability);
-}
-
 // TODO: This is a utility function that should probably be exposed by the
 // SPIR-V dialect. Keeping it local till the use case arises.
-static std::optional<int64_t>
-getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+std::optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options,
+                                       Type type) {
   if (isa<spirv::ScalarType>(type)) {
     auto bitWidth = type.getIntOrFloatBitWidth();
     // According to the SPIR-V spec:
@@ -266,10 +256,10 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
 }
 
 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
-static Type
-convertScalarType(const spirv::TargetEnv &targetEnv,
-                  const SPIRVConversionOptions &options, spirv::ScalarType type,
-                  std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertScalarType(const spirv::TargetEnv &targetEnv,
+                       const SPIRVConversionOptions &options,
+                       spirv::ScalarType type,
+                       std::optional<spirv::StorageClass> storageClass = {}) {
   // Get extension and capability requirements for the given type.
   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
@@ -311,8 +301,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
 /// the above given that these sub-byte types are not supported at all in
 /// SPIR-V; there are no compute/storage capability for them like other
 /// supported integer types.
-static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
-                                      IntegerType type) {
+Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
+                               IntegerType type) {
   if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
     LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
     return nullptr;
@@ -333,9 +323,8 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
 /// Returns a type with the same shape but with any index element type converted
 /// to the matching integer type. This is a noop when the element type is not
 /// the index type.
-static ShapedType
-convertIndexElementType(ShapedType type,
-                        const SPIRVConversionOptions &options) {
+ShapedType convertIndexElementType(ShapedType type,
+                                   const SPIRVConversionOptions &options) {
   Type indexType = dyn_cast<IndexType>(type.getElementType());
   if (!indexType)
     return type;
@@ -344,10 +333,9 @@ convertIndexElementType(ShapedType type,
 }
 
 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
-static Type
-convertVectorType(const spirv::TargetEnv &targetEnv,
-                  const SPIRVConversionOptions &options, VectorType type,
-                  std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertVectorType(const spirv::TargetEnv &targetEnv,
+                       const SPIRVConversionOptions &options, VectorType type,
+                       std::optional<spirv::StorageClass> storageClass = {}) {
   type = cast<VectorType>(convertIndexElementType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
@@ -401,10 +389,9 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   return nullptr;
 }
 
-static Type
-convertComplexType(const spirv::TargetEnv &targetEnv,
-                   const SPIRVConversionOptions &options, ComplexType type,
-                   std::optional<spirv::StorageClass> storageClass = {}) {
+Type convertComplexType(const spirv::TargetEnv &targetEnv,
+                        const SPIRVConversionOptions &options, ComplexType type,
+                        std::optional<spirv::StorageClass> storageClass = {}) {
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
     LLVM_DEBUG(llvm::dbgs()
@@ -431,9 +418,8 @@ convertComplexType(const spirv::TargetEnv &targetEnv,
 /// create composite constants with OpConstantComposite to embed relative large
 /// constant values and use OpCompositeExtract and OpCompositeInsert to
 /// manipulate, like what we do for vectors.
-static Type convertTensorType(const spirv::TargetEnv &targetEnv,
-                              const SPIRVConversionOptions &options,
-                              TensorType type) {
+Type convertTensorType(const spirv::TargetEnv &targetEnv,
+                       const SPIRVConversionOptions &options, TensorType type) {
   // TODO: Handle dynamic shapes.
   if (!type.hasStaticShape()) {
     LLVM_DEBUG(llvm::dbgs()
@@ -478,10 +464,9 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
   return spirv::ArrayType::get(arrayElemType, arrayElemCount);
 }
 
-static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
-                                  const SPIRVConversionOptions &options,
-                                  MemRefType type,
-                                  spirv::StorageClass storageClass) {
+Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
+                           const SPIRVConversionOptions &options,
+                           MemRefType type, spirv::StorageClass storageClass) {
   unsigned numBoolBits = options.boolNumBits;
   if (numBoolBits != 8) {
     LLVM_DEBUG(llvm::dbgs()
@@ -531,10 +516,10 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
-static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
-                                     const SPIRVConversionOptions &options,
-                                     MemRefType type,
-                                     spirv::StorageClass storageClass) {
+Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
+                              const SPIRVConversionOptions &options,
+                              MemRefType type,
+                              spirv::StorageClass storageClass) {
   IntegerType elementType = cast<IntegerType>(type.getElementType());
   Type arrayElemType = convertSubByteIntegerType(options, elementType);
   if (!arrayElemType)
@@ -569,9 +554,8 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
-static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
-                              const SPIRVConversionOptions &options,
-                              MemRefType type) {
+Type convertMemrefType(const spirv::TargetEnv &targetEnv,
+                       const SPIRVConversionOptions &options, MemRefType type) {
   auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
   if (!attr) {
     LLVM_DEBUG(
@@ -731,73 +715,134 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
 }
 
 //===----------------------------------------------------------------------===//
-// SPIRVTypeConverter
+// Builtin Variables
 //===----------------------------------------------------------------------===//
 
-SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
-                                       const SPIRVConversionOptions &options)
-    : targetEnv(targetAttr), options(options) {
-  // Add conversions. The order matters here: later ones will be tried earlier.
+spirv::GlobalVariableOp getBuiltinVariable(Block &body,
+                                           spirv::BuiltIn builtin) {
+  // Look through all global variables in the given `body` block and check if
+  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
+  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+    if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
+            spirv::SPIRVDialect::getAttributeName(
+                spirv::Decoration::BuiltIn))) {
+      auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+      if (varBuiltIn && *varBuiltIn == builtin) {
+        return varOp;
+      }
+    }
+  }
+  return nullptr;
+}
 
-  // Allow all SPIR-V dialect specific types. This assumes all builtin types
-  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
-  // were tried before.
-  //
-  // TODO: This assumes that the SPIR-V types are valid to use in the given
-  // target environment, which should be the case if the whole pipeline is
-  // driven by the same target environment. Still, we probably still want to
-  // validate and convert to be safe.
-  addConversion([](spirv::SPIRVType type) { return type; });
+/// Gets name of global variable for a builtin.
+std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
+                              StringRef suffix) {
+  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
+}
 
-  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
+/// Gets or inserts a global variable for a builtin within `body` block.
+spirv::GlobalVariableOp
+getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
+                           Type integerType, OpBuilder &builder,
+                           StringRef prefix, StringRef suffix) {
+  if (auto varOp = getBuiltinVariable(body, builtin))
+    return varOp;
 
-  addConversion([this](IntegerType intType) -> std::optional<Type> {
-    if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
-      return convertScalarType(this->targetEnv, this->options, scalarType);
-    if (intType.getWidth() < 8)
-      return convertSubByteIntegerType(this->options, intType);
-    return Type();
-  });
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(&body);
 
-  addConversion([this](FloatType floatType) -> std::optional<Type> {
-    if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
-      return convertScalarType(this->targetEnv, this->options, scalarType);
-    return Type();
-  });
+  spirv::GlobalVariableOp newVarOp;
+  switch (builtin) {
+  case spirv::BuiltIn::NumWorkgroups:
+  case spirv::BuiltIn::WorkgroupSize:
+  case spirv::BuiltIn::WorkgroupId:
+  case spirv::BuiltIn::LocalInvocationId:
+  case spirv::BuiltIn::GlobalInvocationId: {
+    auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
+                                           spirv::StorageClass::Input);
+    std::string name = getBuiltinVarName(builtin, prefix, suffix);
+    newVarOp =
+        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+    break;
+  }
+  case spirv::BuiltIn::SubgroupId:
+  case spirv::BuiltIn::NumSubgroups:
+  case spirv::BuiltIn::SubgroupSize: {
+    auto ptrType =
+        spirv::PointerType::get(integerType, spirv::StorageClass::Input);
+    std::string name = getBuiltinVarName(builtin, prefix, suffix);
+    newVarOp =
+        builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+    break;
+  }
+  default:
+    emitError(loc, "unimplemented builtin variable generation for ")
+        << stringifyBuiltIn(builtin);
+  }
+  return newVarOp;
+}
 
-  addConversion([this](ComplexType complexType) {
-    return convertComplexType(this->targetEnv, this->options, complexType);
-  });
+//===----------------------------------------------------------------------===//
+// Push constant storage
+//===----------------------------------------------------------------------===//
 
-  addConversion([this](VectorType vectorType) {
-    return convertVectorType(this->targetEnv, this->options, vectorType);
-  });
+/// Returns the pointer type for the push constant storage containing
+/// `elementCount` 32-bit integer values.
+spirv::PointerType getPushConstantStorageType(unsigned elementCount,
+                                              Builder &builder,
+                                              Type indexType) {
+  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
+                                         /*stride=*/4);
+  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
+  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
+}
 
-  addConversion([this](TensorType tensorType) {
-    return convertTensorType(this->targetEnv, this->options, tensorType);
-  });
+/// Returns the push constant varible containing `elementCount` 32-bit integer
+/// values in `body`. Returns null op if such an op does not exit.
+spirv::GlobalVariableOp getPushConstantVariable(Block &body,
+                                                unsigned elementCount) {
+  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+    auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
+    if (!ptrType)
+      continue;
 
-  addConversion([this](MemRefType memRefType) {
-    return convertMemrefType(this->targetEnv, this->options, memRefType);
-  });
+    // Note that Vulkan requires "There must be no more than one push constant
+    // block statically used per shader entry point." So we should always reuse
+    // the existing one.
+    if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
+      auto numElements = cast<spirv::ArrayType>(
+                             cast<spirv::StructType>(ptrType.getPointeeType())
+                                 .getElementType(0))
+                             .getNumElements();
+      if (numElements == elementCount)
+        return varOp;
+    }
+  }
+  return nullptr;
+}
 
-  // Register some last line of defense casting logic.
-  addSourceMaterialization(
-      [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
-        return castToSourceType(this->targetEnv, builder, type, inputs, loc);
-      });
-  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return std::optional<Value>(cast.getResult(0));
-  });
+/// Gets or inserts a global variable for push constant storage containing
+/// `elementCount` 32-bit integer values in `block`.
+spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc,
+                                                        Block &block,
+                                                        unsigned elementCount,
+                                                        OpBuilder &b,
+                                                        Type indexType) {
+  if (auto varOp = getPushConstantVariable(block, elementCount))
+    return varOp;
+
+  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
+  auto type = getPushConstantStorageType(elementCount, builder, indexType);
+  const char *name = "__push_constant_var__";
+  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
+                                                 /*initializer=*/nullptr);
 }
 
 //===----------------------------------------------------------------------===//
 // func::FuncOp Conversion Patterns
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// A pattern for rewriting function signature to convert arguments of functions
 /// to be of valid SPIR-V types.
 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
@@ -808,7 +853,6 @@ class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
-} // namespace
 
 LogicalResult
 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
@@ -855,16 +899,6 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
   return success();
 }
 
-void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
-                                              RewritePatternSet &patterns) {
-  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// func::FuncOp Conversion Patterns
-//===----------------------------------------------------------------------===//
-
-namespace {
 /// A pattern for rewriting function signature to convert vect...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/99393


More information about the Mlir-commits mailing list