[Mlir-commits] [mlir] 81c326c - [mlir][spirv] NFC: Merge ArithToSPIRV pattern decl and definition
Lei Zhang
llvmlistbot at llvm.org
Sat Aug 12 16:30:47 PDT 2023
Author: Lei Zhang
Date: 2023-08-12T16:25:47-07:00
New Revision: 81c326ccdd9b8475b6b7180da36b24bb29ce4f42
URL: https://github.com/llvm/llvm-project/commit/81c326ccdd9b8475b6b7180da36b24bb29ce4f42
DIFF: https://github.com/llvm/llvm-project/commit/81c326ccdd9b8475b6b7180da36b24bb29ce4f42.diff
LOG: [mlir][spirv] NFC: Merge ArithToSPIRV pattern decl and definition
This makes the code easier to search and read.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D157782
Added:
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index f74c7e3490cd80..a8692a281366ba 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -32,228 +32,6 @@ namespace mlir {
using namespace mlir;
-//===----------------------------------------------------------------------===//
-// Operation Conversion
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Converts composite arith.constant operation to spirv.Constant.
-struct ConstantCompositeOpPattern final
- : public OpConversionPattern<arith::ConstantOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts scalar arith.constant operation to spirv.Constant.
-struct ConstantScalarOpPattern final
- : public OpConversionPattern<arith::ConstantOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.remsi to GLSL SPIR-V ops.
-///
-/// This cannot be merged into the template unary/binary pattern due to Vulkan
-/// restrictions over spirv.SRem and spirv.SMod.
-struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.remsi to OpenCL SPIR-V ops.
-struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts bitwise operations to SPIR-V operations. This is a special pattern
-/// other than the BinaryOpPatternPattern because if the operands are boolean
-/// values, SPIR-V uses
diff erent operations (`SPIRVLogicalOp`). For
-/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
-template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
-struct BitwiseOpPattern final : public OpConversionPattern<Op> {
- using OpConversionPattern<Op>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(Op op, typename Op::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.xori to SPIR-V operations.
-struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
-/// vector of i1.
-struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
-/// of i1.
-struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
-/// of i1.
-struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
-/// of i1.
-struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
-/// of i1.
-struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts type-casting standard operations to SPIR-V operations.
-template <typename Op, typename SPIRVOp>
-struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
- using OpConversionPattern<Op>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(Op op, typename Op::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts integer compare operation on i1 type operands to SPIR-V ops.
-class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts integer compare operation to SPIR-V ops.
-class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts floating-point comparison operations to SPIR-V ops.
-class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts floating point NaN check to SPIR-V ops. This pattern requires
-/// Kernel capability.
-class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts floating point NaN check to SPIR-V ops. This pattern does not
-/// require additional capability.
-class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
-public:
- using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.addui_extended to spirv.IAddCarry.
-class AddUIExtendedOpPattern final
- : public OpConversionPattern<arith::AddUIExtendedOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.mul*i_extended to spirv.*MulExtended.
-template <typename ArithMulOp, typename SPIRVMulOp>
-class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
-public:
- using OpConversionPattern<ArithMulOp>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.select to spirv.Select.
-class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-/// Converts arith.maxf to spirv.GL.FMax or spirv.CL.fmax.
-template <typename Op, typename SPIRVOp>
-class MinMaxFOpPattern final : public OpConversionPattern<Op> {
-public:
- using OpConversionPattern<Op>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(Op op, typename Op::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-} // namespace
-
//===----------------------------------------------------------------------===//
// Conversion Helpers
//===----------------------------------------------------------------------===//
@@ -362,157 +140,169 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
}
+namespace {
+
//===----------------------------------------------------------------------===//
-// ConstantOp with composite type
+// ConstantOp
//===----------------------------------------------------------------------===//
-LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
- arith::ConstantOp constOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- auto srcType = dyn_cast<ShapedType>(constOp.getType());
- if (!srcType || srcType.getNumElements() == 1)
- return failure();
-
- // arith.constant should only have vector or tenor types.
- assert((isa<VectorType, RankedTensorType>(srcType)));
-
- Type dstType = getTypeConverter()->convertType(srcType);
- if (!dstType)
- return failure();
+/// Converts composite arith.constant operation to spirv.Constant.
+struct ConstantCompositeOpPattern final
+ : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
- auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
- if (!dstElementsAttr)
- return failure();
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = dyn_cast<ShapedType>(constOp.getType());
+ if (!srcType || srcType.getNumElements() == 1)
+ return failure();
- ShapedType dstAttrType = dstElementsAttr.getType();
+ // arith.constant should only have vector or tenor types.
+ assert((isa<VectorType, RankedTensorType>(srcType)));
- // If the composite type has more than one dimensions, perform linearization.
- if (srcType.getRank() > 1) {
- if (isa<RankedTensorType>(srcType)) {
- dstAttrType = RankedTensorType::get(srcType.getNumElements(),
- srcType.getElementType());
- dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
- } else {
- // TODO: add support for large vectors.
+ Type dstType = getTypeConverter()->convertType(srcType);
+ if (!dstType)
return failure();
- }
- }
- Type srcElemType = srcType.getElementType();
- Type dstElemType;
- // Tensor types are converted to SPIR-V array types; vector types are
- // converted to SPIR-V vector/array types.
- if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
- dstElemType = arrayType.getElementType();
- else
- dstElemType = cast<VectorType>(dstType).getElementType();
-
- // If the source and destination element types are
diff erent, perform
- // attribute conversion.
- if (srcElemType != dstElemType) {
- SmallVector<Attribute, 8> elements;
- if (isa<FloatType>(srcElemType)) {
- for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
- if (!dstAttr)
- return failure();
- elements.push_back(dstAttr);
- }
- } else if (srcElemType.isInteger(1)) {
+ auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!dstElementsAttr)
return failure();
- } else {
- for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
- IntegerAttr dstAttr = convertIntegerAttr(
- srcAttr, cast<IntegerType>(dstElemType), rewriter);
- if (!dstAttr)
- return failure();
- elements.push_back(dstAttr);
+
+ ShapedType dstAttrType = dstElementsAttr.getType();
+
+ // If the composite type has more than one dimensions, perform
+ // linearization.
+ if (srcType.getRank() > 1) {
+ if (isa<RankedTensorType>(srcType)) {
+ dstAttrType = RankedTensorType::get(srcType.getNumElements(),
+ srcType.getElementType());
+ dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
+ } else {
+ // TODO: add support for large vectors.
+ return failure();
}
}
- // Unfortunately, we cannot use dialect-specific types for element
- // attributes; element attributes only works with builtin types. So we need
- // to prepare another converted builtin types for the destination elements
- // attribute.
- if (isa<RankedTensorType>(dstAttrType))
- dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
+ Type srcElemType = srcType.getElementType();
+ Type dstElemType;
+ // Tensor types are converted to SPIR-V array types; vector types are
+ // converted to SPIR-V vector/array types.
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
+ dstElemType = arrayType.getElementType();
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstElemType = cast<VectorType>(dstType).getElementType();
+
+ // If the source and destination element types are
diff erent, perform
+ // attribute conversion.
+ if (srcElemType != dstElemType) {
+ SmallVector<Attribute, 8> elements;
+ if (isa<FloatType>(srcElemType)) {
+ for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
+ FloatAttr dstAttr =
+ convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ if (!dstAttr)
+ return failure();
+ elements.push_back(dstAttr);
+ }
+ } else if (srcElemType.isInteger(1)) {
+ return failure();
+ } else {
+ for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
+ IntegerAttr dstAttr = convertIntegerAttr(
+ srcAttr, cast<IntegerType>(dstElemType), rewriter);
+ if (!dstAttr)
+ return failure();
+ elements.push_back(dstAttr);
+ }
+ }
+
+ // Unfortunately, we cannot use dialect-specific types for element
+ // attributes; element attributes only works with builtin types. So we
+ // need to prepare another converted builtin types for the destination
+ // elements attribute.
+ if (isa<RankedTensorType>(dstAttrType))
+ dstAttrType =
+ RankedTensorType::get(dstAttrType.getShape(), dstElemType);
+ else
+ dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+
+ dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
+ }
- dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
+ dstElementsAttr);
+ return success();
}
+};
- rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
- dstElementsAttr);
- return success();
-}
+/// Converts scalar arith.constant operation to spirv.Constant.
+struct ConstantScalarOpPattern final
+ : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
-//===----------------------------------------------------------------------===//
-// ConstantOp with scalar type
-//===----------------------------------------------------------------------===//
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = constOp.getType();
+ if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
+ if (shapedType.getNumElements() != 1)
+ return failure();
+ srcType = shapedType.getElementType();
+ }
+ if (!srcType.isIntOrIndexOrFloat())
+ return failure();
-LogicalResult ConstantScalarOpPattern::matchAndRewrite(
- arith::ConstantOp constOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Type srcType = constOp.getType();
- if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
- if (shapedType.getNumElements() != 1)
+ Attribute cstAttr = constOp.getValue();
+ if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
+ cstAttr = elementsAttr.getSplatValue<Attribute>();
+
+ Type dstType = getTypeConverter()->convertType(srcType);
+ if (!dstType)
return failure();
- srcType = shapedType.getElementType();
- }
- if (!srcType.isIntOrIndexOrFloat())
- return failure();
- Attribute cstAttr = constOp.getValue();
- if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
- cstAttr = elementsAttr.getSplatValue<Attribute>();
+ // Floating-point types.
+ if (isa<FloatType>(srcType)) {
+ auto srcAttr = cast<FloatAttr>(cstAttr);
+ auto dstAttr = srcAttr;
- Type dstType = getTypeConverter()->convertType(srcType);
- if (!dstType)
- return failure();
+ // Floating-point types not supported in the target environment are all
+ // converted to float type.
+ if (srcType != dstType) {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
+ if (!dstAttr)
+ return failure();
+ }
- // Floating-point types.
- if (isa<FloatType>(srcType)) {
- auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
+ return success();
+ }
- // Floating-point types not supported in the target environment are all
- // converted to float type.
- if (srcType != dstType) {
- dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
+ // Bool type.
+ if (srcType.isInteger(1)) {
+ // arith.constant can use 0/1 instead of true/false for i1 values. We need
+ // to handle that here.
+ auto dstAttr = convertBoolAttr(cstAttr, rewriter);
if (!dstAttr)
return failure();
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
+ return success();
}
- rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
- return success();
- }
-
- // Bool type.
- if (srcType.isInteger(1)) {
- // arith.constant can use 0/1 instead of true/false for i1 values. We need
- // to handle that here.
- auto dstAttr = convertBoolAttr(cstAttr, rewriter);
+ // IndexType or IntegerType. Index values are converted to 32-bit integer
+ // values when converting to SPIR-V.
+ auto srcAttr = cast<IntegerAttr>(cstAttr);
+ IntegerAttr dstAttr =
+ convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
if (!dstAttr)
return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
return success();
}
-
- // IndexType or IntegerType. Index values are converted to 32-bit integer
- // values when converting to SPIR-V.
- auto srcAttr = cast<IntegerAttr>(cstAttr);
- IntegerAttr dstAttr =
- convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
- if (!dstAttr)
- return failure();
- rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
- return success();
-}
+};
//===----------------------------------------------------------------------===//
-// RemSIOpGLPattern
+// RemSIOp
//===----------------------------------------------------------------------===//
/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
@@ -545,303 +335,363 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
}
-LogicalResult
-RemSIOpGLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
- op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
- adaptor.getOperands()[0], rewriter);
- rewriter.replaceOp(op, result);
+/// Converts arith.remsi to GLSL SPIR-V ops.
+///
+/// This cannot be merged into the template unary/binary pattern due to Vulkan
+/// restrictions over spirv.SRem and spirv.SMod.
+struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
+ using OpConversionPattern::OpConversionPattern;
- return success();
-}
+ LogicalResult
+ matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
+ op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+ adaptor.getOperands()[0], rewriter);
+ rewriter.replaceOp(op, result);
-//===----------------------------------------------------------------------===//
-// RemSIOpCLPattern
-//===----------------------------------------------------------------------===//
+ return success();
+ }
+};
-LogicalResult
-RemSIOpCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
- op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
- adaptor.getOperands()[0], rewriter);
- rewriter.replaceOp(op, result);
+/// Converts arith.remsi to OpenCL SPIR-V ops.
+struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
+ using OpConversionPattern::OpConversionPattern;
- return success();
-}
+ LogicalResult
+ matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
+ op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+ adaptor.getOperands()[0], rewriter);
+ rewriter.replaceOp(op, result);
+
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// BitwiseOpPattern
+// BitwiseOp
//===----------------------------------------------------------------------===//
+/// Converts bitwise operations to SPIR-V operations. This is a special pattern
+/// other than the BinaryOpPatternPattern because if the operands are boolean
+/// values, SPIR-V uses
diff erent operations (`SPIRVLogicalOp`). For
+/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
-LogicalResult
-BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
- Op op, typename Op::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- assert(adaptor.getOperands().size() == 2);
- Type dstType = this->getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
-
- if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
- rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
- adaptor.getOperands());
- } else {
- rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
- adaptor.getOperands());
+struct BitwiseOpPattern final : public OpConversionPattern<Op> {
+ using OpConversionPattern<Op>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 2);
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
+ rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
+ op, dstType, adaptor.getOperands());
+ } else {
+ rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
+ op, dstType, adaptor.getOperands());
+ }
+ return success();
}
- return success();
-}
+};
//===----------------------------------------------------------------------===//
-// XOrIOpLogicalPattern
+// XOrIOp
//===----------------------------------------------------------------------===//
-LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
- arith::XOrIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- assert(adaptor.getOperands().size() == 2);
+/// Converts arith.xori to SPIR-V operations.
+struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
+ using OpConversionPattern::OpConversionPattern;
- if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
- return failure();
+ LogicalResult
+ matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 2);
- Type dstType = getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
+ if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
- rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
- adaptor.getOperands());
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
- return success();
-}
+ rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
+ adaptor.getOperands());
-//===----------------------------------------------------------------------===//
-// XOrIOpBooleanPattern
-//===----------------------------------------------------------------------===//
+ return success();
+ }
+};
+
+/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
+/// vector of i1.
+struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
+ using OpConversionPattern::OpConversionPattern;
-LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
- arith::XOrIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- assert(adaptor.getOperands().size() == 2);
+ LogicalResult
+ matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 2);
- if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
- return failure();
+ if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
- Type dstType = getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
- rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
- adaptor.getOperands());
- return success();
-}
+ rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
+ op, dstType, adaptor.getOperands());
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// UIToFPI1Pattern
+// UIToFPOp
//===----------------------------------------------------------------------===//
-LogicalResult
-UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Type srcType = adaptor.getOperands().front().getType();
- if (!isBoolScalarOrVector(srcType))
- return failure();
+/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
+/// of i1.
+struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
+ using OpConversionPattern::OpConversionPattern;
- Type dstType = getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
+ LogicalResult
+ matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = adaptor.getOperands().front().getType();
+ if (!isBoolScalarOrVector(srcType))
+ return failure();
- Location loc = op.getLoc();
- Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
- Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(
- op, dstType, adaptor.getOperands().front(), one, zero);
- return success();
-}
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ Location loc = op.getLoc();
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+ op, dstType, adaptor.getOperands().front(), one, zero);
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// ExtSII1Pattern
+// ExtSIOp
//===----------------------------------------------------------------------===//
-LogicalResult
-ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Value operand = adaptor.getIn();
- if (!isBoolScalarOrVector(operand.getType()))
- return failure();
+/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
+/// of i1.
+struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
+ using OpConversionPattern::OpConversionPattern;
- Location loc = op.getLoc();
- Type dstType = getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
-
- Value allOnes;
- if (auto intTy = dyn_cast<IntegerType>(dstType)) {
- unsigned componentBitwidth = intTy.getWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, intTy,
- rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
- } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
- unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, vectorTy,
- SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth)));
- } else {
- return rewriter.notifyMatchFailure(
- loc, llvm::formatv("unhandled type: {0}", dstType));
- }
+ LogicalResult
+ matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value operand = adaptor.getIn();
+ if (!isBoolScalarOrVector(operand.getType()))
+ return failure();
- Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
- zero);
- return success();
-}
+ Location loc = op.getLoc();
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ Value allOnes;
+ if (auto intTy = dyn_cast<IntegerType>(dstType)) {
+ unsigned componentBitwidth = intTy.getWidth();
+ allOnes = rewriter.create<spirv::ConstantOp>(
+ loc, intTy,
+ rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
+ } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
+ unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
+ allOnes = rewriter.create<spirv::ConstantOp>(
+ loc, vectorTy,
+ SplatElementsAttr::get(vectorTy,
+ APInt::getAllOnes(componentBitwidth)));
+ } else {
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unhandled type: {0}", dstType));
+ }
+
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
+ zero);
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// ExtUII1Pattern
+// ExtUIOp
//===----------------------------------------------------------------------===//
-LogicalResult
-ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Type srcType = adaptor.getOperands().front().getType();
- if (!isBoolScalarOrVector(srcType))
- return failure();
+/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
+/// of i1.
+struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = adaptor.getOperands().front().getType();
+ if (!isBoolScalarOrVector(srcType))
+ return failure();
- Type dstType = getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
- Location loc = op.getLoc();
- Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
- Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(
- op, dstType, adaptor.getOperands().front(), one, zero);
- return success();
-}
+ Location loc = op.getLoc();
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+ op, dstType, adaptor.getOperands().front(), one, zero);
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// TruncII1Pattern
+// TruncIOp
//===----------------------------------------------------------------------===//
-LogicalResult
-TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Type dstType = getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
+/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
+/// of i1.
+struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
+ using OpConversionPattern::OpConversionPattern;
- if (!isBoolScalarOrVector(dstType))
- return failure();
+ LogicalResult
+ matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
- Location loc = op.getLoc();
- auto srcType = adaptor.getOperands().front().getType();
- // Check if (x & 1) == 1.
- Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
- Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
- loc, srcType, adaptor.getOperands()[0], mask);
- Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
-
- Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
- Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
- return success();
-}
+ if (!isBoolScalarOrVector(dstType))
+ return failure();
+
+ Location loc = op.getLoc();
+ auto srcType = adaptor.getOperands().front().getType();
+ // Check if (x & 1) == 1.
+ Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
+ Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
+ loc, srcType, adaptor.getOperands()[0], mask);
+ Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
+
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// TypeCastingOpPattern
+// TypeCastingOp
//===----------------------------------------------------------------------===//
+/// Converts type-casting standard operations to SPIR-V operations.
template <typename Op, typename SPIRVOp>
-LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
- Op op, typename Op::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- assert(adaptor.getOperands().size() == 1);
- Type srcType = adaptor.getOperands().front().getType();
- Type dstType = this->getTypeConverter()->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
-
- if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
- return failure();
+struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
+ using OpConversionPattern<Op>::OpConversionPattern;
- if (dstType == srcType) {
- // Due to type conversion, we are seeing the same source and target type.
- // Then we can just erase this operation by forwarding its operand.
- rewriter.replaceOp(op, adaptor.getOperands().front());
- } else {
- rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
- adaptor.getOperands());
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 1);
+ Type srcType = adaptor.getOperands().front().getType();
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
+ return failure();
+
+ if (dstType == srcType) {
+ // Due to type conversion, we are seeing the same source and target type.
+ // Then we can just erase this operation by forwarding its operand.
+ rewriter.replaceOp(op, adaptor.getOperands().front());
+ } else {
+ rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
+ adaptor.getOperands());
+ }
+ return success();
}
- return success();
-}
+};
//===----------------------------------------------------------------------===//
-// CmpIOpBooleanPattern
+// CmpIOp
//===----------------------------------------------------------------------===//
-LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
- arith::CmpIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Type srcType = op.getLhs().getType();
- if (!isBoolScalarOrVector(srcType))
- return failure();
- Type dstType = getTypeConverter()->convertType(srcType);
- if (!dstType)
- return getTypeConversionFailure(rewriter, op, srcType);
+/// Converts integer compare operation on i1 type operands to SPIR-V ops.
+class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
- switch (op.getPredicate()) {
- case arith::CmpIPredicate::eq: {
- rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
- adaptor.getRhs());
- return success();
- }
- case arith::CmpIPredicate::ne: {
- rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(),
- adaptor.getRhs());
- return success();
- }
- case arith::CmpIPredicate::uge:
- case arith::CmpIPredicate::ugt:
- case arith::CmpIPredicate::ule:
- case arith::CmpIPredicate::ult: {
- // There are no direct corresponding instructions in SPIR-V for such cases.
- // Extend them to 32-bit and do comparision then.
- Type type = rewriter.getI32Type();
- if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
- Value extLhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
- Value extRhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
-
- rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
- extRhs);
- return success();
- }
- default:
- break;
+ LogicalResult
+ matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = op.getLhs().getType();
+ if (!isBoolScalarOrVector(srcType))
+ return failure();
+ Type dstType = getTypeConverter()->convertType(srcType);
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op, srcType);
+
+ switch (op.getPredicate()) {
+ case arith::CmpIPredicate::eq: {
+ rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
+ adaptor.getRhs());
+ return success();
+ }
+ case arith::CmpIPredicate::ne: {
+ rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
+ op, adaptor.getLhs(), adaptor.getRhs());
+ return success();
+ }
+ case arith::CmpIPredicate::uge:
+ case arith::CmpIPredicate::ugt:
+ case arith::CmpIPredicate::ule:
+ case arith::CmpIPredicate::ult: {
+ // There are no direct corresponding instructions in SPIR-V for such
+ // cases. Extend them to 32-bit and do comparision then.
+ Type type = rewriter.getI32Type();
+ if (auto vectorType = dyn_cast<VectorType>(dstType))
+ type = VectorType::get(vectorType.getShape(), type);
+ Value extLhs =
+ rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
+ Value extRhs =
+ rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
+
+ rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
+ extRhs);
+ return success();
+ }
+ default:
+ break;
+ }
+ return failure();
}
- return failure();
-}
+};
-//===----------------------------------------------------------------------===//
-// CmpIOpPattern
-//===----------------------------------------------------------------------===//
+/// Converts integer compare operation to SPIR-V ops.
+class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
-LogicalResult
-CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Type srcType = op.getLhs().getType();
- if (isBoolScalarOrVector(srcType))
- return failure();
- Type dstType = getTypeConverter()->convertType(srcType);
- if (!dstType)
- return getTypeConversionFailure(rewriter, op, srcType);
+ LogicalResult
+ matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = op.getLhs().getType();
+ if (isBoolScalarOrVector(srcType))
+ return failure();
+ Type dstType = getTypeConverter()->convertType(srcType);
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op, srcType);
- switch (op.getPredicate()) {
+ switch (op.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
@@ -854,216 +704,253 @@ CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
adaptor.getRhs()); \
return success();
- DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
- DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
- DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
- DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
- DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
- DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
- DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
- DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
- DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
- DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
+ DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
+ DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
+ DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
+ DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
+ DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
+ DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
+ DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
+ DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
+ DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
+ DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
#undef DISPATCH
+ }
+ return failure();
}
- return failure();
-}
+};
//===----------------------------------------------------------------------===//
// CmpFOpPattern
//===----------------------------------------------------------------------===//
-LogicalResult
-CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- switch (op.getPredicate()) {
+/// Converts floating-point comparison operations to SPIR-V ops.
+class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ switch (op.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
adaptor.getRhs()); \
return success();
- // Ordered.
- DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
- DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
- DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
- DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
- DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
- DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
- // Unordered.
- DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
- DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
- DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
- DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
- DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
- DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
+ // Ordered.
+ DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
+ DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
+ DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
+ DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
+ DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
+ DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
+ // Unordered.
+ DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
+ DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
+ DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
+ DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
+ DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
+ DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
#undef DISPATCH
- default:
- break;
+ default:
+ break;
+ }
+ return failure();
}
- return failure();
-}
-
-//===----------------------------------------------------------------------===//
-// CmpFOpNanKernelPattern
-//===----------------------------------------------------------------------===//
+};
-LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
- arith::CmpFOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- if (op.getPredicate() == arith::CmpFPredicate::ORD) {
- rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
- adaptor.getRhs());
- return success();
- }
+/// Converts floating point NaN check to SPIR-V ops. This pattern requires
+/// Kernel capability.
+class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
- if (op.getPredicate() == arith::CmpFPredicate::UNO) {
- rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
+ LogicalResult
+ matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (op.getPredicate() == arith::CmpFPredicate::ORD) {
+ rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
adaptor.getRhs());
- return success();
- }
-
- return failure();
-}
+ return success();
+ }
-//===----------------------------------------------------------------------===//
-// CmpFOpNanNonePattern
-//===----------------------------------------------------------------------===//
+ if (op.getPredicate() == arith::CmpFPredicate::UNO) {
+ rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
+ adaptor.getRhs());
+ return success();
+ }
-LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
- arith::CmpFOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- if (op.getPredicate() != arith::CmpFPredicate::ORD &&
- op.getPredicate() != arith::CmpFPredicate::UNO)
return failure();
+ }
+};
- Location loc = op.getLoc();
- auto *converter = getTypeConverter<SPIRVTypeConverter>();
+/// Converts floating point NaN check to SPIR-V ops. This pattern does not
+/// require additional capability.
+class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
+public:
+ using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
- Value replace;
- if (converter->getOptions().enableFastMathMode) {
- if (op.getPredicate() == arith::CmpFPredicate::ORD) {
- // Ordered comparsion checks if neither operand is NaN.
- replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
+ LogicalResult
+ matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (op.getPredicate() != arith::CmpFPredicate::ORD &&
+ op.getPredicate() != arith::CmpFPredicate::UNO)
+ return failure();
+
+ Location loc = op.getLoc();
+ auto *converter = getTypeConverter<SPIRVTypeConverter>();
+
+ Value replace;
+ if (converter->getOptions().enableFastMathMode) {
+ if (op.getPredicate() == arith::CmpFPredicate::ORD) {
+ // Ordered comparsion checks if neither operand is NaN.
+ replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
+ } else {
+ // Unordered comparsion checks if either operand is NaN.
+ replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
+ }
} else {
- // Unordered comparsion checks if either operand is NaN.
- replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+ replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+ if (op.getPredicate() == arith::CmpFPredicate::ORD)
+ replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
}
- } else {
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
- replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
- if (op.getPredicate() == arith::CmpFPredicate::ORD)
- replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+ rewriter.replaceOp(op, replace);
+ return success();
}
-
- rewriter.replaceOp(op, replace);
- return success();
-}
+};
//===----------------------------------------------------------------------===//
-// AddUIExtendedOpPattern
+// AddUIExtendedOp
//===----------------------------------------------------------------------===//
-LogicalResult AddUIExtendedOpPattern::matchAndRewrite(
- arith::AddUIExtendedOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Type dstElemTy = adaptor.getLhs().getType();
- Location loc = op->getLoc();
- Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
- adaptor.getRhs());
-
- Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(0));
- Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(1));
-
- // Convert the carry value to boolean.
- Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
- Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
-
- rewriter.replaceOp(op, {sumResult, carryResult});
- return success();
-}
+/// Converts arith.addui_extended to spirv.IAddCarry.
+class AddUIExtendedOpPattern final
+ : public OpConversionPattern<arith::AddUIExtendedOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type dstElemTy = adaptor.getLhs().getType();
+ Location loc = op->getLoc();
+ Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
+ adaptor.getRhs());
+
+ Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
+ loc, result, llvm::ArrayRef(0));
+ Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
+ loc, result, llvm::ArrayRef(1));
+
+ // Convert the carry value to boolean.
+ Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
+ Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+
+ rewriter.replaceOp(op, {sumResult, carryResult});
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// MulIExtendedOpPattern
+// MulIExtendedOp
//===----------------------------------------------------------------------===//
+/// Converts arith.mul*i_extended to spirv.*MulExtended.
template <typename ArithMulOp, typename SPIRVMulOp>
-LogicalResult MulIExtendedOpPattern<ArithMulOp, SPIRVMulOp>::matchAndRewrite(
- ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Location loc = op->getLoc();
- Value result =
- rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
-
- Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(0));
- Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(1));
-
- rewriter.replaceOp(op, {low, high});
- return success();
-}
+class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
+public:
+ using OpConversionPattern<ArithMulOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ Value result =
+ rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
+
+ Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
+ llvm::ArrayRef(0));
+ Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
+ llvm::ArrayRef(1));
+
+ rewriter.replaceOp(op, {low, high});
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// SelectOpPattern
+// SelectOp
//===----------------------------------------------------------------------===//
-LogicalResult
-SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
- adaptor.getTrueValue(),
- adaptor.getFalseValue());
- return success();
-}
+/// Converts arith.select to spirv.Select.
+class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
+ adaptor.getTrueValue(),
+ adaptor.getFalseValue());
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// MaxFOpPattern
+// MaxFOp
//===----------------------------------------------------------------------===//
+/// Converts arith.maxf to spirv.GL.FMax or spirv.CL.fmax.
template <typename Op, typename SPIRVOp>
-LogicalResult MinMaxFOpPattern<Op, SPIRVOp>::matchAndRewrite(
- Op op, typename Op::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
- Type dstType = converter->convertType(op.getType());
- if (!dstType)
- return getTypeConversionFailure(rewriter, op);
-
- // arith.maxf/minf:
- // "if one of the arguments is NaN, then the result is also NaN."
- // spirv.GL.FMax/FMin
- // "which operand is the result is undefined if one of the operands
- // is a NaN."
- // spirv.CL.fmax/fmin:
- // "If one argument is a NaN, Fmin returns the other argument."
-
- Location loc = op.getLoc();
- Value spirvOp = rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
-
- if (converter->getOptions().enableFastMathMode) {
- rewriter.replaceOp(op, spirvOp);
- return success();
- }
+class MinMaxFOpPattern final : public OpConversionPattern<Op> {
+public:
+ using OpConversionPattern<Op>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = converter->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // arith.maxf/minf:
+ // "if one of the arguments is NaN, then the result is also NaN."
+ // spirv.GL.FMax/FMin
+ // "which operand is the result is undefined if one of the operands
+ // is a NaN."
+ // spirv.CL.fmax/fmin:
+ // "If one argument is a NaN, Fmin returns the other argument."
+
+ Location loc = op.getLoc();
+ Value spirvOp =
+ rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+ if (converter->getOptions().enableFastMathMode) {
+ rewriter.replaceOp(op, spirvOp);
+ return success();
+ }
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getLhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getRhs(), select1);
+ Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+ adaptor.getLhs(), spirvOp);
+ Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
+ adaptor.getRhs(), select1);
- rewriter.replaceOp(op, select2);
- return success();
-}
+ rewriter.replaceOp(op, select2);
+ return success();
+ }
+};
+
+} // namespace
//===----------------------------------------------------------------------===//
// Pattern Population
More information about the Mlir-commits
mailing list