[Mlir-commits] [mlir] e1e0ecb - [mlir][spirv] Support more comparisons on boolean values

Lei Zhang llvmlistbot at llvm.org
Tue Jun 28 09:06:53 PDT 2022


Author: Lei Zhang
Date: 2022-06-28T11:58:42-04:00
New Revision: e1e0ecb96e0aa3b3846484f82e4c0e8c31c50341

URL: https://github.com/llvm/llvm-project/commit/e1e0ecb96e0aa3b3846484f82e4c0e8c31c50341
DIFF: https://github.com/llvm/llvm-project/commit/e1e0ecb96e0aa3b3846484f82e4c0e8c31c50341.diff

LOG: [mlir][spirv] Support more comparisons on boolean values

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D128692

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 31d9023e4346d..4bf985e86283a 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "arith-to-spirv-pattern"
@@ -665,23 +666,44 @@ LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
     arith::CmpIOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-  Type operandType = op.getLhs().getType();
-  if (!isBoolScalarOrVector(operandType))
+  Type srcType = op.getLhs().getType();
+  if (!isBoolScalarOrVector(srcType))
+    return failure();
+  Type dstType = getTypeConverter()->convertType(srcType);
+  if (!dstType)
     return failure();
 
   switch (op.getPredicate()) {
-#define DISPATCH(cmpPredicate, spirvOp)                                        \
-  case cmpPredicate: {                                                         \
-    rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
-                                         adaptor.getRhs());                    \
-    return success();                                                          \
+  case arith::CmpIPredicate::eq: {
+    rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
+                                                       adaptor.getRhs());
+    return success();
   }
-
-    DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
-    DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
-
-#undef DISPATCH
-  default:;
+  case arith::CmpIPredicate::ne: {
+    rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(),
+                                                          adaptor.getRhs());
+    return success();
+  }
+  case arith::CmpIPredicate::uge:
+  case arith::CmpIPredicate::ugt:
+  case arith::CmpIPredicate::ule:
+  case arith::CmpIPredicate::ult: {
+    // There are no direct corresponding instructions in SPIR-V for such cases.
+    // Extend them to 32-bit and do comparision then.
+    Type type = rewriter.getI32Type();
+    if (auto vectorType = dstType.dyn_cast<VectorType>())
+      type = VectorType::get(vectorType.getShape(), type);
+    auto extLhs =
+        rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
+    auto extRhs =
+        rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
+
+    rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
+                                               extRhs);
+    return success();
+  }
+  default:
+    break;
   }
   return failure();
 }

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 8925197512875..22e3dc48b98a9 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -401,8 +401,8 @@ func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
   return
 }
 
-// CHECK-LABEL: @boolcmpi
-func.func @boolcmpi(%arg0 : i1, %arg1 : i1) {
+// CHECK-LABEL: @boolcmpi_equality
+func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
   // CHECK: spv.LogicalEqual
   %0 = arith.cmpi eq, %arg0, %arg1 : i1
   // CHECK: spv.LogicalNotEqual
@@ -410,8 +410,19 @@ func.func @boolcmpi(%arg0 : i1, %arg1 : i1) {
   return
 }
 
-// CHECK-LABEL: @vec1boolcmpi
-func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+// CHECK-LABEL: @boolcmpi_unsigned
+func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
+  // CHECK-COUNT-2: spv.Select
+  // CHECK: spv.UGreaterThanEqual
+  %0 = arith.cmpi uge, %arg0, %arg1 : i1
+  // CHECK-COUNT-2: spv.Select
+  // CHECK: spv.ULessThan
+  %1 = arith.cmpi ult, %arg0, %arg1 : i1
+  return
+}
+
+// CHECK-LABEL: @vec1boolcmpi_equality
+func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
   // CHECK: spv.LogicalEqual
   %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
   // CHECK: spv.LogicalNotEqual
@@ -419,8 +430,19 @@ func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
   return
 }
 
-// CHECK-LABEL: @vecboolcmpi
-func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+// CHECK-LABEL: @vec1boolcmpi_unsigned
+func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+  // CHECK-COUNT-2: spv.Select
+  // CHECK: spv.UGreaterThanEqual
+  %0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
+  // CHECK-COUNT-2: spv.Select
+  // CHECK: spv.ULessThan
+  %1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
+  return
+}
+
+// CHECK-LABEL: @vecboolcmpi_equality
+func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
   // CHECK: spv.LogicalEqual
   %0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
   // CHECK: spv.LogicalNotEqual
@@ -428,6 +450,18 @@ func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
   return
 }
 
+// CHECK-LABEL: @vecboolcmpi_unsigned
+func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
+  // CHECK-COUNT-2: spv.Select
+  // CHECK: spv.UGreaterThanEqual
+  %0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
+  // CHECK-COUNT-2: spv.Select
+  // CHECK: spv.ULessThan
+  %1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
+  return
+}
+
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list