[Mlir-commits] [mlir] c80484e - [mlir][StandardToSPIRV] Add support for lowering trunci to SPIR-V to i1 types.

Hanhan Wang llvmlistbot at llvm.org
Wed Feb 17 07:24:06 PST 2021


Author: Hanhan Wang
Date: 2021-02-17T07:23:41-08:00
New Revision: c80484e16ed8a69f1bef7bedc687a3b31707ac30

URL: https://github.com/llvm/llvm-project/commit/c80484e16ed8a69f1bef7bedc687a3b31707ac30
DIFF: https://github.com/llvm/llvm-project/commit/c80484e16ed8a69f1bef7bedc687a3b31707ac30.diff

LOG: [mlir][StandardToSPIRV] Add support for lowering trunci to SPIR-V to i1 types.

Add a pattern to converting some value to a boolean. spirv.S/UConvert does not
work on i1 types. Thus, the pattern is lowered to cmpi + select.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D96851

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 77f25ac935a8..2775f778719a 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -512,6 +512,35 @@ class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
   }
 };
 
+/// Converts std.trunci to spv.Select if the type of result is i1 or vector of
+/// i1.
+class TruncI1Pattern final : public OpConversionPattern<TruncateIOp> {
+public:
+  using OpConversionPattern<TruncateIOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TruncateIOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType =
+        this->getTypeConverter()->convertType(op.getResult().getType());
+    if (!isBoolScalarOrVector(dstType))
+      return failure();
+
+    Location loc = op.getLoc();
+    auto srcType = operands.front().getType();
+    // Check if (x & 1) == 1.
+    Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
+    Value maskedSrc =
+        rewriter.create<spirv::BitwiseAndOp>(loc, srcType, operands[0], mask);
+    Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
+
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
+    return success();
+  }
+};
+
 /// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
 /// i1.
 class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
@@ -547,10 +576,10 @@ class TypeCastingOpPattern final : public OpConversionPattern<StdOp> {
                   ConversionPatternRewriter &rewriter) const override {
     assert(operands.size() == 1);
     auto srcType = operands.front().getType();
-    if (isBoolScalarOrVector(srcType))
-      return failure();
     auto dstType =
         this->getTypeConverter()->convertType(operation.getResult().getType());
+    if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
+      return failure();
     if (dstType == srcType) {
       // Due to type conversion, we are seeing the same source and target type.
       // Then we can just erase this operation by forwarding its operand.
@@ -1178,7 +1207,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       ReturnOpPattern, SelectOpPattern,
 
       // Type cast patterns
-      UIToFPI1Pattern, ZeroExtendI1Pattern,
+      UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern,
       TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
       TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
       TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 1dc3678bfebf..6bb6d78d56af 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -744,6 +744,30 @@ func @trunci2(%arg0: i32) -> i16 {
   return %0 : i16
 }
 
+// CHECK-LABEL: @trunc_to_i1
+func @trunc_to_i1(%arg0: i32) -> i1 {
+  // CHECK: %[[MASK:.*]] = spv.constant 1 : i32
+  // CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : i32
+  // CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : i32
+  // CHECK-DAG: %[[TRUE:.*]] = spv.constant true
+  // CHECK-DAG: %[[FALSE:.*]] = spv.constant false
+  // CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : i1, i1
+  %0 = std.trunci %arg0 : i32 to i1
+  return %0 : i1
+}
+
+// CHECK-LABEL: @trunc_to_veci1
+func @trunc_to_veci1(%arg0: vector<4xi32>) -> vector<4xi1> {
+  // CHECK: %[[MASK:.*]] = spv.constant dense<1> : vector<4xi32>
+  // CHECK: %[[MASKED_SRC:.*]] = spv.BitwiseAnd %{{.*}}, %[[MASK]] : vector<4xi32>
+  // CHECK: %[[IS_ONE:.*]] = spv.IEqual %[[MASKED_SRC]], %[[MASK]] : vector<4xi32>
+  // CHECK-DAG: %[[TRUE:.*]] = spv.constant dense<true> : vector<4xi1>
+  // CHECK-DAG: %[[FALSE:.*]] = spv.constant dense<false> : vector<4xi1>
+  // CHECK: spv.Select %[[IS_ONE]], %[[TRUE]], %[[FALSE]] : vector<4xi1>, vector<4xi1>
+  %0 = std.trunci %arg0 : vector<4xi32> to vector<4xi1>
+  return %0 : vector<4xi1>
+}
+
 // CHECK-LABEL: @fptosi1
 func @fptosi1(%arg0 : f32) -> i32 {
   // CHECK: spv.ConvertFToS %{{.*}} : f32 to i32


        


More information about the Mlir-commits mailing list