[Mlir-commits] [mlir] 81c326c - [mlir][spirv] NFC: Merge ArithToSPIRV pattern decl and definition

Lei Zhang llvmlistbot at llvm.org
Sat Aug 12 16:30:47 PDT 2023


Author: Lei Zhang
Date: 2023-08-12T16:25:47-07:00
New Revision: 81c326ccdd9b8475b6b7180da36b24bb29ce4f42

URL: https://github.com/llvm/llvm-project/commit/81c326ccdd9b8475b6b7180da36b24bb29ce4f42
DIFF: https://github.com/llvm/llvm-project/commit/81c326ccdd9b8475b6b7180da36b24bb29ce4f42.diff

LOG: [mlir][spirv] NFC: Merge ArithToSPIRV pattern decl and definition

This makes the code easier to search and read.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D157782

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index f74c7e3490cd80..a8692a281366ba 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -32,228 +32,6 @@ namespace mlir {
 
 using namespace mlir;
 
-//===----------------------------------------------------------------------===//
-// Operation Conversion
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Converts composite arith.constant operation to spirv.Constant.
-struct ConstantCompositeOpPattern final
-    : public OpConversionPattern<arith::ConstantOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts scalar arith.constant operation to spirv.Constant.
-struct ConstantScalarOpPattern final
-    : public OpConversionPattern<arith::ConstantOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.remsi to GLSL SPIR-V ops.
-///
-/// This cannot be merged into the template unary/binary pattern due to Vulkan
-/// restrictions over spirv.SRem and spirv.SMod.
-struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.remsi to OpenCL SPIR-V ops.
-struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts bitwise operations to SPIR-V operations. This is a special pattern
-/// other than the BinaryOpPatternPattern because if the operands are boolean
-/// values, SPIR-V uses 
diff erent operations (`SPIRVLogicalOp`). For
-/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
-template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
-struct BitwiseOpPattern final : public OpConversionPattern<Op> {
-  using OpConversionPattern<Op>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.xori to SPIR-V operations.
-struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
-/// vector of i1.
-struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
-/// of i1.
-struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
-/// of i1.
-struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
-/// of i1.
-struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
-/// of i1.
-struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts type-casting standard operations to SPIR-V operations.
-template <typename Op, typename SPIRVOp>
-struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
-  using OpConversionPattern<Op>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts integer compare operation on i1 type operands to SPIR-V ops.
-class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts integer compare operation to SPIR-V ops.
-class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts floating-point comparison operations to SPIR-V ops.
-class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts floating point NaN check to SPIR-V ops. This pattern requires
-/// Kernel capability.
-class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts floating point NaN check to SPIR-V ops. This pattern does not
-/// require additional capability.
-class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
-public:
-  using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.addui_extended to spirv.IAddCarry.
-class AddUIExtendedOpPattern final
-    : public OpConversionPattern<arith::AddUIExtendedOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.mul*i_extended to spirv.*MulExtended.
-template <typename ArithMulOp, typename SPIRVMulOp>
-class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
-public:
-  using OpConversionPattern<ArithMulOp>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.select to spirv.Select.
-class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.maxf to spirv.GL.FMax or spirv.CL.fmax.
-template <typename Op, typename SPIRVOp>
-class MinMaxFOpPattern final : public OpConversionPattern<Op> {
-public:
-  using OpConversionPattern<Op>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // Conversion Helpers
 //===----------------------------------------------------------------------===//
@@ -362,157 +140,169 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
   return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
 }
 
+namespace {
+
 //===----------------------------------------------------------------------===//
-// ConstantOp with composite type
+// ConstantOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
-    arith::ConstantOp constOp, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  auto srcType = dyn_cast<ShapedType>(constOp.getType());
-  if (!srcType || srcType.getNumElements() == 1)
-    return failure();
-
-  // arith.constant should only have vector or tenor types.
-  assert((isa<VectorType, RankedTensorType>(srcType)));
-
-  Type dstType = getTypeConverter()->convertType(srcType);
-  if (!dstType)
-    return failure();
+/// Converts composite arith.constant operation to spirv.Constant.
+struct ConstantCompositeOpPattern final
+    : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
-  if (!dstElementsAttr)
-    return failure();
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = dyn_cast<ShapedType>(constOp.getType());
+    if (!srcType || srcType.getNumElements() == 1)
+      return failure();
 
-  ShapedType dstAttrType = dstElementsAttr.getType();
+    // arith.constant should only have vector or tenor types.
+    assert((isa<VectorType, RankedTensorType>(srcType)));
 
-  // If the composite type has more than one dimensions, perform linearization.
-  if (srcType.getRank() > 1) {
-    if (isa<RankedTensorType>(srcType)) {
-      dstAttrType = RankedTensorType::get(srcType.getNumElements(),
-                                          srcType.getElementType());
-      dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
-    } else {
-      // TODO: add support for large vectors.
+    Type dstType = getTypeConverter()->convertType(srcType);
+    if (!dstType)
       return failure();
-    }
-  }
 
-  Type srcElemType = srcType.getElementType();
-  Type dstElemType;
-  // Tensor types are converted to SPIR-V array types; vector types are
-  // converted to SPIR-V vector/array types.
-  if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
-    dstElemType = arrayType.getElementType();
-  else
-    dstElemType = cast<VectorType>(dstType).getElementType();
-
-  // If the source and destination element types are 
diff erent, perform
-  // attribute conversion.
-  if (srcElemType != dstElemType) {
-    SmallVector<Attribute, 8> elements;
-    if (isa<FloatType>(srcElemType)) {
-      for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
-        FloatAttr dstAttr =
-            convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
-        if (!dstAttr)
-          return failure();
-        elements.push_back(dstAttr);
-      }
-    } else if (srcElemType.isInteger(1)) {
+    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!dstElementsAttr)
       return failure();
-    } else {
-      for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
-        IntegerAttr dstAttr = convertIntegerAttr(
-            srcAttr, cast<IntegerType>(dstElemType), rewriter);
-        if (!dstAttr)
-          return failure();
-        elements.push_back(dstAttr);
+
+    ShapedType dstAttrType = dstElementsAttr.getType();
+
+    // If the composite type has more than one dimensions, perform
+    // linearization.
+    if (srcType.getRank() > 1) {
+      if (isa<RankedTensorType>(srcType)) {
+        dstAttrType = RankedTensorType::get(srcType.getNumElements(),
+                                            srcType.getElementType());
+        dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
+      } else {
+        // TODO: add support for large vectors.
+        return failure();
       }
     }
 
-    // Unfortunately, we cannot use dialect-specific types for element
-    // attributes; element attributes only works with builtin types. So we need
-    // to prepare another converted builtin types for the destination elements
-    // attribute.
-    if (isa<RankedTensorType>(dstAttrType))
-      dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
+    Type srcElemType = srcType.getElementType();
+    Type dstElemType;
+    // Tensor types are converted to SPIR-V array types; vector types are
+    // converted to SPIR-V vector/array types.
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
+      dstElemType = arrayType.getElementType();
     else
-      dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+      dstElemType = cast<VectorType>(dstType).getElementType();
+
+    // If the source and destination element types are 
diff erent, perform
+    // attribute conversion.
+    if (srcElemType != dstElemType) {
+      SmallVector<Attribute, 8> elements;
+      if (isa<FloatType>(srcElemType)) {
+        for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
+          FloatAttr dstAttr =
+              convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+          if (!dstAttr)
+            return failure();
+          elements.push_back(dstAttr);
+        }
+      } else if (srcElemType.isInteger(1)) {
+        return failure();
+      } else {
+        for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
+          IntegerAttr dstAttr = convertIntegerAttr(
+              srcAttr, cast<IntegerType>(dstElemType), rewriter);
+          if (!dstAttr)
+            return failure();
+          elements.push_back(dstAttr);
+        }
+      }
+
+      // Unfortunately, we cannot use dialect-specific types for element
+      // attributes; element attributes only works with builtin types. So we
+      // need to prepare another converted builtin types for the destination
+      // elements attribute.
+      if (isa<RankedTensorType>(dstAttrType))
+        dstAttrType =
+            RankedTensorType::get(dstAttrType.getShape(), dstElemType);
+      else
+        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+
+      dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
+    }
 
-    dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
+                                                   dstElementsAttr);
+    return success();
   }
