[Mlir-commits] [mlir] 7b7df8e - [mlir][StandardToSPIRV] Add support for lowering std.xor on bool to SPIR-V
Hanhan Wang
llvmlistbot at llvm.org
Tue Apr 20 07:35:39 PDT 2021
Author: Hanhan Wang
Date: 2021-04-20T07:35:20-07:00
New Revision: 7b7df8e85eec445389e4b07915f16aa18332719d
URL: https://github.com/llvm/llvm-project/commit/7b7df8e85eec445389e4b07915f16aa18332719d
DIFF: https://github.com/llvm/llvm-project/commit/7b7df8e85eec445389e4b07915f16aa18332719d.diff
LOG: [mlir][StandardToSPIRV] Add support for lowering std.xor on bool to SPIR-V
std.xor ops on bool are lowered to spv.LogicalNotEqual. For Boolean values, xor
and not-equal are the same thing.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D100817
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 0196a21f4a699..2a6e7f2818602 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -663,6 +663,17 @@ class XOrOpPattern final : public OpConversionPattern<XOrOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector
+/// of i1.
+class BoolXOrOpPattern final : public OpConversionPattern<XOrOp> {
+public:
+ using OpConversionPattern<XOrOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1250,6 +1261,22 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
return success();
}
+LogicalResult
+BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ assert(operands.size() == 2);
+
+ if (!isBoolScalarOrVector(operands.front().getType()))
+ return failure();
+
+ auto dstType = getTypeConverter()->convertType(xorOp.getType());
+ if (!dstType)
+ return failure();
+ rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType,
+ operands);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@@ -1293,7 +1320,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
- SignedRemIOpPattern, XOrOpPattern,
+ SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
// Comparison patterns
BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 0148a0731dc9d..fe769482c787b 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -224,6 +224,8 @@ func @logical_scalar(%arg0 : i1, %arg1 : i1) {
%0 = and %arg0, %arg1 : i1
// CHECK: spv.LogicalOr
%1 = or %arg0, %arg1 : i1
+ // CHECK: spv.LogicalNotEqual
+ %2 = xor %arg0, %arg1 : i1
return
}
@@ -233,6 +235,8 @@ func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
%0 = and %arg0, %arg1 : vector<4xi1>
// CHECK: spv.LogicalOr
%1 = or %arg0, %arg1 : vector<4xi1>
+ // CHECK: spv.LogicalNotEqual
+ %2 = xor %arg0, %arg1 : vector<4xi1>
return
}
More information about the Mlir-commits
mailing list