[Mlir-commits] [mlir] 9527d77 - [mlir][spirv] Restructure code in `SPIRVConversion.cpp`. NFC. (#99393)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 18 11:31:18 PDT 2024
Author: Angel Zhang
Date: 2024-07-18T14:31:15-04:00
New Revision: 9527d77aefcf214944a4c8bd284dde3ffe9dff60
URL: https://github.com/llvm/llvm-project/commit/9527d77aefcf214944a4c8bd284dde3ffe9dff60
DIFF: https://github.com/llvm/llvm-project/commit/9527d77aefcf214944a4c8bd284dde3ffe9dff60.diff
LOG: [mlir][spirv] Restructure code in `SPIRVConversion.cpp`. NFC. (#99393)
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index e3a09ef1ff684..bf5044437fd09 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -40,6 +40,8 @@
using namespace mlir;
+namespace {
+
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
@@ -171,18 +173,6 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
}
-Type SPIRVTypeConverter::getIndexType() const {
- return ::getIndexType(getContext(), options);
-}
-
-MLIRContext *SPIRVTypeConverter::getContext() const {
- return targetEnv.getAttr().getContext();
-}
-
-bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
- return targetEnv.allows(capability);
-}
-
// TODO: This is a utility function that should probably be exposed by the
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
@@ -673,9 +663,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
/// This function is meant to handle the **compute** side; so it does not
/// involve storage classes in its logic. The storage side is expected to be
/// handled by MemRef conversion logic.
-std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
- OpBuilder &builder, Type type,
- ValueRange inputs, Location loc) {
+static std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
+ OpBuilder &builder, Type type,
+ ValueRange inputs, Location loc) {
// We can only cast one value in SPIR-V.
if (inputs.size() != 1) {
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
@@ -731,140 +721,185 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
}
//===----------------------------------------------------------------------===//
-// SPIRVTypeConverter
+// Builtin Variables
//===----------------------------------------------------------------------===//
-SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
- const SPIRVConversionOptions &options)
- : targetEnv(targetAttr), options(options) {
- // Add conversions. The order matters here: later ones will be tried earlier.
+static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
+ spirv::BuiltIn builtin) {
+ // Look through all global variables in the given `body` block and check if
+ // there is a spirv.GlobalVariable that has the same `builtin` attribute.
+ for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+ if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
+ spirv::SPIRVDialect::getAttributeName(
+ spirv::Decoration::BuiltIn))) {
+ auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+ if (varBuiltIn && *varBuiltIn == builtin) {
+ return varOp;
+ }
+ }
+ }
+ return nullptr;
+}
- // Allow all SPIR-V dialect specific types. This assumes all builtin types
- // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
- // were tried before.
- //
- // TODO: This assumes that the SPIR-V types are valid to use in the given
- // target environment, which should be the case if the whole pipeline is
- // driven by the same target environment. Still, we probably still want to
- // validate and convert to be safe.
- addConversion([](spirv::SPIRVType type) { return type; });
+/// Gets name of global variable for a builtin.
+std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
+ StringRef suffix) {
+ return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
+}
- addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
+/// Gets or inserts a global variable for a builtin within `body` block.
+static spirv::GlobalVariableOp
+getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
+ Type integerType, OpBuilder &builder,
+ StringRef prefix, StringRef suffix) {
+ if (auto varOp = getBuiltinVariable(body, builtin))
+ return varOp;
- addConversion([this](IntegerType intType) -> std::optional<Type> {
- if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
- return convertScalarType(this->targetEnv, this->options, scalarType);
- if (intType.getWidth() < 8)
- return convertSubByteIntegerType(this->options, intType);
- return Type();
- });
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&body);
- addConversion([this](FloatType floatType) -> std::optional<Type> {
- if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
- return convertScalarType(this->targetEnv, this->options, scalarType);
- return Type();
- });
+ spirv::GlobalVariableOp newVarOp;
+ switch (builtin) {
+ case spirv::BuiltIn::NumWorkgroups:
+ case spirv::BuiltIn::WorkgroupSize:
+ case spirv::BuiltIn::WorkgroupId:
+ case spirv::BuiltIn::LocalInvocationId:
+ case spirv::BuiltIn::GlobalInvocationId: {
+ auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
+ spirv::StorageClass::Input);
+ std::string name = getBuiltinVarName(builtin, prefix, suffix);
+ newVarOp =
+ builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+ break;
+ }
+ case spirv::BuiltIn::SubgroupId:
+ case spirv::BuiltIn::NumSubgroups:
+ case spirv::BuiltIn::SubgroupSize: {
+ auto ptrType =
+ spirv::PointerType::get(integerType, spirv::StorageClass::Input);
+ std::string name = getBuiltinVarName(builtin, prefix, suffix);
+ newVarOp =
+ builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
+ break;
+ }
+ default:
+ emitError(loc, "unimplemented builtin variable generation for ")
+ << stringifyBuiltIn(builtin);
+ }
+ return newVarOp;
+}
- addConversion([this](ComplexType complexType) {
- return convertComplexType(this->targetEnv, this->options, complexType);
- });
+//===----------------------------------------------------------------------===//
+// Push constant storage
+//===----------------------------------------------------------------------===//
- addConversion([this](VectorType vectorType) {
- return convertVectorType(this->targetEnv, this->options, vectorType);
- });
+/// Returns the pointer type for the push constant storage containing
+/// `elementCount` 32-bit integer values.
+static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
+ Builder &builder,
+ Type indexType) {
+ auto arrayType = spirv::ArrayType::get(indexType, elementCount,
+ /*stride=*/4);
+ auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
+ return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
+}
- addConversion([this](TensorType tensorType) {
- return convertTensorType(this->targetEnv, this->options, tensorType);
- });
+/// Returns the push constant varible containing `elementCount` 32-bit integer
+/// values in `body`. Returns null op if such an op does not exit.
+static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
+ unsigned elementCount) {
+ for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
+ auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
+ if (!ptrType)
+ continue;
- addConversion([this](MemRefType memRefType) {
- return convertMemrefType(this->targetEnv, this->options, memRefType);
- });
+ // Note that Vulkan requires "There must be no more than one push constant
+ // block statically used per shader entry point." So we should always reuse
+ // the existing one.
+ if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
+ auto numElements = cast<spirv::ArrayType>(
+ cast<spirv::StructType>(ptrType.getPointeeType())
+ .getElementType(0))
+ .getNumElements();
+ if (numElements == elementCount)
+ return varOp;
+ }
+ }
+ return nullptr;
+}
- // Register some last line of defense casting logic.
- addSourceMaterialization(
- [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
- return castToSourceType(this->targetEnv, builder, type, inputs, loc);
- });
- addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
- Location loc) {
- auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
- return std::optional<Value>(cast.getResult(0));
- });
+/// Gets or inserts a global variable for push constant storage containing
+/// `elementCount` 32-bit integer values in `block`.
+static spirv::GlobalVariableOp
+getOrInsertPushConstantVariable(Location loc, Block &block,
+ unsigned elementCount, OpBuilder &b,
+ Type indexType) {
+ if (auto varOp = getPushConstantVariable(block, elementCount))
+ return varOp;
+
+ auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
+ auto type = getPushConstantStorageType(elementCount, builder, indexType);
+ const char *name = "__push_constant_var__";
+ return builder.create<spirv::GlobalVariableOp>(loc, type, name,
+ /*initializer=*/nullptr);
}
//===----------------------------------------------------------------------===//
// func::FuncOp Conversion Patterns
//===----------------------------------------------------------------------===//
-namespace {
/// A pattern for rewriting function signature to convert arguments of functions
/// to be of valid SPIR-V types.
-class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
-public:
+struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
using OpConversionPattern<func::FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- auto fnType = funcOp.getFunctionType();
- if (fnType.getNumResults() > 1)
- return failure();
-
- TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
- for (const auto &argType : enumerate(fnType.getInputs())) {
- auto convertedType = getTypeConverter()->convertType(argType.value());
- if (!convertedType)
- return failure();
- signatureConverter.addInputs(argType.index(), convertedType);
- }
-
- Type resultType;
- if (fnType.getNumResults() == 1) {
- resultType = getTypeConverter()->convertType(fnType.getResult(0));
- if (!resultType)
+ ConversionPatternRewriter &rewriter) const override {
+ FunctionType fnType = funcOp.getFunctionType();
+ if (fnType.getNumResults() > 1)
return failure();
- }
-
- // Create the converted spirv.func op.
- auto newFuncOp = rewriter.create<spirv::FuncOp>(
- funcOp.getLoc(), funcOp.getName(),
- rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
- resultType ? TypeRange(resultType)
- : TypeRange()));
- // Copy over all attributes other than the function name and type.
- for (const auto &namedAttr : funcOp->getAttrs()) {
- if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
- namedAttr.getName() != SymbolTable::getSymbolAttrName())
- newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
- }
+ TypeConverter::SignatureConversion signatureConverter(
+ fnType.getNumInputs());
+ for (const auto &argType : enumerate(fnType.getInputs())) {
+ auto convertedType = getTypeConverter()->convertType(argType.value());
+ if (!convertedType)
+ return failure();
+ signatureConverter.addInputs(argType.index(), convertedType);
+ }
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- if (failed(rewriter.convertRegionTypes(
- &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
- return failure();
- rewriter.eraseOp(funcOp);
- return success();
-}
+ Type resultType;
+ if (fnType.getNumResults() == 1) {
+ resultType = getTypeConverter()->convertType(fnType.getResult(0));
+ if (!resultType)
+ return failure();
+ }
-void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
-}
+ // Create the converted spirv.func op.
+ auto newFuncOp = rewriter.create<spirv::FuncOp>(
+ funcOp.getLoc(), funcOp.getName(),
+ rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
+ resultType ? TypeRange(resultType)
+ : TypeRange()));
+
+ // Copy over all attributes other than the function name and type.
+ for (const auto &namedAttr : funcOp->getAttrs()) {
+ if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
+ namedAttr.getName() != SymbolTable::getSymbolAttrName())
+ newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
+ }
-//===----------------------------------------------------------------------===//
-// func::FuncOp Conversion Patterns
-//===----------------------------------------------------------------------===//
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ if (failed(rewriter.convertRegionTypes(
+ &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
+ return failure();
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
-namespace {
/// A pattern for rewriting function signature to convert vector arguments of
/// functions to be of valid types
struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
@@ -1015,17 +1050,11 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
return success();
}
};
-} // namespace
-
-void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
- patterns.add<FuncOpVectorUnroll>(patterns.getContext());
-}
//===----------------------------------------------------------------------===//
// func::ReturnOp Conversion Patterns
//===----------------------------------------------------------------------===//
-namespace {
/// A pattern for rewriting function signature and the return op to convert
/// vectors to be of valid types.
struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
@@ -1097,81 +1126,13 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
return success();
}
};
-} // namespace
-void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
- patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
-}
+} // namespace
//===----------------------------------------------------------------------===//
-// Builtin Variables
+// Public function for builtin variables
//===----------------------------------------------------------------------===//
-static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
- spirv::BuiltIn builtin) {
- // Look through all global variables in the given `body` block and check if
- // there is a spirv.GlobalVariable that has the same `builtin` attribute.
- for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
- if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
- spirv::SPIRVDialect::getAttributeName(
- spirv::Decoration::BuiltIn))) {
- auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
- if (varBuiltIn && *varBuiltIn == builtin) {
- return varOp;
- }
- }
- }
- return nullptr;
-}
-
-/// Gets name of global variable for a builtin.
-static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
- StringRef suffix) {
- return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
-}
-
-/// Gets or inserts a global variable for a builtin within `body` block.
-static spirv::GlobalVariableOp
-getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
- Type integerType, OpBuilder &builder,
- StringRef prefix, StringRef suffix) {
- if (auto varOp = getBuiltinVariable(body, builtin))
- return varOp;
-
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(&body);
-
- spirv::GlobalVariableOp newVarOp;
- switch (builtin) {
- case spirv::BuiltIn::NumWorkgroups:
- case spirv::BuiltIn::WorkgroupSize:
- case spirv::BuiltIn::WorkgroupId:
- case spirv::BuiltIn::LocalInvocationId:
- case spirv::BuiltIn::GlobalInvocationId: {
- auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
- spirv::StorageClass::Input);
- std::string name = getBuiltinVarName(builtin, prefix, suffix);
- newVarOp =
- builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
- break;
- }
- case spirv::BuiltIn::SubgroupId:
- case spirv::BuiltIn::NumSubgroups:
- case spirv::BuiltIn::SubgroupSize: {
- auto ptrType =
- spirv::PointerType::get(integerType, spirv::StorageClass::Input);
- std::string name = getBuiltinVarName(builtin, prefix, suffix);
- newVarOp =
- builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
- break;
- }
- default:
- emitError(loc, "unimplemented builtin variable generation for ")
- << stringifyBuiltIn(builtin);
- }
- return newVarOp;
-}
-
Value mlir::spirv::getBuiltinVariableValue(Operation *op,
spirv::BuiltIn builtin,
Type integerType, OpBuilder &builder,
@@ -1190,60 +1151,9 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
}
//===----------------------------------------------------------------------===//
-// Push constant storage
+// Public function for pushing constant storage
//===----------------------------------------------------------------------===//
-/// Returns the pointer type for the push constant storage containing
-/// `elementCount` 32-bit integer values.
-static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
- Builder &builder,
- Type indexType) {
- auto arrayType = spirv::ArrayType::get(indexType, elementCount,
- /*stride=*/4);
- auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
- return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
-}
-
-/// Returns the push constant varible containing `elementCount` 32-bit integer
-/// values in `body`. Returns null op if such an op does not exit.
-static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
- unsigned elementCount) {
- for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
- auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
- if (!ptrType)
- continue;
-
- // Note that Vulkan requires "There must be no more than one push constant
- // block statically used per shader entry point." So we should always reuse
- // the existing one.
- if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
- auto numElements = cast<spirv::ArrayType>(
- cast<spirv::StructType>(ptrType.getPointeeType())
- .getElementType(0))
- .getNumElements();
- if (numElements == elementCount)
- return varOp;
- }
- }
- return nullptr;
-}
-
-/// Gets or inserts a global variable for push constant storage containing
-/// `elementCount` 32-bit integer values in `block`.
-static spirv::GlobalVariableOp
-getOrInsertPushConstantVariable(Location loc, Block &block,
- unsigned elementCount, OpBuilder &b,
- Type indexType) {
- if (auto varOp = getPushConstantVariable(block, elementCount))
- return varOp;
-
- auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
- auto type = getPushConstantStorageType(elementCount, builder, indexType);
- const char *name = "__push_constant_var__";
- return builder.create<spirv::GlobalVariableOp>(loc, type, name,
- /*initializer=*/nullptr);
-}
-
Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
unsigned offset, Type integerType,
OpBuilder &builder) {
@@ -1267,7 +1177,7 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
}
//===----------------------------------------------------------------------===//
-// Index calculation
+// Public functions for index calculation
//===----------------------------------------------------------------------===//
Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
@@ -1375,6 +1285,81 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
builder);
}
+//===----------------------------------------------------------------------===//
+// SPIR-V TypeConverter
+//===----------------------------------------------------------------------===//
+
+SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
+ const SPIRVConversionOptions &options)
+ : targetEnv(targetAttr), options(options) {
+ // Add conversions. The order matters here: later ones will be tried earlier.
+
+ // Allow all SPIR-V dialect specific types. This assumes all builtin types
+ // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
+ // were tried before.
+ //
+ // TODO: This assumes that the SPIR-V types are valid to use in the given
+ // target environment, which should be the case if the whole pipeline is
+ // driven by the same target environment. Still, we probably still want to
+ // validate and convert to be safe.
+ addConversion([](spirv::SPIRVType type) { return type; });
+
+ addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
+
+ addConversion([this](IntegerType intType) -> std::optional<Type> {
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
+ return convertScalarType(this->targetEnv, this->options, scalarType);
+ if (intType.getWidth() < 8)
+ return convertSubByteIntegerType(this->options, intType);
+ return Type();
+ });
+
+ addConversion([this](FloatType floatType) -> std::optional<Type> {
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
+ return convertScalarType(this->targetEnv, this->options, scalarType);
+ return Type();
+ });
+
+ addConversion([this](ComplexType complexType) {
+ return convertComplexType(this->targetEnv, this->options, complexType);
+ });
+
+ addConversion([this](VectorType vectorType) {
+ return convertVectorType(this->targetEnv, this->options, vectorType);
+ });
+
+ addConversion([this](TensorType tensorType) {
+ return convertTensorType(this->targetEnv, this->options, tensorType);
+ });
+
+ addConversion([this](MemRefType memRefType) {
+ return convertMemrefType(this->targetEnv, this->options, memRefType);
+ });
+
+ // Register some last line of defense casting logic.
+ addSourceMaterialization(
+ [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
+ return castToSourceType(this->targetEnv, builder, type, inputs, loc);
+ });
+ addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) {
+ auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ return std::optional<Value>(cast.getResult(0));
+ });
+}
+
+Type SPIRVTypeConverter::getIndexType() const {
+ return ::getIndexType(getContext(), options);
+}
+
+MLIRContext *SPIRVTypeConverter::getContext() const {
+ return targetEnv.getAttr().getContext();
+}
+
+bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
+ return targetEnv.allows(capability);
+}
+
//===----------------------------------------------------------------------===//
// SPIR-V ConversionTarget
//===----------------------------------------------------------------------===//
@@ -1468,3 +1453,20 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
return true;
}
+
+//===----------------------------------------------------------------------===//
+// Public functions for populating patterns
+//===----------------------------------------------------------------------===//
+
+void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
+}
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<FuncOpVectorUnroll>(patterns.getContext());
+}
+
+void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
+}
More information about the Mlir-commits
mailing list