[Mlir-commits] [mlir] 1909b6a - [mlir][StandardToSPIRV] Handle vector of i1 case for lowering zexti to SPIR-V.
Hanhan Wang
llvmlistbot at llvm.org
Fri Sep 18 07:07:34 PDT 2020
Author: Hanhan Wang
Date: 2020-09-18T07:07:22-07:00
New Revision: 1909b6ac0dbc2f1306103a5ea7f5e59f2232b133
URL: https://github.com/llvm/llvm-project/commit/1909b6ac0dbc2f1306103a5ea7f5e59f2232b133
DIFF: https://github.com/llvm/llvm-project/commit/1909b6ac0dbc2f1306103a5ea7f5e59f2232b133.diff
LOG: [mlir][StandardToSPIRV] Handle vector of i1 case for lowering zexti to SPIR-V.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D87887
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 6ae17c33070c..583f7836ae88 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -493,7 +493,8 @@ class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Converts std.zexti to spv.Select if the type of source is i1.
+/// Converts std.zexti to spv.Select if the type of source is i1 or vector of
+/// i1.
class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
public:
using SPIRVOpLowering<ZeroExtendIOp>::SPIRVOpLowering;
@@ -502,13 +503,21 @@ class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto srcType = operands.front().getType();
- if (!srcType.isSignlessInteger() || srcType.getIntOrFloatBitWidth() != 1)
+ if (!isBoolScalarOrVector(srcType))
return failure();
auto dstType = this->typeConverter.convertType(op.getResult().getType());
Location loc = op.getLoc();
- Value zero = rewriter.create<ConstantIntOp>(loc, 0, dstType);
- Value one = rewriter.create<ConstantIntOp>(loc, 1, dstType);
+ Attribute zeroAttr, oneAttr;
+ if (auto vectorType = dstType.dyn_cast<VectorType>()) {
+ zeroAttr = DenseElementsAttr::get(vectorType, 0);
+ oneAttr = DenseElementsAttr::get(vectorType, 1);
+ } else {
+ zeroAttr = IntegerAttr::get(dstType, 0);
+ oneAttr = IntegerAttr::get(dstType, 1);
+ }
+ Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
+ Value one = rewriter.create<ConstantOp>(loc, oneAttr);
rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
op, dstType, operands.front(), one, zero);
return success();
@@ -526,7 +535,7 @@ class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1);
auto srcType = operands.front().getType();
- if (srcType.isSignlessInteger() && srcType.getIntOrFloatBitWidth() == 1)
+ if (isBoolScalarOrVector(srcType))
return failure();
auto dstType =
this->typeConverter.convertType(operation.getResult().getType());
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index ce38ba8b3f5e..5b62e54311b4 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -578,6 +578,15 @@ func @zexti3(%arg0 : i1) -> i32 {
return %0 : i32
}
+// CHECK-LABEL: @zexti4
+func @zexti4(%arg0 : vector<4xi1>) -> vector<4xi32> {
+ // CHECK: %[[ZERO:.+]] = spv.constant dense<0> : vector<4xi32>
+ // CHECK: %[[ONE:.+]] = spv.constant dense<1> : vector<4xi32>
+ // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi32>
+ %0 = std.zexti %arg0 : vector<4xi1> to vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
// CHECK-LABEL: @trunci1
func @trunci1(%arg0 : i64) -> i16 {
// CHECK: spv.SConvert %{{.*}} : i64 to i16
More information about the Mlir-commits
mailing list