[Mlir-commits] [mlir] c50d0fe - [mlir][arith][spirv] Clean up arith-to-spirv. NFC.
Jakub Kuderski
llvmlistbot at llvm.org
Mon Nov 14 17:55:14 PST 2022
Author: Jakub Kuderski
Date: 2022-11-14T20:54:27-05:00
New Revision: c50d0fe570e630c3b1e56a5aee17568e955134b3
URL: https://github.com/llvm/llvm-project/commit/c50d0fe570e630c3b1e56a5aee17568e955134b3
DIFF: https://github.com/llvm/llvm-project/commit/c50d0fe570e630c3b1e56a5aee17568e955134b3.diff
LOG: [mlir][arith][spirv] Clean up arith-to-spirv. NFC.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D137978
Added:
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a284be8ce939..d550e0e33f3e 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -9,8 +9,8 @@
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "../SPIRVCommon/Pattern.h"
-#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
@@ -21,6 +21,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Debug.h"
#include <cassert>
+#include <memory>
namespace mlir {
#define GEN_PASS_DEF_CONVERTARITHTOSPIRV
@@ -40,7 +41,7 @@ namespace {
/// Converts composite arith.constant operation to spirv.Constant.
struct ConstantCompositeOpPattern final
: public OpConversionPattern<arith::ConstantOp> {
- using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
@@ -50,7 +51,7 @@ struct ConstantCompositeOpPattern final
/// Converts scalar arith.constant operation to spirv.Constant.
struct ConstantScalarOpPattern final
: public OpConversionPattern<arith::ConstantOp> {
- using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
@@ -62,7 +63,7 @@ struct ConstantScalarOpPattern final
/// This cannot be merged into the template unary/binary pattern due to Vulkan
/// restrictions over spirv.SRem and spirv.SMod.
struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
- using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
@@ -71,7 +72,7 @@ struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
/// Converts arith.remsi to OpenCL SPIR-V ops.
struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
- using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
@@ -93,7 +94,7 @@ struct BitwiseOpPattern final : public OpConversionPattern<Op> {
/// Converts arith.xori to SPIR-V operations.
struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
- using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
@@ -103,7 +104,7 @@ struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
/// vector of i1.
struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
- using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
@@ -113,7 +114,7 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
/// of i1.
struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
- using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
@@ -123,7 +124,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
/// of i1.
struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
- using OpConversionPattern<arith::ExtSIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
@@ -133,7 +134,7 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
/// of i1.
struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
- using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
@@ -143,7 +144,7 @@ struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
/// of i1.
struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
- using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
@@ -163,7 +164,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
/// Converts integer compare operation on i1 type operands to SPIR-V ops.
class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
public:
- using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
@@ -173,7 +174,7 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
/// Converts integer compare operation to SPIR-V ops.
class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
public:
- using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
@@ -183,7 +184,7 @@ class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
/// Converts floating-point comparison operations to SPIR-V ops.
class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
public:
- using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
@@ -194,7 +195,7 @@ class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
/// Kernel capability.
class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
public:
- using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
@@ -216,7 +217,7 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
class AddICarryOpPattern final
: public OpConversionPattern<arith::AddUICarryOp> {
public:
- using OpConversionPattern<arith::AddUICarryOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
@@ -225,7 +226,7 @@ class AddICarryOpPattern final
/// Converts arith.select to spirv.Select.
class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
public:
- using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
@@ -254,7 +255,7 @@ static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
return boolAttr;
if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
return builder.getBoolAttr(intAttr.getValue().getBoolValue());
- return BoolAttr();
+ return {};
}
/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
@@ -281,7 +282,7 @@ static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
<< "' illegal: cannot fit into target type '"
<< dstType << "'\n");
- return IntegerAttr();
+ return {};
}
/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
@@ -346,7 +347,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
// arith.constant should only have vector or tenor types.
assert((srcType.isa<VectorType, RankedTensorType>()));
- auto dstType = getTypeConverter()->convertType(srcType);
+ Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return failure();
@@ -473,7 +474,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
// IndexType or IntegerType. Index values are converted to 32-bit integer
// values when converting to SPIR-V.
auto srcAttr = cstAttr.cast<IntegerAttr>();
- auto dstAttr =
+ IntegerAttr dstAttr =
convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
if (!dstAttr)
return failure();
@@ -577,7 +578,7 @@ LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
return failure();
- auto dstType = getTypeConverter()->convertType(op.getType());
+ Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
@@ -598,7 +599,7 @@ LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
return failure();
- auto dstType = getTypeConverter()->convertType(op.getType());
+ Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
@@ -613,16 +614,15 @@ LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
LogicalResult
UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto srcType = adaptor.getOperands().front().getType();
+ Type srcType = adaptor.getOperands().front().getType();
if (!isBoolScalarOrVector(srcType))
return failure();
- auto dstType =
- this->getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = getTypeConverter()->convertType(op.getResult().getType());
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
- rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(
op, dstType, adaptor.getOperands().front(), one, zero);
return success();
}
@@ -670,16 +670,15 @@ ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
LogicalResult
ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto srcType = adaptor.getOperands().front().getType();
+ Type srcType = adaptor.getOperands().front().getType();
if (!isBoolScalarOrVector(srcType))
return failure();
- auto dstType =
- this->getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = getTypeConverter()->convertType(op.getResult().getType());
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
- rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(
op, dstType, adaptor.getOperands().front(), one, zero);
return success();
}
@@ -691,8 +690,7 @@ ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
LogicalResult
TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto dstType =
- this->getTypeConverter()->convertType(op.getResult().getType());
+ Type dstType = getTypeConverter()->convertType(op.getResult().getType());
if (!isBoolScalarOrVector(dstType))
return failure();
@@ -719,8 +717,8 @@ LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(adaptor.getOperands().size() == 1);
- auto srcType = adaptor.getOperands().front().getType();
- auto dstType =
+ Type srcType = adaptor.getOperands().front().getType();
+ Type dstType =
this->getTypeConverter()->convertType(op.getResult().getType());
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
return failure();
@@ -769,9 +767,9 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
Type type = rewriter.getI32Type();
if (auto vectorType = dstType.dyn_cast<VectorType>())
type = VectorType::get(vectorType.getShape(), type);
- auto extLhs =
+ Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
- auto extRhs =
+ Value extRhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
@@ -968,7 +966,7 @@ LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
- auto dstType = converter->convertType(op.getType());
+ Type dstType = converter->convertType(op.getType());
if (!dstType)
return failure();
@@ -1075,8 +1073,9 @@ struct ConvertArithToSPIRVPass
: public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
void runOnOperation() override {
Operation *op = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
- auto target = SPIRVConversionTarget::get(targetAttr);
+ spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ std::unique_ptr<SPIRVConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
More information about the Mlir-commits
mailing list