[Mlir-commits] [mlir] c50d0fe - [mlir][arith][spirv] Clean up arith-to-spirv. NFC.

Jakub Kuderski llvmlistbot at llvm.org
Mon Nov 14 17:55:14 PST 2022


Author: Jakub Kuderski
Date: 2022-11-14T20:54:27-05:00
New Revision: c50d0fe570e630c3b1e56a5aee17568e955134b3

URL: https://github.com/llvm/llvm-project/commit/c50d0fe570e630c3b1e56a5aee17568e955134b3
DIFF: https://github.com/llvm/llvm-project/commit/c50d0fe570e630c3b1e56a5aee17568e955134b3.diff

LOG: [mlir][arith][spirv] Clean up arith-to-spirv. NFC.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D137978

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a284be8ce939..d550e0e33f3e 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -9,8 +9,8 @@
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 
 #include "../SPIRVCommon/Pattern.h"
-#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
@@ -21,6 +21,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Debug.h"
 #include <cassert>
+#include <memory>
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
@@ -40,7 +41,7 @@ namespace {
 /// Converts composite arith.constant operation to spirv.Constant.
 struct ConstantCompositeOpPattern final
     : public OpConversionPattern<arith::ConstantOp> {
-  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
@@ -50,7 +51,7 @@ struct ConstantCompositeOpPattern final
 /// Converts scalar arith.constant operation to spirv.Constant.
 struct ConstantScalarOpPattern final
     : public OpConversionPattern<arith::ConstantOp> {
-  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
@@ -62,7 +63,7 @@ struct ConstantScalarOpPattern final
 /// This cannot be merged into the template unary/binary pattern due to Vulkan
 /// restrictions over spirv.SRem and spirv.SMod.
 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
-  using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
@@ -71,7 +72,7 @@ struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
 
 /// Converts arith.remsi to OpenCL SPIR-V ops.
 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
-  using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
@@ -93,7 +94,7 @@ struct BitwiseOpPattern final : public OpConversionPattern<Op> {
 
 /// Converts arith.xori to SPIR-V operations.
 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
-  using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
@@ -103,7 +104,7 @@ struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
 /// vector of i1.
 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
-  using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
@@ -113,7 +114,7 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
 /// of i1.
 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
-  using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
@@ -123,7 +124,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
 /// of i1.
 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
-  using OpConversionPattern<arith::ExtSIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
@@ -133,7 +134,7 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
 /// of i1.
 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
-  using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
@@ -143,7 +144,7 @@ struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
 /// of i1.
 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
-  using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
@@ -163,7 +164,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
 public:
-  using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
@@ -173,7 +174,7 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
 /// Converts integer compare operation to SPIR-V ops.
 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
 public:
-  using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
@@ -183,7 +184,7 @@ class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
 /// Converts floating-point comparison operations to SPIR-V ops.
 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
 public:
-  using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
@@ -194,7 +195,7 @@ class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
 /// Kernel capability.
 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
 public:
-  using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
@@ -216,7 +217,7 @@ class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
 class AddICarryOpPattern final
     : public OpConversionPattern<arith::AddUICarryOp> {
 public:
-  using OpConversionPattern<arith::AddUICarryOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
@@ -225,7 +226,7 @@ class AddICarryOpPattern final
 /// Converts arith.select to spirv.Select.
 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
 public:
-  using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
@@ -254,7 +255,7 @@ static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
     return boolAttr;
   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
-  return BoolAttr();
+  return {};
 }
 
 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
@@ -281,7 +282,7 @@ static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
                           << "' illegal: cannot fit into target type '"
                           << dstType << "'\n");
-  return IntegerAttr();
+  return {};
 }
 
 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
@@ -346,7 +347,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
   // arith.constant should only have vector or tenor types.
   assert((srcType.isa<VectorType, RankedTensorType>()));
 
-  auto dstType = getTypeConverter()->convertType(srcType);
+  Type dstType = getTypeConverter()->convertType(srcType);
   if (!dstType)
     return failure();
 
@@ -473,7 +474,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
   // IndexType or IntegerType. Index values are converted to 32-bit integer
   // values when converting to SPIR-V.
   auto srcAttr = cstAttr.cast<IntegerAttr>();
-  auto dstAttr =
+  IntegerAttr dstAttr =
       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
   if (!dstAttr)
     return failure();
@@ -577,7 +578,7 @@ LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
     return failure();
 
-  auto dstType = getTypeConverter()->convertType(op.getType());
+  Type dstType = getTypeConverter()->convertType(op.getType());
   if (!dstType)
     return failure();
   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
@@ -598,7 +599,7 @@ LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
     return failure();
 
-  auto dstType = getTypeConverter()->convertType(op.getType());
+  Type dstType = getTypeConverter()->convertType(op.getType());
   if (!dstType)
     return failure();
   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
@@ -613,16 +614,15 @@ LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
 LogicalResult
 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
-  auto srcType = adaptor.getOperands().front().getType();
+  Type srcType = adaptor.getOperands().front().getType();
   if (!isBoolScalarOrVector(srcType))
     return failure();
 
-  auto dstType =
-      this->getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = getTypeConverter()->convertType(op.getResult().getType());
   Location loc = op.getLoc();
   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
-  rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
+  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
       op, dstType, adaptor.getOperands().front(), one, zero);
   return success();
 }
@@ -670,16 +670,15 @@ ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
 LogicalResult
 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
-  auto srcType = adaptor.getOperands().front().getType();
+  Type srcType = adaptor.getOperands().front().getType();
   if (!isBoolScalarOrVector(srcType))
     return failure();
 
-  auto dstType =
-      this->getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = getTypeConverter()->convertType(op.getResult().getType());
   Location loc = op.getLoc();
   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
-  rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
+  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
       op, dstType, adaptor.getOperands().front(), one, zero);
   return success();
 }
@@ -691,8 +690,7 @@ ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
 LogicalResult
 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
-  auto dstType =
-      this->getTypeConverter()->convertType(op.getResult().getType());
+  Type dstType = getTypeConverter()->convertType(op.getResult().getType());
   if (!isBoolScalarOrVector(dstType))
     return failure();
 
@@ -719,8 +717,8 @@ LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
     Op op, typename Op::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   assert(adaptor.getOperands().size() == 1);
-  auto srcType = adaptor.getOperands().front().getType();
-  auto dstType =
+  Type srcType = adaptor.getOperands().front().getType();
+  Type dstType =
       this->getTypeConverter()->convertType(op.getResult().getType());
   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
     return failure();
@@ -769,9 +767,9 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
     Type type = rewriter.getI32Type();
     if (auto vectorType = dstType.dyn_cast<VectorType>())
       type = VectorType::get(vectorType.getShape(), type);
-    auto extLhs =
+    Value extLhs =
         rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
-    auto extRhs =
+    Value extRhs =
         rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
 
     rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
@@ -968,7 +966,7 @@ LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
     Op op, typename Op::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
-  auto dstType = converter->convertType(op.getType());
+  Type dstType = converter->convertType(op.getType());
   if (!dstType)
     return failure();
 
@@ -1075,8 +1073,9 @@ struct ConvertArithToSPIRVPass
     : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
   void runOnOperation() override {
     Operation *op = getOperation();
-    auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
-    auto target = SPIRVConversionTarget::get(targetAttr);
+    spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+    std::unique_ptr<SPIRVConversionTarget> target =
+        SPIRVConversionTarget::get(targetAttr);
 
     SPIRVConversionOptions options;
     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;


        


More information about the Mlir-commits mailing list