+};
 
-  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
-                                                 dstElementsAttr);
-  return success();
-}
+/// Converts scalar arith.constant operation to spirv.Constant.
+struct ConstantScalarOpPattern final
+    : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-//===----------------------------------------------------------------------===//
-// ConstantOp with scalar type
-//===----------------------------------------------------------------------===//
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = constOp.getType();
+    if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
+      if (shapedType.getNumElements() != 1)
+        return failure();
+      srcType = shapedType.getElementType();
+    }
+    if (!srcType.isIntOrIndexOrFloat())
+      return failure();
 
-LogicalResult ConstantScalarOpPattern::matchAndRewrite(
-    arith::ConstantOp constOp, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  Type srcType = constOp.getType();
-  if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
-    if (shapedType.getNumElements() != 1)
+    Attribute cstAttr = constOp.getValue();
+    if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
+      cstAttr = elementsAttr.getSplatValue<Attribute>();
+
+    Type dstType = getTypeConverter()->convertType(srcType);
+    if (!dstType)
       return failure();
-    srcType = shapedType.getElementType();
-  }
-  if (!srcType.isIntOrIndexOrFloat())
-    return failure();
 
-  Attribute cstAttr = constOp.getValue();
-  if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
-    cstAttr = elementsAttr.getSplatValue<Attribute>();
+    // Floating-point types.
+    if (isa<FloatType>(srcType)) {
+      auto srcAttr = cast<FloatAttr>(cstAttr);
+      auto dstAttr = srcAttr;
 
-  Type dstType = getTypeConverter()->convertType(srcType);
-  if (!dstType)
-    return failure();
+      // Floating-point types not supported in the target environment are all
+      // converted to float type.
+      if (srcType != dstType) {
+        dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
+        if (!dstAttr)
+          return failure();
+      }
 
-  // Floating-point types.
-  if (isa<FloatType>(srcType)) {
-    auto srcAttr = cast<FloatAttr>(cstAttr);
-    auto dstAttr = srcAttr;
+      rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
+      return success();
+    }
 
-    // Floating-point types not supported in the target environment are all
-    // converted to float type.
-    if (srcType != dstType) {
-      dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
+    // Bool type.
+    if (srcType.isInteger(1)) {
+      // arith.constant can use 0/1 instead of true/false for i1 values. We need
+      // to handle that here.
+      auto dstAttr = convertBoolAttr(cstAttr, rewriter);
       if (!dstAttr)
         return failure();
+      rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
+      return success();
     }
 
-    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
-    return success();
-  }
-
-  // Bool type.
-  if (srcType.isInteger(1)) {
-    // arith.constant can use 0/1 instead of true/false for i1 values. We need
-    // to handle that here.
-    auto dstAttr = convertBoolAttr(cstAttr, rewriter);
+    // IndexType or IntegerType. Index values are converted to 32-bit integer
+    // values when converting to SPIR-V.
+    auto srcAttr = cast<IntegerAttr>(cstAttr);
+    IntegerAttr dstAttr =
+        convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
     if (!dstAttr)
       return failure();
     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
     return success();
   }
-
-  // IndexType or IntegerType. Index values are converted to 32-bit integer
-  // values when converting to SPIR-V.
-  auto srcAttr = cast<IntegerAttr>(cstAttr);
-  IntegerAttr dstAttr =
-      convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
-  if (!dstAttr)
-    return failure();
-  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
-  return success();
-}
+};
 
 //===----------------------------------------------------------------------===//
