[Mlir-commits] [mlir] 1909b6a - [mlir][StandardToSPIRV] Handle vector of i1 case for lowering zexti to SPIR-V.

Hanhan Wang llvmlistbot at llvm.org
Fri Sep 18 07:07:34 PDT 2020


Author: Hanhan Wang
Date: 2020-09-18T07:07:22-07:00
New Revision: 1909b6ac0dbc2f1306103a5ea7f5e59f2232b133

URL: https://github.com/llvm/llvm-project/commit/1909b6ac0dbc2f1306103a5ea7f5e59f2232b133
DIFF: https://github.com/llvm/llvm-project/commit/1909b6ac0dbc2f1306103a5ea7f5e59f2232b133.diff

LOG: [mlir][StandardToSPIRV] Handle vector of i1 case for lowering zexti to SPIR-V.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D87887

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 6ae17c33070c..583f7836ae88 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -493,7 +493,8 @@ class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts std.zexti to spv.Select if the type of source is i1.
+/// Converts std.zexti to spv.Select if the type of source is i1 or vector of
+/// i1.
 class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
 public:
   using SPIRVOpLowering<ZeroExtendIOp>::SPIRVOpLowering;
@@ -502,13 +503,21 @@ class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
   matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = operands.front().getType();
-    if (!srcType.isSignlessInteger() || srcType.getIntOrFloatBitWidth() != 1)
+    if (!isBoolScalarOrVector(srcType))
       return failure();
 
     auto dstType = this->typeConverter.convertType(op.getResult().getType());
     Location loc = op.getLoc();
-    Value zero = rewriter.create<ConstantIntOp>(loc, 0, dstType);
-    Value one = rewriter.create<ConstantIntOp>(loc, 1, dstType);
+    Attribute zeroAttr, oneAttr;
+    if (auto vectorType = dstType.dyn_cast<VectorType>()) {
+      zeroAttr = DenseElementsAttr::get(vectorType, 0);
+      oneAttr = DenseElementsAttr::get(vectorType, 1);
+    } else {
+      zeroAttr = IntegerAttr::get(dstType, 0);
+      oneAttr = IntegerAttr::get(dstType, 1);
+    }
+    Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
+    Value one = rewriter.create<ConstantOp>(loc, oneAttr);
     rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
         op, dstType, operands.front(), one, zero);
     return success();
@@ -526,7 +535,7 @@ class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
                   ConversionPatternRewriter &rewriter) const override {
     assert(operands.size() == 1);
     auto srcType = operands.front().getType();
-    if (srcType.isSignlessInteger() && srcType.getIntOrFloatBitWidth() == 1)
+    if (isBoolScalarOrVector(srcType))
       return failure();
     auto dstType =
         this->typeConverter.convertType(operation.getResult().getType());

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index ce38ba8b3f5e..5b62e54311b4 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -578,6 +578,15 @@ func @zexti3(%arg0 : i1) -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: @zexti4
+func @zexti4(%arg0 : vector<4xi1>) -> vector<4xi32> {
+  // CHECK: %[[ZERO:.+]] = spv.constant dense<0> : vector<4xi32>
+  // CHECK: %[[ONE:.+]] = spv.constant dense<1> : vector<4xi32>
+  // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi32>
+  %0 = std.zexti %arg0 : vector<4xi1> to vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
 // CHECK-LABEL: @trunci1
 func @trunci1(%arg0 : i64) -> i16 {
   // CHECK: spv.SConvert %{{.*}} : i64 to i16


        


More information about the Mlir-commits mailing list