[Mlir-commits] [mlir] 7b7df8e - [mlir][StandardToSPIRV] Add support for lowering std.xor on bool to SPIR-V

Hanhan Wang llvmlistbot at llvm.org
Tue Apr 20 07:35:39 PDT 2021


Author: Hanhan Wang
Date: 2021-04-20T07:35:20-07:00
New Revision: 7b7df8e85eec445389e4b07915f16aa18332719d

URL: https://github.com/llvm/llvm-project/commit/7b7df8e85eec445389e4b07915f16aa18332719d
DIFF: https://github.com/llvm/llvm-project/commit/7b7df8e85eec445389e4b07915f16aa18332719d.diff

LOG: [mlir][StandardToSPIRV] Add support for lowering std.xor on bool to SPIR-V

std.xor ops on bool are lowered to spv.LogicalNotEqual. For Boolean values, xor
and not-equal are the same thing.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D100817

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 0196a21f4a699..2a6e7f2818602 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -663,6 +663,17 @@ class XOrOpPattern final : public OpConversionPattern<XOrOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector
+/// of i1.
+class BoolXOrOpPattern final : public OpConversionPattern<XOrOp> {
+public:
+  using OpConversionPattern<XOrOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1250,6 +1261,22 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
   return success();
 }
 
+LogicalResult
+BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+                                  ConversionPatternRewriter &rewriter) const {
+  assert(operands.size() == 2);
+
+  if (!isBoolScalarOrVector(operands.front().getType()))
+    return failure();
+
+  auto dstType = getTypeConverter()->convertType(xorOp.getType());
+  if (!dstType)
+    return failure();
+  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType,
+                                                        operands);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern population
 //===----------------------------------------------------------------------===//
@@ -1293,7 +1320,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
       UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
       UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
-      SignedRemIOpPattern, XOrOpPattern,
+      SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
 
       // Comparison patterns
       BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 0148a0731dc9d..fe769482c787b 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -224,6 +224,8 @@ func @logical_scalar(%arg0 : i1, %arg1 : i1) {
   %0 = and %arg0, %arg1 : i1
   // CHECK: spv.LogicalOr
   %1 = or %arg0, %arg1 : i1
+  // CHECK: spv.LogicalNotEqual
+  %2 = xor %arg0, %arg1 : i1
   return
 }
 
@@ -233,6 +235,8 @@ func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
   %0 = and %arg0, %arg1 : vector<4xi1>
   // CHECK: spv.LogicalOr
   %1 = or %arg0, %arg1 : vector<4xi1>
+  // CHECK: spv.LogicalNotEqual
+  %2 = xor %arg0, %arg1 : vector<4xi1>
   return
 }
 


        


More information about the Mlir-commits mailing list