[Mlir-commits] [mlir] 771b788 - [MLIR][SPIRVToLLVM] Support cast ops, some logical ops, UModOp

Lei Zhang llvmlistbot at llvm.org
Wed Jun 17 14:55:12 PDT 2020


Author: George Mitenkov
Date: 2020-06-17T17:46:45-04:00
New Revision: 771b7886872ec9f70b233554921c8e994e711cea

URL: https://github.com/llvm/llvm-project/commit/771b7886872ec9f70b233554921c8e994e711cea
DIFF: https://github.com/llvm/llvm-project/commit/771b7886872ec9f70b233554921c8e994e711cea.diff

LOG: [MLIR][SPIRVToLLVM] Support cast ops, some logical ops, UModOp

Added support of simple logical ops: `LogicalAnd`, `LogicalOr`,
`LogicalEqual` and `LogicalNotEqual`. Added a missing conversion
for `UMod` op.

Also, implemented SPIR-V cast ops conversion. There are 4 simple
case where there is a clear equivalent in LLVM (e.g. `ConvertFToS`
is `fptosi`). For `FConvert`, `SConvert` and `UConvert` we
distinguish between truncation and extension based on the bit
width of the operand.

Differential Revision: https://reviews.llvm.org/D81812

Added: 
    mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir
    mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir

Modified: 
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
    mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index ddaf1ca34861..85f6a113cff3 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -37,6 +37,19 @@ static bool isUnsignedIntegerOrVector(Type type) {
   return false;
 }
 
+/// Returns the bit width of integer, float or vector of float or integer values
+static unsigned getBitWidth(Type type) {
+  assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
+         "bitwidth is not supported for this type");
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+  auto vecType = type.dyn_cast<VectorType>();
+  auto elementType = vecType.getElementType();
+  assert(elementType.isIntOrFloat() &&
+         "only integers and floats have a bitwidth");
+  return elementType.getIntOrFloatBitWidth();
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -61,6 +74,38 @@ class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
   }
 };
 
+/// Converts SPIR-V cast ops that do not have straightforward LLVM
+/// equivalent in LLVM dialect.
+template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
+class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
+public:
+  using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type fromType = operation.operand().getType();
+    Type toType = operation.getType();
+
+    auto dstType = this->typeConverter.convertType(toType);
+    if (!dstType)
+      return failure();
+
+    if (getBitWidth(fromType) < getBitWidth(toType)) {
+      rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
+                                                      operands);
+      return success();
+    }
+    if (getBitWidth(fromType) > getBitWidth(toType)) {
+      rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
+                                                        operands);
+      return success();
+    }
+    return failure();
+  }
+};
+
 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
@@ -168,12 +213,22 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
+      DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
 
       // Bitwise ops
       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
 
+      // Cast ops
+      DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
+      DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
+      DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
+      DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
+      IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
+      IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
+      IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
+
       // Comparison ops
       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
@@ -199,6 +254,12 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
 
+      // Logical ops
+      DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
+      DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
+      IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
+      IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
+
       // Shift ops
       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
index 5e658c969a69..30cfb55b262a 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
@@ -144,6 +144,22 @@ func @udiv_vector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) {
 	return
 }
 