-// RemSIOpGLPattern
+// RemSIOp
 //===----------------------------------------------------------------------===//
 
 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
@@ -545,303 +335,363 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
 }
 
-LogicalResult
-RemSIOpGLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
-                                  ConversionPatternRewriter &rewriter) const {
-  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
-      op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
-      adaptor.getOperands()[0], rewriter);
-  rewriter.replaceOp(op, result);
+/// Converts arith.remsi to GLSL SPIR-V ops.
+///
+/// This cannot be merged into the template unary/binary pattern due to Vulkan
+/// restrictions over spirv.SRem and spirv.SMod.
+struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  return success();
-}
+  LogicalResult
+  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
+        op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+        adaptor.getOperands()[0], rewriter);
+    rewriter.replaceOp(op, result);
 
-//===----------------------------------------------------------------------===//
-// RemSIOpCLPattern
-//===----------------------------------------------------------------------===//
+    return success();
+  }
+};
 
-LogicalResult
-RemSIOpCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
-                                  ConversionPatternRewriter &rewriter) const {
-  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
-      op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
-      adaptor.getOperands()[0], rewriter);
-  rewriter.replaceOp(op, result);
+/// Converts arith.remsi to OpenCL SPIR-V ops.
+struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  return success();
-}
+  LogicalResult
+  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
+        op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+        adaptor.getOperands()[0], rewriter);
+    rewriter.replaceOp(op, result);
+
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// BitwiseOpPattern
+// BitwiseOp
 //===----------------------------------------------------------------------===//
 
