[llvm-branch-commits] [mlir] 7c3ae48 - [mlir][spirv] Replace SPIRVOpLowering with OpConversionPattern
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Jan 9 05:12:10 PST 2021
Author: Lei Zhang
Date: 2021-01-09T08:04:53-05:00
New Revision: 7c3ae48fe85f535a5e35db9898c7bf2e3baeb6b4
URL: https://github.com/llvm/llvm-project/commit/7c3ae48fe85f535a5e35db9898c7bf2e3baeb6b4
DIFF: https://github.com/llvm/llvm-project/commit/7c3ae48fe85f535a5e35db9898c7bf2e3baeb6b4.diff
LOG: [mlir][spirv] Replace SPIRVOpLowering with OpConversionPattern
The dialect conversion framework was enhanced to handle type
conversion automatically. OpConversionPattern already contains
a pointer to the TypeConverter. There is no need to duplicate it
in a separate subclass. This removes the only reason for a
SPIRVOpLowering subclass. It adapts to use core infrastructure
and simplifies the code.
Also added a utility function to OpConversionPattern for getting
TypeConverter as a certain subclass.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D94080
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index fddf84859bc2..4143091543d6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -63,19 +63,6 @@ class SPIRVTypeConverter : public TypeConverter {
spirv::TargetEnv targetEnv;
};
-/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.
-template <typename SourceOp>
-class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
-public:
- SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
- : OpConversionPattern<SourceOp>(context, benefit),
- typeConverter(typeConverter) {}
-
-protected:
- SPIRVTypeConverter &typeConverter;
-};
-
/// Appends to a pattern list additional patterns for translating the builtin
/// `func` op to the SPIR-V dialect. These patterns do not handle shader
/// interface/ABI; they convert function parameters to be of SPIR-V allowed
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e02cf8fe4c0a..51c7788ffb14 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -341,6 +341,13 @@ class ConversionPattern : public RewritePattern {
/// does not require type conversion.
TypeConverter *getTypeConverter() const { return typeConverter; }
+ template <typename ConverterTy>
+ std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
+ ConverterTy *>
+ getTypeConverter() const {
+ return static_cast<ConverterTy *>(typeConverter);
+ }
+
protected:
/// See `RewritePattern::RewritePattern` for information on the other
/// available constructors.
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index e84269e9418d..d66f9c66c1da 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -17,6 +17,8 @@
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
@@ -26,9 +28,9 @@ namespace {
/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
/// builtin variables.
template <typename SourceOp, spirv::BuiltIn builtin>
-class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
+class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
public:
- using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
@@ -38,9 +40,9 @@ class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
/// Pattern lowering subgroup size/id to loading SPIR-V invocation
/// builtin variables.
template <typename SourceOp, spirv::BuiltIn builtin>
-class SingleDimLaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
+class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
public:
- using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
@@ -51,9 +53,9 @@ class SingleDimLaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
/// a constant with WorkgroupSize decoration. So here we cannot generate a
/// builtin variable; instead the information in the `spv.entry_point_abi`
/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
-class WorkGroupSizeConversion : public SPIRVOpLowering<gpu::BlockDimOp> {
+class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
public:
- using SPIRVOpLowering<gpu::BlockDimOp>::SPIRVOpLowering;
+ using OpConversionPattern<gpu::BlockDimOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
@@ -61,9 +63,9 @@ class WorkGroupSizeConversion : public SPIRVOpLowering<gpu::BlockDimOp> {
};
/// Pattern to convert a kernel function in GPU dialect within a spv.module.
-class GPUFuncOpConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> {
+class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
public:
- using SPIRVOpLowering<gpu::GPUFuncOp>::SPIRVOpLowering;
+ using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
@@ -74,9 +76,9 @@ class GPUFuncOpConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> {
};
/// Pattern to convert a gpu.module to a spv.module.
-class GPUModuleConversion final : public SPIRVOpLowering<gpu::GPUModuleOp> {
+class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
public:
- using SPIRVOpLowering<gpu::GPUModuleOp>::SPIRVOpLowering;
+ using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
@@ -85,9 +87,9 @@ class GPUModuleConversion final : public SPIRVOpLowering<gpu::GPUModuleOp> {
/// Pattern to convert a gpu.return into a SPIR-V return.
// TODO: This can go to DRR when GPU return has operands.
-class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
+class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
public:
- using SPIRVOpLowering<gpu::ReturnOp>::SPIRVOpLowering;
+ using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
@@ -102,17 +104,14 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
static Optional<int32_t> getLaunchConfigIndex(Operation *op) {
auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
- if (!dimAttr) {
- return {};
- }
- if (dimAttr.getValue() == "x") {
- return 0;
- } else if (dimAttr.getValue() == "y") {
- return 1;
- } else if (dimAttr.getValue() == "z") {
- return 2;
- }
- return {};
+ if (!dimAttr)
+ return llvm::None;
+
+ return llvm::StringSwitch<Optional<int32_t>>(dimAttr.getValue())
+ .Case("x", 0)
+ .Case("y", 1)
+ .Case("z", 2)
+ .Default(llvm::None);
}
template <typename SourceOp, spirv::BuiltIn builtin>
@@ -150,7 +149,8 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite(
auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
- auto convertedType = typeConverter.convertType(op.getResult().getType());
+ auto convertedType =
+ getTypeConverter()->convertType(op.getResult().getType());
if (!convertedType)
return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
@@ -164,7 +164,7 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite(
// Legalizes a GPU function as an entry SPIR-V function.
static spirv::FuncOp
-lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
+lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter,
spirv::EntryPointABIAttr entryPointInfo,
ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
@@ -266,7 +266,7 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite(
return failure();
}
spirv::FuncOp newFuncOp = lowerAsEntryFunction(
- funcOp, typeConverter, rewriter, entryPointAttr, argABI);
+ funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
if (!newFuncOp)
return failure();
newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
@@ -344,5 +344,5 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
spirv::BuiltIn::NumSubgroups>,
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
spirv::BuiltIn::SubgroupSize>,
- WorkGroupSizeConversion>(context, typeConverter);
+ WorkGroupSizeConversion>(typeConverter, context);
}
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index 8133a37aa7ad..0db760b17d7c 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -44,10 +45,9 @@ namespace {
/// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
/// that the linalg.generic op is performing reduction with a workload size that
/// can fit in one workgroup.
-class SingleWorkgroupReduction final
- : public SPIRVOpLowering<linalg::GenericOp> {
-public:
- using SPIRVOpLowering<linalg::GenericOp>::SPIRVOpLowering;
+struct SingleWorkgroupReduction final
+ : public OpConversionPattern<linalg::GenericOp> {
+ using OpConversionPattern::OpConversionPattern;
/// Matches the given linalg.generic op as performing reduction and returns
/// the binary op kind if successful.
@@ -142,9 +142,11 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
// TODO: Load to Workgroup storage class first.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+
// Get the input element accessed by this invocation.
Value inputElementPtr = spirv::getElementPtr(
- typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
+ *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
// Perform the group reduction operation.
@@ -163,10 +165,10 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
// Get the output element accessed by this reduction.
Value zero = spirv::ConstantOp::getZero(
- typeConverter.getIndexType(rewriter.getContext()), loc, rewriter);
+ typeConverter->getIndexType(rewriter.getContext()), loc, rewriter);
SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
Value outputElementPtr =
- spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput,
+ spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,
zeroIndices, loc, rewriter);
// Write out the final reduction result. This should be only conducted by one
@@ -204,5 +206,5 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<SingleWorkgroupReduction>(context, typeConverter);
+ patterns.insert<SingleWorkgroupReduction>(typeConverter, context);
}
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index da2488db1182..93caa3294408 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -16,9 +16,14 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Context
+//===----------------------------------------------------------------------===//
+
namespace mlir {
struct ScfToSPIRVContextImpl {
// Map between the spirv region control flow operation (spv.loop or
@@ -37,20 +42,40 @@ struct ScfToSPIRVContextImpl {
ScfToSPIRVContext::ScfToSPIRVContext() {
impl = std::make_unique<ScfToSPIRVContextImpl>();
}
+
ScfToSPIRVContext::~ScfToSPIRVContext() = default;
+//===----------------------------------------------------------------------===//
+// Pattern Declarations
+//===----------------------------------------------------------------------===//
+
namespace {
/// Common class for all vector to GPU patterns.
template <typename OpTy>
-class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
+class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
public:
SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
ScfToSPIRVContextImpl *scfToSPIRVContext)
- : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
- scfToSPIRVContext(scfToSPIRVContext) {}
+ : OpConversionPattern<OpTy>::OpConversionPattern(context),
+ scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
protected:
ScfToSPIRVContextImpl *scfToSPIRVContext;
+ // FIXME: We explicitly keep a reference of the type converter here instead of
+ // passing it to OpConversionPattern during construction. This effectively
+ // bypasses the conversion framework's automation on type conversion. This is
+ // needed right now because the conversion framework will unconditionally
+ // legalize all types used by SCF ops upon discovering them, for example, the
+ // types of loop carried values. We use SPIR-V variables for those loop
+ // carried values. Depending on the available capabilities, the SPIR-V
+ // variable can be
diff erent, for example, cooperative matrix or normal
+ // variable. We'd like to detach the conversion of the loop carried values
+ // from the SCF ops (which is mainly a region). So we need to "mark" types
+ // used by SCF ops as legal, if to use the conversion framework for type
+ // conversion. There isn't a straightforward way to do that yet, as when
+ // converting types, ops aren't taken into consideration. Therefore, we just
+ // bypass the framework's type conversion for now.
+ SPIRVTypeConverter &typeConverter;
};
/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
@@ -90,7 +115,6 @@ class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
/// we load the value from the allocation and use it as the SCF op result.
template <typename ScfOp, typename OpTy>
static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
- SPIRVTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter,
ScfToSPIRVContextImpl *scfToSPIRVContext,
ArrayRef<Type> returnTypes) {
@@ -117,7 +141,7 @@ static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
}
//===----------------------------------------------------------------------===//
-// scf::ForOp.
+// scf::ForOp
//===----------------------------------------------------------------------===//
LogicalResult
@@ -196,13 +220,12 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
SmallVector<Type, 8> initTypes;
for (auto arg : forOperands.initArgs())
initTypes.push_back(arg.getType());
- replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
- scfToSPIRVContext, initTypes);
+ replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
return success();
}
//===----------------------------------------------------------------------===//
-// scf::IfOp.
+// scf::IfOp
//===----------------------------------------------------------------------===//
LogicalResult
@@ -255,11 +278,15 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
auto convertedType = typeConverter.convertType(result.getType());
returnTypes.push_back(convertedType);
}
- replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
- scfToSPIRVContext, returnTypes);
+ replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
+ returnTypes);
return success();
}
+//===----------------------------------------------------------------------===//
+// scf::YieldOp
+//===----------------------------------------------------------------------===//
+
/// Yield is lowered to stores to the VariableOp created during lowering of the
/// parent region. For loops we also need to update the branch looping back to
/// the header with the loop carried values.
@@ -290,6 +317,10 @@ LogicalResult TerminatorOpConversion::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// Hooks
+//===----------------------------------------------------------------------===//
+
void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 88d0a818b230..4010484a8e89 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -234,9 +234,9 @@ namespace {
/// to Workgroup memory when the size is constant. Note that this pattern needs
/// to be applied in a pass that runs at least at spv.module scope since it wil
/// ladd global variables into the spv.module.
-class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
+class AllocOpPattern final : public OpConversionPattern<AllocOp> {
public:
- using SPIRVOpLowering<AllocOp>::SPIRVOpLowering;
+ using OpConversionPattern<AllocOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
@@ -246,7 +246,7 @@ class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
return operation.emitError("unhandled allocation type");
// Get the SPIR-V type for the allocation.
- Type spirvType = typeConverter.convertType(allocType);
+ Type spirvType = getTypeConverter()->convertType(allocType);
// Insert spv.globalVariable for this allocation.
Operation *parent =
@@ -276,9 +276,9 @@ class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
/// Removed a deallocation if it is a supported allocation. Currently only
/// removes deallocation if the memory space is workgroup memory.
-class DeallocOpPattern final : public SPIRVOpLowering<DeallocOp> {
+class DeallocOpPattern final : public OpConversionPattern<DeallocOp> {
public:
- using SPIRVOpLowering<DeallocOp>::SPIRVOpLowering;
+ using OpConversionPattern<DeallocOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
@@ -293,15 +293,15 @@ class DeallocOpPattern final : public SPIRVOpLowering<DeallocOp> {
/// Converts unary and binary standard operations to SPIR-V operations.
template <typename StdOp, typename SPIRVOp>
-class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
+class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
public:
- using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+ using OpConversionPattern<StdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() <= 2);
- auto dstType = this->typeConverter.convertType(operation.getType());
+ auto dstType = this->getTypeConverter()->convertType(operation.getType());
if (!dstType)
return failure();
if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
@@ -318,9 +318,9 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
///
/// This cannot be merged into the template unary/binary pattern due to
/// Vulkan restrictions over spv.SRem and spv.SMod.
-class SignedRemIOpPattern final : public SPIRVOpLowering<SignedRemIOp> {
+class SignedRemIOpPattern final : public OpConversionPattern<SignedRemIOp> {
public:
- using SPIRVOpLowering<SignedRemIOp>::SPIRVOpLowering;
+ using OpConversionPattern<SignedRemIOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
@@ -332,16 +332,16 @@ class SignedRemIOpPattern final : public SPIRVOpLowering<SignedRemIOp> {
/// boolean values, SPIR-V uses
diff erent operations (`SPIRVLogicalOp`). For
/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
-class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
+class BitwiseOpPattern final : public OpConversionPattern<StdOp> {
public:
- using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+ using OpConversionPattern<StdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 2);
auto dstType =
- this->typeConverter.convertType(operation.getResult().getType());
+ this->getTypeConverter()->convertType(operation.getResult().getType());
if (!dstType)
return failure();
if (isBoolScalarOrVector(operands.front().getType())) {
@@ -356,9 +356,10 @@ class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
};
/// Converts composite std.constant operation to spv.constant.
-class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
+class ConstantCompositeOpPattern final
+ : public OpConversionPattern<ConstantOp> {
public:
- using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+ using OpConversionPattern<ConstantOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
@@ -366,9 +367,9 @@ class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
};
/// Converts scalar std.constant operation to spv.constant.
-class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
+class ConstantScalarOpPattern final : public OpConversionPattern<ConstantOp> {
public:
- using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+ using OpConversionPattern<ConstantOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
@@ -376,9 +377,9 @@ class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
};
/// Converts floating-point comparison operations to SPIR-V ops.
-class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
+class CmpFOpPattern final : public OpConversionPattern<CmpFOp> {
public:
- using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
+ using OpConversionPattern<CmpFOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
@@ -386,9 +387,9 @@ class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
};
/// Converts integer compare operation on i1 type operands to SPIR-V ops.
-class BoolCmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
+class BoolCmpIOpPattern final : public OpConversionPattern<CmpIOp> {
public:
- using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
+ using OpConversionPattern<CmpIOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
@@ -396,9 +397,9 @@ class BoolCmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
};
/// Converts integer compare operation to SPIR-V ops.
-class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
+class CmpIOpPattern final : public OpConversionPattern<CmpIOp> {
public:
- using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
+ using OpConversionPattern<CmpIOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
@@ -406,9 +407,9 @@ class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
};
/// Converts std.load to spv.Load.
-class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
+class IntLoadOpPattern final : public OpConversionPattern<LoadOp> {
public:
- using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
+ using OpConversionPattern<LoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
@@ -416,9 +417,9 @@ class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
};
/// Converts std.load to spv.Load.
-class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
+class LoadOpPattern final : public OpConversionPattern<LoadOp> {
public:
- using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
+ using OpConversionPattern<LoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
@@ -426,9 +427,9 @@ class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
};
/// Converts std.return to spv.Return.
-class ReturnOpPattern final : public SPIRVOpLowering<ReturnOp> {
+class ReturnOpPattern final : public OpConversionPattern<ReturnOp> {
public:
- using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
+ using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
@@ -436,18 +437,18 @@ class ReturnOpPattern final : public SPIRVOpLowering<ReturnOp> {
};
/// Converts std.select to spv.Select.
-class SelectOpPattern final : public SPIRVOpLowering<SelectOp> {
+class SelectOpPattern final : public OpConversionPattern<SelectOp> {
public:
- using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
+ using OpConversionPattern<SelectOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.store to spv.Store on integers.
-class IntStoreOpPattern final : public SPIRVOpLowering<StoreOp> {
+class IntStoreOpPattern final : public OpConversionPattern<StoreOp> {
public:
- using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
+ using OpConversionPattern<StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
@@ -455,9 +456,9 @@ class IntStoreOpPattern final : public SPIRVOpLowering<StoreOp> {
};
/// Converts std.store to spv.Store.
-class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
+class StoreOpPattern final : public OpConversionPattern<StoreOp> {
public:
- using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
+ using OpConversionPattern<StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
@@ -466,9 +467,9 @@ class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
/// Converts std.zexti to spv.Select if the type of source is i1 or vector of
/// i1.
-class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
+class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
public:
- using SPIRVOpLowering<ZeroExtendIOp>::SPIRVOpLowering;
+ using OpConversionPattern<ZeroExtendIOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
@@ -477,7 +478,8 @@ class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
if (!isBoolScalarOrVector(srcType))
return failure();
- auto dstType = this->typeConverter.convertType(op.getResult().getType());
+ auto dstType =
+ this->getTypeConverter()->convertType(op.getResult().getType());
Location loc = op.getLoc();
Attribute zeroAttr, oneAttr;
if (auto vectorType = dstType.dyn_cast<VectorType>()) {
@@ -497,9 +499,9 @@ class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
/// Converts type-casting standard operations to SPIR-V operations.
template <typename StdOp, typename SPIRVOp>
-class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
+class TypeCastingOpPattern final : public OpConversionPattern<StdOp> {
public:
- using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+ using OpConversionPattern<StdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
@@ -509,7 +511,7 @@ class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
if (isBoolScalarOrVector(srcType))
return failure();
auto dstType =
- this->typeConverter.convertType(operation.getResult().getType());
+ this->getTypeConverter()->convertType(operation.getResult().getType());
if (dstType == srcType) {
// Due to type conversion, we are seeing the same source and target type.
// Then we can just erase this operation by forwarding its operand.
@@ -523,9 +525,9 @@ class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
};
/// Converts std.xor to SPIR-V operations.
-class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
+class XOrOpPattern final : public OpConversionPattern<XOrOp> {
public:
- using SPIRVOpLowering<XOrOp>::SPIRVOpLowering;
+ using OpConversionPattern<XOrOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
@@ -562,7 +564,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
// std.constant should only have vector or tenor types.
assert((srcType.isa<VectorType, RankedTensorType>()));
- auto dstType = typeConverter.convertType(srcType);
+ auto dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return failure();
@@ -645,7 +647,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
if (!srcType.isIntOrIndexOrFloat())
return failure();
- Type dstType = typeConverter.convertType(srcType);
+ Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return failure();
@@ -771,7 +773,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
- operandType != this->typeConverter.convertType(operandType)) { \
+ operandType != this->getTypeConverter()->convertType(operandType)) { \
return cmpIOp.emitError( \
"bitwidth emulation is not implemented yet on unsigned op"); \
} \
@@ -808,6 +810,8 @@ IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (!memrefType.getElementType().isSignlessInteger())
return failure();
+
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
spirv::AccessChainOp accessChainOp =
spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
loadOperands.indices(), loc, rewriter);
@@ -881,9 +885,9 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (memrefType.getElementType().isSignlessInteger())
return failure();
- auto loadPtr =
- spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
- loadOperands.indices(), loadOp.getLoc(), rewriter);
+ auto loadPtr = spirv::getElementPtr(
+ *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+ loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
return success();
}
@@ -933,6 +937,7 @@ IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
return failure();
auto loc = storeOp.getLoc();
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
spirv::AccessChainOp accessChainOp =
spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
storeOperands.indices(), loc, rewriter);
@@ -1010,8 +1015,9 @@ StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto storePtr =
- spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
- storeOperands.indices(), storeOp.getLoc(), rewriter);
+ spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
+ storeOperands.memref(), storeOperands.indices(),
+ storeOp.getLoc(), rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
storeOperands.value());
return success();
@@ -1029,7 +1035,7 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
if (isBoolScalarOrVector(operands.front().getType()))
return failure();
- auto dstType = typeConverter.convertType(xorOp.getType());
+ auto dstType = getTypeConverter()->convertType(xorOp.getType());
if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
@@ -1096,7 +1102,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
- TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(context,
- typeConverter);
+ TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(typeConverter,
+ context);
}
} // namespace mlir
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a2735e646bec..1509836ef2e2 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -24,8 +24,9 @@ using namespace mlir;
namespace {
struct VectorBroadcastConvert final
- : public SPIRVOpLowering<vector::BroadcastOp> {
- using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
+ : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
LogicalResult
matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@@ -43,8 +44,9 @@ struct VectorBroadcastConvert final
};
struct VectorExtractOpConvert final
- : public SPIRVOpLowering<vector::ExtractOp> {
- using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
+ : public OpConversionPattern<vector::ExtractOp> {
+ using OpConversionPattern::OpConversionPattern;
+
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@@ -60,8 +62,10 @@ struct VectorExtractOpConvert final
}
};
-struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
- using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
+struct VectorInsertOpConvert final
+ : public OpConversionPattern<vector::InsertOp> {
+ using OpConversionPattern::OpConversionPattern;
+
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@@ -78,8 +82,9 @@ struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
};
struct VectorExtractElementOpConvert final
- : public SPIRVOpLowering<vector::ExtractElementOp> {
- using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering;
+ : public OpConversionPattern<vector::ExtractElementOp> {
+ using OpConversionPattern::OpConversionPattern;
+
LogicalResult
matchAndRewrite(vector::ExtractElementOp extractElementOp,
ArrayRef<Value> operands,
@@ -96,8 +101,9 @@ struct VectorExtractElementOpConvert final
};
struct VectorInsertElementOpConvert final
- : public SPIRVOpLowering<vector::InsertElementOp> {
- using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering;
+ : public OpConversionPattern<vector::InsertElementOp> {
+ using OpConversionPattern::OpConversionPattern;
+
LogicalResult
matchAndRewrite(vector::InsertElementOp insertElementOp,
ArrayRef<Value> operands,
@@ -120,5 +126,5 @@ void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
VectorInsertOpConvert, VectorExtractElementOpConvert,
- VectorInsertElementOpConvert>(context, typeConverter);
+ VectorInsertElementOpConvert>(typeConverter, context);
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index aeff47a831ef..9b62b4289c77 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -151,9 +151,10 @@ namespace {
/// variable ABI attributes attached to function arguments and converts all
/// function argument uses to those global variables. This is necessary because
/// Vulkan requires all shader entry points to be of void(void) type.
-class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> {
+class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
public:
- using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering;
+ using OpConversionPattern<spirv::FuncOp>::OpConversionPattern;
+
LogicalResult
matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
@@ -214,7 +215,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
}
signatureConverter.remapInput(argType.index(), replacement);
}
- if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter,
+ if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *getTypeConverter(),
&signatureConverter)))
return failure();
@@ -246,7 +247,7 @@ void LowerABIAttributesPass::runOnOperation() {
});
OwningRewritePatternList patterns;
- patterns.insert<ProcessInterfaceVarABI>(context, typeConverter);
+ patterns.insert<ProcessInterfaceVarABI>(typeConverter, context);
ConversionTarget target(*context);
// "Legal" function ops should have no interface variable ABI attributes.
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 9393f3df6425..1c0445290402 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
@@ -459,9 +460,9 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
namespace {
/// A pattern for rewriting function signature to convert arguments of functions
/// to be of valid SPIR-V types.
-class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
+class FuncOpConversion final : public OpConversionPattern<FuncOp> {
public:
- using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+ using OpConversionPattern<FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
@@ -478,7 +479,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
for (auto argType : enumerate(fnType.getInputs())) {
- auto convertedType = typeConverter.convertType(argType.value());
+ auto convertedType = getTypeConverter()->convertType(argType.value());
if (!convertedType)
return failure();
signatureConverter.addInputs(argType.index(), convertedType);
@@ -486,7 +487,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
Type resultType;
if (fnType.getNumResults() == 1)
- resultType = typeConverter.convertType(fnType.getResult(0));
+ resultType = getTypeConverter()->convertType(fnType.getResult(0));
// Create the converted spv.func op.
auto newFuncOp = rewriter.create<spirv::FuncOp>(
@@ -504,8 +505,8 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
- if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
- &signatureConverter)))
+ if (failed(rewriter.convertRegionTypes(
+ &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
return failure();
rewriter.eraseOp(funcOp);
return success();
@@ -514,7 +515,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
void mlir::populateBuiltinFuncToSPIRVPatterns(
MLIRContext *context, SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<FuncOpConversion>(context, typeConverter);
+ patterns.insert<FuncOpConversion>(typeConverter, context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 850e22465d44..9e972c3a6c57 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s
//===----------------------------------------------------------------------===//
// std arithmetic ops
@@ -628,49 +628,59 @@ func @fptosi2(%arg0 : f16) -> i16 {
// -----
-// Checks that cast types will be adjusted when no special capabilities for
-// non-32-bit scalar types.
+// Checks that cast types will be adjusted when missing special capabilities for
+// certain non-32-bit scalar types.
module attributes {
- spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
+ spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float64], []>, {}>
} {
// CHECK-LABEL: @fpext1
// CHECK-SAME: %[[ARG:.*]]: f32
-func @fpext1(%arg0: f16) {
- // CHECK-NEXT: "use"(%[[ARG]])
+func @fpext1(%arg0: f16) -> f64 {
+ // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64
%0 = std.fpext %arg0 : f16 to f64
- "use"(%0) : (f64) -> ()
+ return %0: f64
}
// CHECK-LABEL: @fpext2
// CHECK-SAME: %[[ARG:.*]]: f32
-func @fpext2(%arg0 : f32) {
- // CHECK-NEXT: "use"(%[[ARG]])
+func @fpext2(%arg0 : f32) -> f64 {
+ // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64
%0 = std.fpext %arg0 : f32 to f64
- "use"(%0) : (f64) -> ()
+ return %0: f64
}
+} // end module
+
+// -----
+
+// Checks that cast types will be adjusted when missing special capabilities for
+// certain non-32-bit scalar types.
+module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}>
+} {
+
// CHECK-LABEL: @fptrunc1
// CHECK-SAME: %[[ARG:.*]]: f32
-func @fptrunc1(%arg0 : f64) {
- // CHECK-NEXT: "use"(%[[ARG]])
+func @fptrunc1(%arg0 : f64) -> f16 {
+ // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16
%0 = std.fptrunc %arg0 : f64 to f16
- "use"(%0) : (f16) -> ()
+ return %0: f16
}
// CHECK-LABEL: @fptrunc2
// CHECK-SAME: %[[ARG:.*]]: f32
-func @fptrunc2(%arg0: f32) {
- // CHECK-NEXT: "use"(%[[ARG]])
+func @fptrunc2(%arg0: f32) -> f16 {
+ // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16
%0 = std.fptrunc %arg0 : f32 to f16
- "use"(%0) : (f16) -> ()
+ return %0: f16
}
// CHECK-LABEL: @sitofp
-func @sitofp(%arg0 : i64) {
+func @sitofp(%arg0 : i64) -> f64 {
// CHECK: spv.ConvertSToF %{{.*}} : i32 to f32
%0 = std.sitofp %arg0 : i64 to f64
- "use"(%0) : (f64) -> ()
+ return %0: f64
}
} // end module
More information about the llvm-branch-commits
mailing list