+//===----------------------------------------------------------------------===//
+// spv.UMod
+//===----------------------------------------------------------------------===//
+
+func @umod_scalar(%arg0: i32, %arg1: i32) {
+	// CHECK: %{{.*}} = llvm.urem %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.UMod %arg0, %arg1 : i32
+	return
+}
+
+func @umod_vector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) {
+	// CHECK: %{{.*}} = llvm.urem %{{.*}}, %{{.*}} : !llvm<"<3 x i64>">
+	%0 = spv.UMod %arg0, %arg1 : vector<3xi64>
+	return
+}
+
 //===----------------------------------------------------------------------===//
 // spv.SDiv
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir
new file mode 100644
index 000000000000..6a772245e869
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.ConvertFToS
+//===----------------------------------------------------------------------===//
+
+func @convert_float_to_signed_scalar(%arg0: f32) {
+	// CHECK: %{{.*}} = llvm.fptosi %{{.*}} : !llvm.float to !llvm.i32
+    %0 = spv.ConvertFToS %arg0: f32 to i32
+	return
+}
+
+func @convert_float_to_signed_vector(%arg0: vector<2xf32>) {
+	// CHECK: %{{.*}} = llvm.fptosi %{{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x i32>">
+    %0 = spv.ConvertFToS %arg0: vector<2xf32> to vector<2xi32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ConvertFToU
+//===----------------------------------------------------------------------===//
+
+func @convert_float_to_unsigned_scalar(%arg0: f32) {
+	// CHECK: %{{.*}} = llvm.fptoui %{{.*}} : !llvm.float to !llvm.i32
+    %0 = spv.ConvertFToU %arg0: f32 to i32
+	return
+}
+
+func @convert_float_to_unsigned_vector(%arg0: vector<2xf32>) {
+	// CHECK: %{{.*}} = llvm.fptoui %{{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x i32>">
+    %0 = spv.ConvertFToU %arg0: vector<2xf32> to vector<2xi32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ConvertSToF
+//===----------------------------------------------------------------------===//
+
+func @convert_signed_to_float_scalar(%arg0: i32) {
+	// CHECK: %{{.*}} = llvm.sitofp %{{.*}} : !llvm.i32 to !llvm.float
+    %0 = spv.ConvertSToF %arg0: i32 to f32
+	return
+}
+
+func @convert_signed_to_float_vector(%arg0: vector<3xi32>) {
+	// CHECK: %{{.*}} = llvm.sitofp %{{.*}} : !llvm<"<3 x i32>"> to !llvm<"<3 x float>">
+    %0 = spv.ConvertSToF %arg0: vector<3xi32> to vector<3xf32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ConvertUToF
+//===----------------------------------------------------------------------===//
+
+func @convert_unsigned_to_float_scalar(%arg0: i32) {
+	// CHECK: %{{.*}} = llvm.uitofp %{{.*}} : !llvm.i32 to !llvm.float
+    %0 = spv.ConvertUToF %arg0: i32 to f32
+	return
+}
+
+func @convert_unsigned_to_float_vector(%arg0: vector<3xi32>) {
+	// CHECK: %{{.*}} = llvm.uitofp %{{.*}} : !llvm<"<3 x i32>"> to !llvm<"<3 x float>">
+    %0 = spv.ConvertUToF %arg0: vector<3xi32> to vector<3xf32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FConvert
+//===----------------------------------------------------------------------===//
+
+func @fconvert_scalar(%arg0: f32, %arg1: f64) {
+	// CHECK: %{{.*}} = llvm.fpext %{{.*}} : !llvm.float to !llvm.double
+    %0 = spv.FConvert %arg0: f32 to f64
+
+    // CHECK: %{{.*}} = llvm.fptrunc %{{.*}} : !llvm.double to !llvm.float
+    %1 = spv.FConvert %arg1: f64 to f32
+	return
+}
+
+func @fconvert_vector(%arg0: vector<2xf32>, %arg1: vector<2xf64>) {
+	// CHECK: %{{.*}} = llvm.fpext %{{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x double>">
+    %0 = spv.FConvert %arg0: vector<2xf32> to vector<2xf64>
+
+    // CHECK: %{{.*}} = llvm.fptrunc %{{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x float>">
+    %1 = spv.FConvert %arg1: vector<2xf64> to vector<2xf32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SConvert
+//===----------------------------------------------------------------------===//
+
+func @sconvert_scalar(%arg0: i32, %arg1: i64) {
+	// CHECK: %{{.*}} = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64
+    %0 = spv.SConvert %arg0: i32 to i64
+
+    // CHECK: %{{.*}} = llvm.trunc %{{.*}} : !llvm.i64 to !llvm.i32
+    %1 = spv.SConvert %arg1: i64 to i32
+	return
+}
+
+func @sconvert_vector(%arg0: vector<3xi32>, %arg1: vector<3xi64>) {
+	// CHECK: %{{.*}} = llvm.sext %{{.*}} : !llvm<"<3 x i32>"> to !llvm<"<3 x i64>">
+    %0 = spv.SConvert %arg0: vector<3xi32> to vector<3xi64>
+
+    // CHECK: %{{.*}} = llvm.trunc %{{.*}} : !llvm<"<3 x i64>"> to !llvm<"<3 x i32>">
+    %1 = spv.SConvert %arg1: vector<3xi64> to vector<3xi32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.UConvert
+//===----------------------------------------------------------------------===//
+
+func @uconvert_scalar(%arg0: i32, %arg1: i64) {
+	// CHECK: %{{.*}} = llvm.zext %{{.*}} : !llvm.i32 to !llvm.i64
+    %0 = spv.UConvert %arg0: i32 to i64
+
+    // CHECK: %{{.*}} = llvm.trunc %{{.*}} : !llvm.i64 to !llvm.i32
+    %1 = spv.UConvert %arg1: i64 to i32
+	return
+}
+
+func @uconvert_vector(%arg0: vector<3xi32>, %arg1: vector<3xi64>) {
+	// CHECK: %{{.*}} = llvm.zext %{{.*}} : !llvm<"<3 x i32>"> to !llvm<"<3 x i64>">
+    %0 = spv.UConvert %arg0: vector<3xi32> to vector<3xi64>
+
+    // CHECK: %{{.*}} = llvm.trunc %{{.*}} : !llvm<"<3 x i64>"> to !llvm<"<3 x i32>">
+    %1 = spv.UConvert %arg1: vector<3xi64> to vector<3xi32>
+	return
+}

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir
new file mode 100644
index 000000000000..a6ff260711de
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.LogicalEqual
+//===----------------------------------------------------------------------===//
+
+func @logical_equal_scalar(%arg0: i1, %arg1: i1) {
+  // CHECK: %{{.*}} = llvm.icmp "eq" %{{.*}}, %{{.*}} : !llvm.i1
+	%0 = spv.LogicalEqual %arg0, %arg0 : i1
+  return
+}
+
+func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) {
+  // CHECK: %{{.*}} = llvm.icmp "eq" %{{.*}}, %{{.*}} : !llvm<"<4 x i1>">
+	%0 = spv.LogicalEqual %arg0, %arg0 : vector<4xi1>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.LogicalNotEqual
+//===----------------------------------------------------------------------===//
+
+func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) {
+  // CHECK: %{{.*}} = llvm.icmp "ne" %{{.*}}, %{{.*}} : !llvm.i1
+	%0 = spv.LogicalNotEqual %arg0, %arg0 : i1
+  return
+}
+
+func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) {
+  // CHECK: %{{.*}} = llvm.icmp "ne" %{{.*}}, %{{.*}} : !llvm<"<4 x i1>">
+	%0 = spv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.LogicalAnd
+//===----------------------------------------------------------------------===//
+
+func @logical_and_scalar(%arg0: i1, %arg1: i1) {
+  // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i1
+	%0 = spv.LogicalAnd %arg0, %arg0 : i1
+  return
+}
+
+func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) {
+  // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i1>">
+	%0 = spv.LogicalAnd %arg0, %arg0 : vector<4xi1>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.LogicalOr
+//===----------------------------------------------------------------------===//
+
+func @logical_or_scalar(%arg0: i1, %arg1: i1) {
+  // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm.i1
+	%0 = spv.LogicalOr %arg0, %arg0 : i1
+  return
+}
+
+func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) {
+  // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm<"<4 x i1>">
+	%0 = spv.LogicalOr %arg0, %arg0 : vector<4xi1>
+  return
+}


        


More information about the Mlir-commits mailing list