+/// Converts bitwise operations to SPIR-V operations. This is a special pattern
+/// other than the BinaryOpPatternPattern because if the operands are boolean
+/// values, SPIR-V uses 
diff erent operations (`SPIRVLogicalOp`). For
+/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
-LogicalResult
-BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
-    Op op, typename Op::Adaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  assert(adaptor.getOperands().size() == 2);
-  Type dstType = this->getTypeConverter()->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
-
-  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
-    rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
-                                                         adaptor.getOperands());
-  } else {
-    rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
-                                                         adaptor.getOperands());
+struct BitwiseOpPattern final : public OpConversionPattern<Op> {
+  using OpConversionPattern<Op>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 2);
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
+      rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
+          op, dstType, adaptor.getOperands());
+    } else {
+      rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
+          op, dstType, adaptor.getOperands());
+    }
+    return success();
   }
-  return success();
-}
+};
 
 //===----------------------------------------------------------------------===//
-// XOrIOpLogicalPattern
+// XOrIOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
-    arith::XOrIOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  assert(adaptor.getOperands().size() == 2);
+/// Converts arith.xori to SPIR-V operations.
+struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
-    return failure();
+  LogicalResult
+  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 2);
 
-  Type dstType = getTypeConverter()->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
+    if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
 
-  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
-                                                   adaptor.getOperands());
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
 
-  return success();
-}
+    rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
+                                                     adaptor.getOperands());
 
-//===----------------------------------------------------------------------===//
-// XOrIOpBooleanPattern
-//===----------------------------------------------------------------------===//
+    return success();
+  }
+};
+
+/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
+/// vector of i1.
+struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
-    arith::XOrIOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  assert(adaptor.getOperands().size() == 2);
+  LogicalResult
+  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 2);
 
-  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
-    return failure();
+    if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
 
-  Type dstType = getTypeConverter()->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
 
-  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
-                                                        adaptor.getOperands());
-  return success();
-}
+    rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
+        op, dstType, adaptor.getOperands());
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// UIToFPI1Pattern
+// UIToFPOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
-                                 ConversionPatternRewriter &rewriter) const {
-  Type srcType = adaptor.getOperands().front().getType();
-  if (!isBoolScalarOrVector(srcType))
-    return failure();
+/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
+/// of i1.
+struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  Type dstType = getTypeConverter()->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
+  LogicalResult
+  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getOperands().front().getType();
+    if (!isBoolScalarOrVector(srcType))
+      return failure();
 
-  Location loc = op.getLoc();
-  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
-  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
-  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
-      op, dstType, adaptor.getOperands().front(), one, zero);
-  return success();
-}
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    Location loc = op.getLoc();
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+        op, dstType, adaptor.getOperands().front(), one, zero);
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// ExtSII1Pattern
+// ExtSIOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
-                                ConversionPatternRewriter &rewriter) const {
-  Value operand = adaptor.getIn();
-  if (!isBoolScalarOrVector(operand.getType()))
-    return failure();
+/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
+/// of i1.
+struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  Location loc = op.getLoc();
-  Type dstType = getTypeConverter()->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
-
-  Value allOnes;
-  if (auto intTy = dyn_cast<IntegerType>(dstType)) {
-    unsigned componentBitwidth = intTy.getWidth();
-    allOnes = rewriter.create<spirv::ConstantOp>(
-        loc, intTy,
-        rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
-  } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
-    unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
-    allOnes = rewriter.create<spirv::ConstantOp>(
-        loc, vectorTy,
-        SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth)));
-  } else {
-    return rewriter.notifyMatchFailure(
-        loc, llvm::formatv("unhandled type: {0}", dstType));
-  }
+  LogicalResult
+  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value operand = adaptor.getIn();
+    if (!isBoolScalarOrVector(operand.getType()))
+      return failure();
 
