[Mlir-commits] [mlir] 1cfaaf6 - [MLIR][SPIRVToLLVM] Convert spv.constant scalars and vectors
Lei Zhang
llvmlistbot at llvm.org
Thu Jul 2 11:27:24 PDT 2020
Author: George Mitenkov
Date: 2020-07-02T14:26:58-04:00
New Revision: 1cfaaf645528cc2fed79617c8ca80945a1504021
URL: https://github.com/llvm/llvm-project/commit/1cfaaf645528cc2fed79617c8ca80945a1504021
DIFF: https://github.com/llvm/llvm-project/commit/1cfaaf645528cc2fed79617c8ca80945a1504021.diff
LOG: [MLIR][SPIRVToLLVM] Convert spv.constant scalars and vectors
This patch introduces conversion pattern for `spv.constant` with scalar
and vector types. There is a special case when the constant value is a
signed/unsigned integer (vector of integers). Since LLVM dialect does not
have signedness semantics, the types had to be converted to signless ints.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D82936
Added:
mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir
Modified:
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 3cb8342ca287..1ead619518b2 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -31,6 +31,15 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+/// Returns true if the given type is a signed integer or vector type.
+static bool isSignedIntegerOrVector(Type type) {
+ if (type.isSignedInteger())
+ return true;
+ if (auto vecType = type.dyn_cast<VectorType>())
+ return vecType.getElementType().isSignedInteger();
+ return false;
+}
+
/// Returns true if the given type is an unsigned integer or vector type
static bool isUnsignedIntegerOrVector(Type type) {
if (type.isUnsignedInteger())
@@ -197,6 +206,52 @@ class BitFieldInsertPattern
}
};
+/// Converts SPIR-V ConstantOp with scalar or vector type.
+class ConstantScalarAndVectorPattern
+ : public SPIRVToLLVMConversion<spirv::ConstantOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = constOp.getType();
+ if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
+ return failure();
+
+ auto dstType = typeConverter.convertType(srcType);
+ if (!dstType)
+ return failure();
+
+ // SPIR-V constant can be a signed/unsigned integer, which has to be
+ // casted to signless integer when converting to LLVM dialect. Removing the
+ // sign bit may have unexpected behaviour. However, it is better to handle
+ // it case-by-case, given that the purpose of the conversion is not to
+ // cover all possible corner cases.
+ if (isSignedIntegerOrVector(srcType) ||
+ isUnsignedIntegerOrVector(srcType)) {
+ auto *context = rewriter.getContext();
+ auto signlessType = IntegerType::get(getBitWidth(srcType), context);
+
+ if (srcType.isa<VectorType>()) {
+ auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
+ rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
+ constOp, dstType,
+ dstElementsAttr.mapValues(
+ signlessType, [&](const APInt &value) { return value; }));
+ return success();
+ }
+ auto srcAttr = constOp.value().cast<IntegerAttr>();
+ auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
+ rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
+ return success();
+ }
+ rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
+ constOp.getAttrs());
+ return success();
+ }
+};
+
/// Converts SPIR-V operations that have straightforward LLVM equivalent
/// into LLVM dialect operations.
template <typename SPIRVOp, typename LLVMOp>
@@ -573,6 +628,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
+ // Constant op
+ ConstantScalarAndVectorPattern,
+
// Function Call op
FunctionCallPattern,
diff --git a/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir
new file mode 100644
index 000000000000..b9605e791857
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.constant
+//===----------------------------------------------------------------------===//
+
+func @bool_constant_scalar() {
+ // CHECK: {{.*}} = llvm.mlir.constant(true) : !llvm.i1
+ %0 = spv.constant true
+ // CHECK: {{.*}} = llvm.mlir.constant(false) : !llvm.i1
+ %1 = spv.constant false
+ return
+}
+
+func @bool_constant_vector() {
+ // CHECK: {{.*}} = llvm.mlir.constant(dense<[true, false]> : vector<2xi1>) : !llvm<"<2 x i1>">
+ %0 = constant dense<[true, false]> : vector<2xi1>
+ // CHECK: {{.*}} = llvm.mlir.constant(dense<false> : vector<3xi1>) : !llvm<"<3 x i1>">
+ %1 = constant dense<false> : vector<3xi1>
+ return
+}
+
+func @integer_constant_scalar() {
+ // CHECK: {{.*}} = llvm.mlir.constant(0 : i8) : !llvm.i8
+ %0 = spv.constant 0 : i8
+ // CHECK: {{.*}} = llvm.mlir.constant(-5 : i64) : !llvm.i64
+ %1 = spv.constant -5 : si64
+ // CHECK: {{.*}} = llvm.mlir.constant(10 : i16) : !llvm.i16
+ %2 = spv.constant 10 : ui16
+ return
+}
+
+func @integer_constant_vector() {
+ // CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3]> : vector<2xi32>) : !llvm<"<2 x i32>">
+ %0 = spv.constant dense<[2, 3]> : vector<2xi32>
+ // CHECK: {{.*}} = llvm.mlir.constant(dense<-4> : vector<2xi32>) : !llvm<"<2 x i32>">
+ %1 = spv.constant dense<-4> : vector<2xsi32>
+ // CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3, 4]> : vector<3xi32>) : !llvm<"<3 x i32>">
+ %2 = spv.constant dense<[2, 3, 4]> : vector<3xui32>
+ return
+}
+
+func @float_constant_scalar() {
+ // CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f16) : !llvm.half
+ %0 = spv.constant 5.000000e+00 : f16
+ // CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f64) : !llvm.double
+ %1 = spv.constant 5.000000e+00 : f64
+ return
+}
+
+func @float_constant_vector() {
+ // CHECK: {{.*}} = llvm.mlir.constant(dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>) : !llvm<"<2 x float>">
+ %0 = spv.constant dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>
+ return
+}
More information about the Mlir-commits
mailing list