[Mlir-commits] [mlir] e1e0ecb - [mlir][spirv] Support more comparisons on boolean values
Lei Zhang
llvmlistbot at llvm.org
Tue Jun 28 09:06:53 PDT 2022
Author: Lei Zhang
Date: 2022-06-28T11:58:42-04:00
New Revision: e1e0ecb96e0aa3b3846484f82e4c0e8c31c50341
URL: https://github.com/llvm/llvm-project/commit/e1e0ecb96e0aa3b3846484f82e4c0e8c31c50341
DIFF: https://github.com/llvm/llvm-project/commit/e1e0ecb96e0aa3b3846484f82e4c0e8c31c50341.diff
LOG: [mlir][spirv] Support more comparisons on boolean values
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D128692
Added:
Modified:
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 31d9023e4346d..4bf985e86283a 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "arith-to-spirv-pattern"
@@ -665,23 +666,44 @@ LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- Type operandType = op.getLhs().getType();
- if (!isBoolScalarOrVector(operandType))
+ Type srcType = op.getLhs().getType();
+ if (!isBoolScalarOrVector(srcType))
+ return failure();
+ Type dstType = getTypeConverter()->convertType(srcType);
+ if (!dstType)
return failure();
switch (op.getPredicate()) {
-#define DISPATCH(cmpPredicate, spirvOp) \
- case cmpPredicate: { \
- rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
- adaptor.getRhs()); \
- return success(); \
+ case arith::CmpIPredicate::eq: {
+ rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
+ adaptor.getRhs());
+ return success();
}
-
- DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
- DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
-
-#undef DISPATCH
- default:;
+ case arith::CmpIPredicate::ne: {
+ rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(),
+ adaptor.getRhs());
+ return success();
+ }
+ case arith::CmpIPredicate::uge:
+ case arith::CmpIPredicate::ugt:
+ case arith::CmpIPredicate::ule:
+ case arith::CmpIPredicate::ult: {
+ // There are no direct corresponding instructions in SPIR-V for such cases.
+ // Extend them to 32-bit and do comparision then.
+ Type type = rewriter.getI32Type();
+ if (auto vectorType = dstType.dyn_cast<VectorType>())
+ type = VectorType::get(vectorType.getShape(), type);
+ auto extLhs =
+ rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
+ auto extRhs =
+ rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
+
+ rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
+ extRhs);
+ return success();
+ }
+ default:
+ break;
}
return failure();
}
diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 8925197512875..22e3dc48b98a9 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -401,8 +401,8 @@ func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
return
}
-// CHECK-LABEL: @boolcmpi
-func.func @boolcmpi(%arg0 : i1, %arg1 : i1) {
+// CHECK-LABEL: @boolcmpi_equality
+func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
// CHECK: spv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : i1
// CHECK: spv.LogicalNotEqual
@@ -410,8 +410,19 @@ func.func @boolcmpi(%arg0 : i1, %arg1 : i1) {
return
}
-// CHECK-LABEL: @vec1boolcmpi
-func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+// CHECK-LABEL: @boolcmpi_unsigned
+func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
+ // CHECK-COUNT-2: spv.Select
+ // CHECK: spv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : i1
+ // CHECK-COUNT-2: spv.Select
+ // CHECK: spv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @vec1boolcmpi_equality
+func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
// CHECK: spv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
// CHECK: spv.LogicalNotEqual
@@ -419,8 +430,19 @@ func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
return
}
-// CHECK-LABEL: @vecboolcmpi
-func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+// CHECK-LABEL: @vec1boolcmpi_unsigned
+func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
+ // CHECK-COUNT-2: spv.Select
+ // CHECK: spv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
+ // CHECK-COUNT-2: spv.Select
+ // CHECK: spv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
+ return
+}
+
+// CHECK-LABEL: @vecboolcmpi_equality
+func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
// CHECK: spv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
// CHECK: spv.LogicalNotEqual
@@ -428,6 +450,18 @@ func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
return
}
+// CHECK-LABEL: @vecboolcmpi_unsigned
+func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
+ // CHECK-COUNT-2: spv.Select
+ // CHECK: spv.UGreaterThanEqual
+ %0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
+ // CHECK-COUNT-2: spv.Select
+ // CHECK: spv.ULessThan
+ %1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
+ return
+}
+
+
} // end module
// -----
More information about the Mlir-commits
mailing list