-  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
-  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
-                                               zero);
-  return success();
-}
+    Location loc = op.getLoc();
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    Value allOnes;
+    if (auto intTy = dyn_cast<IntegerType>(dstType)) {
+      unsigned componentBitwidth = intTy.getWidth();
+      allOnes = rewriter.create<spirv::ConstantOp>(
+          loc, intTy,
+          rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
+    } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
+      unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
+      allOnes = rewriter.create<spirv::ConstantOp>(
+          loc, vectorTy,
+          SplatElementsAttr::get(vectorTy,
+                                 APInt::getAllOnes(componentBitwidth)));
+    } else {
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unhandled type: {0}", dstType));
+    }
+
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
+                                                 zero);
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// ExtUII1Pattern
+// ExtUIOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
-                                ConversionPatternRewriter &rewriter) const {
-  Type srcType = adaptor.getOperands().front().getType();
-  if (!isBoolScalarOrVector(srcType))
-    return failure();
+/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
+/// of i1.
+struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getOperands().front().getType();
+    if (!isBoolScalarOrVector(srcType))
+      return failure();
 
-  Type dstType = getTypeConverter()->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
 
-  Location loc = op.getLoc();
-  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
-  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
-  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
-      op, dstType, adaptor.getOperands().front(), one, zero);
-  return success();
-}
+    Location loc = op.getLoc();
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+        op, dstType, adaptor.getOperands().front(), one, zero);
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// TruncII1Pattern
+// TruncIOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
-                                 ConversionPatternRewriter &rewriter) const {
-  Type dstType = getTypeConverter()->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
+/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
+/// of i1.
+struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  if (!isBoolScalarOrVector(dstType))
-    return failure();
+  LogicalResult
+  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
 
-  Location loc = op.getLoc();
-  auto srcType = adaptor.getOperands().front().getType();
-  // Check if (x & 1) == 1.
-  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
-  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
-      loc, srcType, adaptor.getOperands()[0], mask);
-  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
-
-  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
-  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
-  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
-  return success();
-}
+    if (!isBoolScalarOrVector(dstType))
+      return failure();
+
+    Location loc = op.getLoc();
+    auto srcType = adaptor.getOperands().front().getType();
+    // Check if (x & 1) == 1.
+    Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
+    Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
+        loc, srcType, adaptor.getOperands()[0], mask);
+    Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
+
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// TypeCastingOpPattern
+// TypeCastingOp
 //===----------------------------------------------------------------------===//
 
+/// Converts type-casting standard operations to SPIR-V operations.
 template <typename Op, typename SPIRVOp>
-LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
-    Op op, typename Op::Adaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  assert(adaptor.getOperands().size() == 1);
-  Type srcType = adaptor.getOperands().front().getType();
-  Type dstType = this->getTypeConverter()->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
-
-  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
-    return failure();
+struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
+  using OpConversionPattern<Op>::OpConversionPattern;
 
-  if (dstType == srcType) {
-    // Due to type conversion, we are seeing the same source and target type.
-    // Then we can just erase this operation by forwarding its operand.
-    rewriter.replaceOp(op, adaptor.getOperands().front());
-  } else {
-    rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
-                                                  adaptor.getOperands());
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 1);
+    Type srcType = adaptor.getOperands().front().getType();
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
+      return failure();
+
+    if (dstType == srcType) {
+      // Due to type conversion, we are seeing the same source and target type.
+      // Then we can just erase this operation by forwarding its operand.
+      rewriter.replaceOp(op, adaptor.getOperands().front());
+    } else {
+      rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
+                                                    adaptor.getOperands());
+    }
+    return success();
   }
