[Mlir-commits] [mlir] 2cb130f - [mlir][StandardToSPIRV] Add support for lowering uitofp to SPIR-V
Hanhan Wang
llvmlistbot at llvm.org
Thu Jan 21 22:20:51 PST 2021
Author: Hanhan Wang
Date: 2021-01-21T22:20:32-08:00
New Revision: 2cb130f7661176f2c2eaa7554f2a55863cfc0ed3
URL: https://github.com/llvm/llvm-project/commit/2cb130f7661176f2c2eaa7554f2a55863cfc0ed3
DIFF: https://github.com/llvm/llvm-project/commit/2cb130f7661176f2c2eaa7554f2a55863cfc0ed3.diff
LOG: [mlir][StandardToSPIRV] Add support for lowering uitofp to SPIR-V
- Extend spirv::ConstantOp::getZero/One to handle float, vector of int, and vector of float.
- Refactor ZeroExtendI1Pattern to use getZero/One methods.
- Add one more test for lowering std.zexti which extends vector<4xi1> to vector<4xi64>.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D95120
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 72b8c5811695..95bb0eca4496 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -481,16 +481,32 @@ class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
auto dstType =
this->getTypeConverter()->convertType(op.getResult().getType());
Location loc = op.getLoc();
- Attribute zeroAttr, oneAttr;
- if (auto vectorType = dstType.dyn_cast<VectorType>()) {
- zeroAttr = DenseElementsAttr::get(vectorType, 0);
- oneAttr = DenseElementsAttr::get(vectorType, 1);
- } else {
- zeroAttr = IntegerAttr::get(dstType, 0);
- oneAttr = IntegerAttr::get(dstType, 1);
- }
- Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
- Value one = rewriter.create<ConstantOp>(loc, oneAttr);
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+ rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
+ op, dstType, operands.front(), one, zero);
+ return success();
+ }
+};
+
+/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
+/// i1.
+class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
+public:
+ using OpConversionPattern<UIToFPOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(UIToFPOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = operands.front().getType();
+ if (!isBoolScalarOrVector(srcType))
+ return failure();
+
+ auto dstType =
+ this->getTypeConverter()->convertType(op.getResult().getType());
+ Location loc = op.getLoc();
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
op, dstType, operands.front(), one, zero);
return success();
@@ -1098,8 +1114,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
ReturnOpPattern, SelectOpPattern,
// Type cast patterns
- ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
+ UIToFPI1Pattern, ZeroExtendI1Pattern,
+ TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
+ TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,
TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c90895197f43..3d99696d6882 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -25,6 +25,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CallInterfaces.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/bit.h"
@@ -1581,6 +1583,25 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
return builder.create<spirv::ConstantOp>(
loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
}
+ if (auto floatType = type.dyn_cast<FloatType>()) {
+ return builder.create<spirv::ConstantOp>(
+ loc, type, builder.getFloatAttr(floatType, 0.0));
+ }
+ if (auto vectorType = type.dyn_cast<VectorType>()) {
+ Type elemType = vectorType.getElementType();
+ if (elemType.isa<IntegerType>()) {
+ return builder.create<spirv::ConstantOp>(
+ loc, type,
+ DenseElementsAttr::get(vectorType,
+ IntegerAttr::get(elemType, 0.0).getValue()));
+ }
+ if (elemType.isa<FloatType>()) {
+ return builder.create<spirv::ConstantOp>(
+ loc, type,
+ DenseFPElementsAttr::get(vectorType,
+ FloatAttr::get(elemType, 0.0).getValue()));
+ }
+ }
llvm_unreachable("unimplemented types for ConstantOp::getZero()");
}
@@ -1595,6 +1616,25 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
return builder.create<spirv::ConstantOp>(
loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
}
+ if (auto floatType = type.dyn_cast<FloatType>()) {
+ return builder.create<spirv::ConstantOp>(
+ loc, type, builder.getFloatAttr(floatType, 1.0));
+ }
+ if (auto vectorType = type.dyn_cast<VectorType>()) {
+ Type elemType = vectorType.getElementType();
+ if (elemType.isa<IntegerType>()) {
+ return builder.create<spirv::ConstantOp>(
+ loc, type,
+ DenseElementsAttr::get(vectorType,
+ IntegerAttr::get(elemType, 1.0).getValue()));
+ }
+ if (elemType.isa<FloatType>()) {
+ return builder.create<spirv::ConstantOp>(
+ loc, type,
+ DenseFPElementsAttr::get(vectorType,
+ FloatAttr::get(elemType, 1.0).getValue()));
+ }
+ }
llvm_unreachable("unimplemented types for ConstantOp::getOne()");
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 633fdbc03550..252bc3eb5095 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -568,6 +568,58 @@ func @sitofp2(%arg0 : i64) -> f64 {
return %0 : f64
}
+// CHECK-LABEL: @uitofp_i16_f32
+func @uitofp_i16_f32(%arg0: i16) -> f32 {
+ // CHECK: spv.ConvertUToF %{{.*}} : i16 to f32
+ %0 = std.uitofp %arg0 : i16 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_i32_f32
+func @uitofp_i32_f32(%arg0 : i32) -> f32 {
+ // CHECK: spv.ConvertUToF %{{.*}} : i32 to f32
+ %0 = std.uitofp %arg0 : i32 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_i1_f32
+func @uitofp_i1_f32(%arg0 : i1) -> f32 {
+ // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f32
+ // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f32
+ // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f32
+ %0 = std.uitofp %arg0 : i1 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_i1_f64
+func @uitofp_i1_f64(%arg0 : i1) -> f64 {
+ // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f64
+ // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f64
+ // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f64
+ %0 = std.uitofp %arg0 : i1 to f64
+ return %0 : f64
+}
+
+// CHECK-LABEL: @uitofp_vec_i1_f32
+func @uitofp_vec_i1_f32(%arg0 : vector<4xi1>) -> vector<4xf32> {
+ // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf32>
+ // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf32>
+ // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf32>
+ %0 = std.uitofp %arg0 : vector<4xi1> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @uitofp_vec_i1_f64
+spv.func @uitofp_vec_i1_f64(%arg0: vector<4xi1>) -> vector<4xf64> "None" {
+ // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf64>
+ // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf64>
+ // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf64>
+ %0 = spv.constant dense<0.000000e+00> : vector<4xf64>
+ %1 = spv.constant dense<1.000000e+00> : vector<4xf64>
+ %2 = spv.Select %arg0, %1, %0 : vector<4xi1>, vector<4xf64>
+ spv.ReturnValue %2 : vector<4xf64>
+}
+
// CHECK-LABEL: @zexti1
func @zexti1(%arg0: i16) -> i64 {
// CHECK: spv.UConvert %{{.*}} : i16 to i64
@@ -600,6 +652,15 @@ func @zexti4(%arg0 : vector<4xi1>) -> vector<4xi32> {
return %0 : vector<4xi32>
}
+// CHECK-LABEL: @zexti5
+func @zexti5(%arg0 : vector<4xi1>) -> vector<4xi64> {
+ // CHECK: %[[ZERO:.+]] = spv.constant dense<0> : vector<4xi64>
+ // CHECK: %[[ONE:.+]] = spv.constant dense<1> : vector<4xi64>
+ // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi64>
+ %0 = std.zexti %arg0 : vector<4xi1> to vector<4xi64>
+ return %0 : vector<4xi64>
+}
+
// CHECK-LABEL: @trunci1
func @trunci1(%arg0 : i64) -> i16 {
// CHECK: spv.SConvert %{{.*}} : i64 to i16
More information about the Mlir-commits
mailing list