[Mlir-commits] [mlir] c80484e - [mlir][StandardToSPIRV] Add support for lowering trunci to SPIR-V to i1 types.
Hanhan Wang
llvmlistbot at llvm.org
Wed Feb 17 07:24:06 PST 2021
Author: Hanhan Wang
Date: 2021-02-17T07:23:41-08:00
New Revision: c80484e16ed8a69f1bef7bedc687a3b31707ac30
URL: https://github.com/llvm/llvm-project/commit/c80484e16ed8a69f1bef7bedc687a3b31707ac30
DIFF: https://github.com/llvm/llvm-project/commit/c80484e16ed8a69f1bef7bedc687a3b31707ac30.diff
LOG: [mlir][StandardToSPIRV] Add support for lowering trunci to SPIR-V to i1 types.
Add a pattern to converting some value to a boolean. spirv.S/UConvert does not
work on i1 types. Thus, the pattern is lowered to cmpi + select.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D96851
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 77f25ac935a8..2775f778719a 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -512,6 +512,35 @@ class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
}
};
+/// Converts std.trunci to spv.Select if the type of result is i1 or vector of
+/// i1.
+class TruncI1Pattern final : public OpConversionPattern<TruncateIOp> {
+public:
+ using OpConversionPattern<TruncateIOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TruncateIOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType =
+ this->getTypeConverter()->convertType(op.getResult().getType());
+ if (!isBoolScalarOrVector(dstType))
+ return failure();
+
+ Location loc = op.getLoc();
+ auto srcType = operands.front().getType();
+ // Check if (x & 1) == 1.
+ Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
+ Value maskedSrc =
+ rewriter.create<spirv::BitwiseAndOp>(loc, srcType, operands[0], mask);
+ Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
+
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
+ return success();
+ }
+};
+
/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
/// i1.
class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
@@ -547,10 +576,10 @@ class TypeCastingOpPattern final : public OpConversionPattern<StdOp> {
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1);
auto srcType = operands.front().getType();
- if (isBoolScalarOrVector(srcType))
- return failure();
auto dstType =
this->getTypeConverter()->convertType(operation.getResult().getType());
+ if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
+ return failure();
if (dstType == srcType) {
// Due to type conversion, we are seeing the same source and target type.
// Then we can just erase this operation by forwarding its operand.
@@ -1178,7 +1207,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
ReturnOpPattern, SelectOpPattern,
// Type cast patterns
- UIToFPI1Pattern, ZeroExtendI1Pattern,
+ UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern,
TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 1dc3678bfebf..6bb6d78d56af 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -744,6 +744,30 @@ func @trunci2(%arg0: i32) -> i16 {
return %0 : i16
}
+// CHECK-LABEL: @trunc_to_i1
+func @trunc_to_i1(%arg0: i32) -> i1 {
+ // CHECK: %[[MASK:.*]] = spv.constant 1 : i32
+ // CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : i32
+ // CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : i32
+ // CHECK-DAG: %[[TRUE:.*]] = spv.constant true
+ // CHECK-DAG: %[[FALSE:.*]] = spv.constant false
+ // CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : i1, i1
+ %0 = std.trunci %arg0 : i32 to i1
+ return %0 : i1
+}
+
+// CHECK-LABEL: @trunc_to_veci1
+func @trunc_to_veci1(%arg0: vector<4xi32>) -> vector<4xi1> {
+ // CHECK: %[[MASK:.*]] = spv.constant dense<1> : vector<4xi32>
+ // CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : vector<4xi32>
+ // CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : vector<4xi32>
+ // CHECK-DAG: %[[TRUE:.*]] = spv.constant dense<true> : vector<4xi1>
+ // CHECK-DAG: %[[FALSE:.*]] = spv.constant dense<false> : vector<4xi1>
+ // CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : vector<4xi1>, vector<4xi1>
+ %0 = std.trunci %arg0 : vector<4xi32> to vector<4xi1>
+ return %0 : vector<4xi1>
+}
+
// CHECK-LABEL: @fptosi1
func @fptosi1(%arg0 : f32) -> i32 {
// CHECK: spv.ConvertFToS %{{.*}} : f32 to i32
More information about the Mlir-commits
mailing list