[Mlir-commits] [mlir] 1cfaaf6 - [MLIR][SPIRVToLLVM] Convert spv.constant scalars and vectors

Lei Zhang llvmlistbot at llvm.org
Thu Jul 2 11:27:24 PDT 2020


Author: George Mitenkov
Date: 2020-07-02T14:26:58-04:00
New Revision: 1cfaaf645528cc2fed79617c8ca80945a1504021

URL: https://github.com/llvm/llvm-project/commit/1cfaaf645528cc2fed79617c8ca80945a1504021
DIFF: https://github.com/llvm/llvm-project/commit/1cfaaf645528cc2fed79617c8ca80945a1504021.diff

LOG: [MLIR][SPIRVToLLVM] Convert spv.constant scalars and vectors

This patch introduces conversion pattern for `spv.constant` with scalar
and vector types. There is a special case when the constant value is a
signed/unsigned integer (vector of integers). Since LLVM dialect does not
have signedness semantics, the types had to be converted to signless ints.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D82936

Added: 
    mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir

Modified: 
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 3cb8342ca287..1ead619518b2 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -31,6 +31,15 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+/// Returns true if the given type is a signed integer or vector type.
+static bool isSignedIntegerOrVector(Type type) {
+  if (type.isSignedInteger())
+    return true;
+  if (auto vecType = type.dyn_cast<VectorType>())
+    return vecType.getElementType().isSignedInteger();
+  return false;
+}
+
 /// Returns true if the given type is an unsigned integer or vector type
 static bool isUnsignedIntegerOrVector(Type type) {
   if (type.isUnsignedInteger())
@@ -197,6 +206,52 @@ class BitFieldInsertPattern
   }
 };
 
+/// Converts SPIR-V ConstantOp with scalar or vector type.
+class ConstantScalarAndVectorPattern
+    : public SPIRVToLLVMConversion<spirv::ConstantOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = constOp.getType();
+    if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
+      return failure();
+
+    auto dstType = typeConverter.convertType(srcType);
+    if (!dstType)
+      return failure();
+
+    // SPIR-V constant can be a signed/unsigned integer, which has to be
+    // casted to signless integer when converting to LLVM dialect. Removing the
+    // sign bit may have unexpected behaviour. However, it is better to handle
+    // it case-by-case, given that the purpose of the conversion is not to
+    // cover all possible corner cases.
+    if (isSignedIntegerOrVector(srcType) ||
+        isUnsignedIntegerOrVector(srcType)) {
+      auto *context = rewriter.getContext();
+      auto signlessType = IntegerType::get(getBitWidth(srcType), context);
+
+      if (srcType.isa<VectorType>()) {
+        auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
+        rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
+            constOp, dstType,
+            dstElementsAttr.mapValues(
+                signlessType, [&](const APInt &value) { return value; }));
+        return success();
+      }
+      auto srcAttr = constOp.value().cast<IntegerAttr>();
+      auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
+      rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
+      return success();
+    }
+    rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
+                                                  constOp.getAttrs());
+    return success();
+  }
+};
+
 /// Converts SPIR-V operations that have straightforward LLVM equivalent
 /// into LLVM dialect operations.
 template <typename SPIRVOp, typename LLVMOp>
@@ -573,6 +628,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
 
+      // Constant op
+      ConstantScalarAndVectorPattern,
+
       // Function Call op
       FunctionCallPattern,
 

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir
new file mode 100644
index 000000000000..b9605e791857
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.constant
+//===----------------------------------------------------------------------===//
+
+func @bool_constant_scalar() {
+  // CHECK: {{.*}} = llvm.mlir.constant(true) : !llvm.i1
+  %0 = spv.constant true
+  // CHECK: {{.*}} = llvm.mlir.constant(false) : !llvm.i1
+  %1 = spv.constant false
+  return
+}
+
+func @bool_constant_vector() {
+  // CHECK: {{.*}} = llvm.mlir.constant(dense<[true, false]> : vector<2xi1>) : !llvm<"<2 x i1>">
+  %0 = constant dense<[true, false]> : vector<2xi1>
+  // CHECK: {{.*}} = llvm.mlir.constant(dense<false> : vector<3xi1>) : !llvm<"<3 x i1>">
+  %1 = constant dense<false> : vector<3xi1>
+  return
+}
+
+func @integer_constant_scalar() {
+  // CHECK: {{.*}} = llvm.mlir.constant(0 : i8) : !llvm.i8
+  %0 = spv.constant  0 : i8
+  // CHECK: {{.*}} = llvm.mlir.constant(-5 : i64) : !llvm.i64
+  %1 = spv.constant -5 : si64
+  // CHECK: {{.*}} = llvm.mlir.constant(10 : i16) : !llvm.i16
+  %2 = spv.constant  10 : ui16
+  return
+}
+
+func @integer_constant_vector() {
+  // CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3]> : vector<2xi32>) : !llvm<"<2 x i32>">
+  %0 = spv.constant dense<[2, 3]> : vector<2xi32>
+  // CHECK: {{.*}} = llvm.mlir.constant(dense<-4> : vector<2xi32>) : !llvm<"<2 x i32>">
+  %1 = spv.constant dense<-4> : vector<2xsi32>
+  // CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3, 4]> : vector<3xi32>) : !llvm<"<3 x i32>">
+  %2 = spv.constant dense<[2, 3, 4]> : vector<3xui32>
+  return
+}
+
+func @float_constant_scalar() {
+  // CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f16) : !llvm.half
+  %0 = spv.constant 5.000000e+00 : f16
+  // CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f64) : !llvm.double
+  %1 = spv.constant 5.000000e+00 : f64
+  return
+}
+
+func @float_constant_vector() {
+  // CHECK: {{.*}} = llvm.mlir.constant(dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>) : !llvm<"<2 x float>">
+  %0 = spv.constant dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>
+  return
+}


        


More information about the Mlir-commits mailing list