-  return success();
-}
+};
 
 //===----------------------------------------------------------------------===//
-// CmpIOpBooleanPattern
+// CmpIOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
-    arith::CmpIOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  Type srcType = op.getLhs().getType();
-  if (!isBoolScalarOrVector(srcType))
-    return failure();
-  Type dstType = getTypeConverter()->convertType(srcType);
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op, srcType);
+/// Converts integer compare operation on i1 type operands to SPIR-V ops.
+class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
 
-  switch (op.getPredicate()) {
-  case arith::CmpIPredicate::eq: {
-    rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
-                                                       adaptor.getRhs());
-    return success();
-  }
-  case arith::CmpIPredicate::ne: {
-    rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(),
-                                                          adaptor.getRhs());
-    return success();
-  }
-  case arith::CmpIPredicate::uge:
-  case arith::CmpIPredicate::ugt:
-  case arith::CmpIPredicate::ule:
-  case arith::CmpIPredicate::ult: {
-    // There are no direct corresponding instructions in SPIR-V for such cases.
-    // Extend them to 32-bit and do comparision then.
-    Type type = rewriter.getI32Type();
-    if (auto vectorType = dyn_cast<VectorType>(dstType))
-      type = VectorType::get(vectorType.getShape(), type);
-    Value extLhs =
-        rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
-    Value extRhs =
-        rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
-
-    rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
-                                               extRhs);
-    return success();
-  }
-  default:
-    break;
+  LogicalResult
+  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = op.getLhs().getType();
+    if (!isBoolScalarOrVector(srcType))
+      return failure();
+    Type dstType = getTypeConverter()->convertType(srcType);
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op, srcType);
+
+    switch (op.getPredicate()) {
+    case arith::CmpIPredicate::eq: {
+      rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
+                                                         adaptor.getRhs());
+      return success();
+    }
+    case arith::CmpIPredicate::ne: {
+      rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
+          op, adaptor.getLhs(), adaptor.getRhs());
+      return success();
+    }
+    case arith::CmpIPredicate::uge:
+    case arith::CmpIPredicate::ugt:
+    case arith::CmpIPredicate::ule:
+    case arith::CmpIPredicate::ult: {
+      // There are no direct corresponding instructions in SPIR-V for such
+      // cases. Extend them to 32-bit and do comparision then.
+      Type type = rewriter.getI32Type();
+      if (auto vectorType = dyn_cast<VectorType>(dstType))
+        type = VectorType::get(vectorType.getShape(), type);
+      Value extLhs =
+          rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
+      Value extRhs =
+          rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
+
+      rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
+                                                 extRhs);
+      return success();
+    }
+    default:
+      break;
+    }
+    return failure();
   }
-  return failure();
-}
+};
 
-//===----------------------------------------------------------------------===//
-// CmpIOpPattern
-//===----------------------------------------------------------------------===//
+/// Converts integer compare operation to SPIR-V ops.
+class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
 
-LogicalResult
-CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
-                               ConversionPatternRewriter &rewriter) const {
-  Type srcType = op.getLhs().getType();
-  if (isBoolScalarOrVector(srcType))
-    return failure();
-  Type dstType = getTypeConverter()->convertType(srcType);
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op, srcType);
+  LogicalResult
+  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = op.getLhs().getType();
+    if (isBoolScalarOrVector(srcType))
+      return failure();
+    Type dstType = getTypeConverter()->convertType(srcType);
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op, srcType);
 
-  switch (op.getPredicate()) {
+    switch (op.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
@@ -854,216 +704,253 @@ CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
                                          adaptor.getRhs());                    \
     return success();
 
-    DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
-    DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
-    DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
-    DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
-    DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
-    DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
-    DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
-    DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
-    DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
-    DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
+      DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
+      DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
+      DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
+      DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
+      DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
+      DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
+      DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
+      DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
+      DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
+      DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
 
 #undef DISPATCH
+    }
+    return failure();
   }
