[Mlir-commits] [mlir] 1c12a95 - [mlir][StandardToSPIRV] Handle conversion of cmpi operation with i1
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 29 10:09:40 PDT 2020
Author: MaheshRavishankar
Date: 2020-04-29T10:09:03-07:00
New Revision: 1c12a95d9c52de8980ea9979350c7eabc1b9fd01
URL: https://github.com/llvm/llvm-project/commit/1c12a95d9c52de8980ea9979350c7eabc1b9fd01
DIFF: https://github.com/llvm/llvm-project/commit/1c12a95d9c52de8980ea9979350c7eabc1b9fd01.diff
LOG: [mlir][StandardToSPIRV] Handle conversion of cmpi operation with i1
type operands.
The instructions used to convert std.cmpi cannot have i1 types
according to SPIR-V specification. A different set of operations are
specified in the SPIR-V spec for comparing boolean types. Enhance the
StandardToSPIRV lowering to target these instructions when operands to
std.cmpi operation are of i1 type.
Differential Revision: https://reviews.llvm.org/D79049
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index b53128a33018..12b22cacdee2 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -184,6 +184,16 @@ class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts integer compare operation on i1 type opearnds to SPIR-V ops.
+class BoolCmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
+public:
+ using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
+
+ LogicalResult
+ matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts integer compare operation to SPIR-V ops.
class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
public:
@@ -453,11 +463,43 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
// CmpIOp
//===----------------------------------------------------------------------===//
+LogicalResult
+BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ CmpIOpOperandAdaptor cmpIOpOperands(operands);
+
+ Type operandType = cmpIOp.lhs().getType();
+ if (!operandType.isa<IntegerType>() ||
+ operandType.cast<IntegerType>().getWidth() != 1)
+ return failure();
+
+ switch (cmpIOp.getPredicate()) {
+#define DISPATCH(cmpPredicate, spirvOp) \
+ case cmpPredicate: \
+ rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
+ cmpIOpOperands.lhs(), \
+ cmpIOpOperands.rhs()); \
+ return success();
+
+ DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp);
+ DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp);
+
+#undef DISPATCH
+ default:;
+ }
+ return failure();
+}
+
LogicalResult
CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpIOpOperandAdaptor cmpIOpOperands(operands);
+ Type operandType = cmpIOp.lhs().getType();
+ if (operandType.isa<IntegerType>() &&
+ operandType.cast<IntegerType>().getWidth() == 1)
+ return failure();
+
switch (cmpIOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
@@ -599,9 +641,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
- ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern,
- CmpIOpPattern, LoadOpPattern, ReturnOpPattern, SelectOpPattern,
- StoreOpPattern, TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
+ BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
+ CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern,
+ SelectOpPattern, StoreOpPattern,
+ TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(
context, typeConverter);
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 6abdde44e3e5..e7ad95a1a173 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -285,6 +285,15 @@ func @cmpi(%arg0 : i32, %arg1 : i32) {
return
}
+// CHECK-LABEL: @boolcmpi
+func @boolcmpi(%arg0 : i1, %arg1 : i1) {
+ // CHECK: spv.LogicalEqual
+ %0 = cmpi "eq", %arg0, %arg1 : i1
+ // CHECK: spv.LogicalNotEqual
+ %1 = cmpi "ne", %arg0, %arg1 : i1
+ return
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list