-  return failure();
-}
+};
 
 //===----------------------------------------------------------------------===//
 // CmpFOpPattern
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
-                               ConversionPatternRewriter &rewriter) const {
-  switch (op.getPredicate()) {
+/// Converts floating-point comparison operations to SPIR-V ops.
+class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    switch (op.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
                                          adaptor.getRhs());                    \
     return success();
 
-    // Ordered.
-    DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
-    DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
-    DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
-    DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
-    DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
-    DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
-    // Unordered.
-    DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
-    DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
-    DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
-    DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
-    DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
-    DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
+      // Ordered.
+      DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
+      DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
+      DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
+      DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
+      DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
+      DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
+      // Unordered.
+      DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
+      DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
+      DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
+      DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
+      DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
+      DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
 
 #undef DISPATCH
 
-  default:
-    break;
+    default:
+      break;
+    }
+    return failure();
   }
-  return failure();
-}
-
-//===----------------------------------------------------------------------===//
-// CmpFOpNanKernelPattern
-//===----------------------------------------------------------------------===//
+};
 
-LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
-    arith::CmpFOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
-    rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
-                                                  adaptor.getRhs());
-    return success();
-  }
+/// Converts floating point NaN check to SPIR-V ops. This pattern requires
+/// Kernel capability.
+class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
 
-  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
-    rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
+  LogicalResult
+  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (op.getPredicate() == arith::CmpFPredicate::ORD) {
+      rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
                                                     adaptor.getRhs());
-    return success();
-  }
-
-  return failure();
-}
+      return success();
+    }
 
-//===----------------------------------------------------------------------===//
-// CmpFOpNanNonePattern
-//===----------------------------------------------------------------------===//
+    if (op.getPredicate() == arith::CmpFPredicate::UNO) {
+      rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
+                                                      adaptor.getRhs());
+      return success();
+    }
 
-LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
-    arith::CmpFOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
-      op.getPredicate() != arith::CmpFPredicate::UNO)
     return failure();
+  }
+};
 
-  Location loc = op.getLoc();
-  auto *converter = getTypeConverter<SPIRVTypeConverter>();
+/// Converts floating point NaN check to SPIR-V ops. This pattern does not
+/// require additional capability.
+class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
+public:
+  using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
 
-  Value replace;
-  if (converter->getOptions().enableFastMathMode) {
-    if (op.getPredicate() == arith::CmpFPredicate::ORD) {
-      // Ordered comparsion checks if neither operand is NaN.
-      replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
+  LogicalResult
+  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (op.getPredicate() != arith::CmpFPredicate::ORD &&
+        op.getPredicate() != arith::CmpFPredicate::UNO)
+      return failure();
+
+    Location loc = op.getLoc();
+    auto *converter = getTypeConverter<SPIRVTypeConverter>();
+
+    Value replace;
+    if (converter->getOptions().enableFastMathMode) {
+      if (op.getPredicate() == arith::CmpFPredicate::ORD) {
+        // Ordered comparsion checks if neither operand is NaN.
+        replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
+      } else {
+        // Unordered comparsion checks if either operand is NaN.
+        replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
+      }
     } else {
-      // Unordered comparsion checks if either operand is NaN.
-      replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
+      Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+      Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+      replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+      if (op.getPredicate() == arith::CmpFPredicate::ORD)
+        replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
     }
-  } else {
-    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
-    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
 
-    replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
-    if (op.getPredicate() == arith::CmpFPredicate::ORD)
-      replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+    rewriter.replaceOp(op, replace);
+    return success();
   }
-
-  rewriter.replaceOp(op, replace);
-  return success();
-}
+};
 
 //===----------------------------------------------------------------------===//
-// AddUIExtendedOpPattern
+// AddUIExtendedOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult AddUIExtendedOpPattern::matchAndRewrite(
-    arith::AddUIExtendedOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  Type dstElemTy = adaptor.getLhs().getType();
-  Location loc = op->getLoc();
-  Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
-                                                     adaptor.getRhs());
-
-  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
-      loc, result, llvm::ArrayRef(0));
-  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
-      loc, result, llvm::ArrayRef(1));
-
-  // Convert the carry value to boolean.
-  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
-  Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
-
-  rewriter.replaceOp(op, {sumResult, carryResult});
-  return success();
-}
+/// Converts arith.addui_extended to spirv.IAddCarry.
+class AddUIExtendedOpPattern final
+    : public OpConversionPattern<arith::AddUIExtendedOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type dstElemTy = adaptor.getLhs().getType();
+    Location loc = op->getLoc();
+    Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
+                                                       adaptor.getRhs());
+
+    Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
+        loc, result, llvm::ArrayRef(0));
+    Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
+        loc, result, llvm::ArrayRef(1));
+
+    // Convert the carry value to boolean.
+    Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
+    Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+
+    rewriter.replaceOp(op, {sumResult, carryResult});
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// MulIExtendedOpPattern
+// MulIExtendedOp
 //===----------------------------------------------------------------------===//
 
+/// Converts arith.mul*i_extended to spirv.*MulExtended.
 template <typename ArithMulOp, typename SPIRVMulOp>
-LogicalResult MulIExtendedOpPattern<ArithMulOp, SPIRVMulOp>::matchAndRewrite(
-    ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  Location loc = op->getLoc();
-  Value result =
-      rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
-
-  Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
-                                                         llvm::ArrayRef(0));
-  Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
-                                                          llvm::ArrayRef(1));
-
-  rewriter.replaceOp(op, {low, high});
-  return success();
-}
+class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
+public:
+  using OpConversionPattern<ArithMulOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Value result =
+        rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
+
+    Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
+                                                           llvm::ArrayRef(0));
+    Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
+                                                            llvm::ArrayRef(1));
+
+    rewriter.replaceOp(op, {low, high});
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// SelectOpPattern
+// SelectOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
-                                 ConversionPatternRewriter &rewriter) const {
-  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
-                                               adaptor.getTrueValue(),
-                                               adaptor.getFalseValue());
-  return success();
-}
+/// Converts arith.select to spirv.Select.
+class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
+                                                 adaptor.getTrueValue(),
+                                                 adaptor.getFalseValue());
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// MaxFOpPattern
+// MaxFOp
 //===----------------------------------------------------------------------===//
 
+/// Converts arith.maxf to spirv.GL.FMax or spirv.CL.fmax.
 template <typename Op, typename SPIRVOp>
-LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
-    Op op, typename Op::Adaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
-  Type dstType = converter->convertType(op.getType());
-  if (!dstType)
-    return getTypeConversionFailure(rewriter, op);
-
-  // arith.maxf/minf:
-  //   "if one of the arguments is NaN, then the result is also NaN."
-  // spirv.GL.FMax/FMin
-  //   "which operand is the result is undefined if one of the operands
-  //   is a NaN."
-  // spirv.CL.fmax/fmin:
-  //   "If one argument is a NaN, Fmin returns the other argument."
-
-  Location loc = op.getLoc();
-  Value spirvOp = rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
-
-  if (converter->getOptions().enableFastMathMode) {
-    rewriter.replaceOp(op, spirvOp);
-    return success();
-  }
+class MinMaxFOpPattern final : public OpConversionPattern<Op> {
+public:
+  using OpConversionPattern<Op>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = converter->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // arith.maxf/minf:
+    //   "if one of the arguments is NaN, then the result is also NaN."
+    // spirv.GL.FMax/FMin
+    //   "which operand is the result is undefined if one of the operands
+    //   is a NaN."
+    // spirv.CL.fmax/fmin:
+    //   "If one argument is a NaN, Fmin returns the other argument."
+
+    Location loc = op.getLoc();
+    Value spirvOp =
+        rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+    if (converter->getOptions().enableFastMathMode) {
+      rewriter.replaceOp(op, spirvOp);
+      return success();
+    }
 
-  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
-  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
 
-  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
-                                                   adaptor.getLhs(), spirvOp);
-  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
-                                                   adaptor.getRhs(), select1);
+    Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+                                                     adaptor.getLhs(), spirvOp);
+    Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
+                                                     adaptor.getRhs(), select1);
 
-  rewriter.replaceOp(op, select2);
-  return success();
-}
+    rewriter.replaceOp(op, select2);
+    return success();
+  }
+};
+
+} // namespace
 
 //===----------------------------------------------------------------------===//
 // Pattern Population


        


More information about the Mlir-commits mailing list