[Mlir-commits] [mlir] d93d8fc - [MLIR][SPIRVToLLVM] Implemented conversion for arithmetic ops and 3 bitwise ops.
Lei Zhang
llvmlistbot at llvm.org
Wed Jun 10 16:14:06 PDT 2020
Author: George Mitenkov
Date: 2020-06-10T19:10:31-04:00
New Revision: d93d8fcdec68211fd3ac7f586fa67bc065acef6a
URL: https://github.com/llvm/llvm-project/commit/d93d8fcdec68211fd3ac7f586fa67bc065acef6a
DIFF: https://github.com/llvm/llvm-project/commit/d93d8fcdec68211fd3ac7f586fa67bc065acef6a.diff
LOG: [MLIR][SPIRVToLLVM] Implemented conversion for arithmetic ops and 3 bitwise ops.
Following the previous revision `D81100`, this commit implements a templated class
that would provide conversion patterns for “straightforward” SPIR-V ops into
LLVM dialect. Templating allows to abstract away from concrete implementation
for each specific op. Those are mainly binary operations. Currently supported
and tested ops are:
- Arithmetic ops: `IAdd`, `ISub`, `IMul`, `FAdd`, `FSub`, `FMul`, `FDiv`, `FNegate`,
`SDiv`, `SRem` and `UDiv`
- Bitwise ops: `BitwiseAnd`, `BitwiseOr`, `BitwiseXor`
The implementation relies on `SPIRVToLLVMConversion` class that makes use of
`OpConversionPattern`.
Differential Revision: https://reviews.llvm.org/D81305
Added:
mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
Modified:
mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
Removed:
mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
index 2bf74d65b90d..e82efac3abe5 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
@@ -20,6 +20,18 @@ class LLVMTypeConverter;
class MLIRContext;
class ModuleOp;
+template <typename SPIRVOp>
+class SPIRVToLLVMConversion : public OpConversionPattern<SPIRVOp> {
+public:
+ SPIRVToLLVMConversion(MLIRContext *context, LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<SPIRVOp>(context, benefit),
+ typeConverter(typeConverter) {}
+
+protected:
+ LLVMTypeConverter &typeConverter;
+};
+
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
LLVMTypeConverter &typeConverter,
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 2ccd2431ac6e..4b056a553b76 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -24,30 +24,51 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Operation conversion
+//===----------------------------------------------------------------------===//
+
namespace {
-class BitwiseAndOpConversion : public ConvertToLLVMPattern {
+/// Converts SPIR-V operations that have straightforward LLVM equivalent
+/// into LLVM dialect operations.
+template <typename SPIRVOp, typename LLVMOp>
+class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
- explicit BitwiseAndOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(spirv::BitwiseAndOp::getOperationName(), context,
- typeConverter) {}
+ using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto bitwiseAndOp = cast<spirv::BitwiseAndOp>(op);
- auto dstType = typeConverter.convertType(bitwiseAndOp.getType());
+ auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::AndOp>(bitwiseAndOp, dstType, operands);
+ rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands);
return success();
}
};
} // namespace
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
void mlir::populateSPIRVToLLVMConversionPatterns(
MLIRContext *context, LLVMTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<BitwiseAndOpConversion>(context, typeConverter);
+ patterns.insert<DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
+ DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
+ DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
+ DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
+ DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
+ DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
+ DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
+ DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
+ DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
+ DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
+ DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
+ DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
+ DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
+ DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>>(
+ context, typeConverter);
}
diff --git a/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
new file mode 100644
index 000000000000..5e658c969a69
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
@@ -0,0 +1,177 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.IAdd
+//===----------------------------------------------------------------------===//
+
+func @iadd_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.IAdd %arg0, %arg1 : i32
+ return
+}
+
+func @iadd_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
+ // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %0 = spv.IAdd %arg0, %arg1 : vector<4xi64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ISub
+//===----------------------------------------------------------------------===//
+
+func @isub_scalar(%arg0: i8, %arg1: i8) {
+ // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i8
+ %0 = spv.ISub %arg0, %arg1 : i8
+ return
+}
+
+func @isub_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) {
+ // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm<"<2 x i16>">
+ %0 = spv.ISub %arg0, %arg1 : vector<2xi16>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+func @imul_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.IMul %arg0, %arg1 : i32
+ return
+}
+
+func @imul_vector(%arg0: vector<3xi32>, %arg1: vector<3xi32>) {
+ // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm<"<3 x i32>">
+ %0 = spv.IMul %arg0, %arg1 : vector<3xi32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FAdd
+//===----------------------------------------------------------------------===//
+
+func @fadd_scalar(%arg0: f16, %arg1: f16) {
+ // CHECK: %{{.*}} = llvm.fadd %{{.*}}, %{{.*}} : !llvm.half
+ %0 = spv.FAdd %arg0, %arg1 : f16
+ return
+}
+
+func @fadd_vector(%arg0: vector<4xf32>, %arg1: vector<4xf32>) {
+ // CHECK: %{{.*}} = llvm.fadd %{{.*}}, %{{.*}} : !llvm<"<4 x float>">
+ %0 = spv.FAdd %arg0, %arg1 : vector<4xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FSub
+//===----------------------------------------------------------------------===//
+
+func @fsub_scalar(%arg0: f32, %arg1: f32) {
+ // CHECK: %{{.*}} = llvm.fsub %{{.*}}, %{{.*}} : !llvm.float
+ %0 = spv.FSub %arg0, %arg1 : f32
+ return
+}
+
+func @fsub_vector(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
+ // CHECK: %{{.*}} = llvm.fsub %{{.*}}, %{{.*}} : !llvm<"<2 x float>">
+ %0 = spv.FSub %arg0, %arg1 : vector<2xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FDiv
+//===----------------------------------------------------------------------===//
+
+func @fdiv_scalar(%arg0: f32, %arg1: f32) {
+ // CHECK: %{{.*}} = llvm.fdiv %{{.*}}, %{{.*}} : !llvm.float
+ %0 = spv.FDiv %arg0, %arg1 : f32
+ return
+}
+
+func @fdiv_vector(%arg0: vector<3xf64>, %arg1: vector<3xf64>) {
+ // CHECK: %{{.*}} = llvm.fdiv %{{.*}}, %{{.*}} : !llvm<"<3 x double>">
+ %0 = spv.FDiv %arg0, %arg1 : vector<3xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FRem
+//===----------------------------------------------------------------------===//
+
+func @frem_scalar(%arg0: f32, %arg1: f32) {
+ // CHECK: %{{.*}} = llvm.frem %{{.*}}, %{{.*}} : !llvm.float
+ %0 = spv.FRem %arg0, %arg1 : f32
+ return
+}
+
+func @frem_vector(%arg0: vector<3xf64>, %arg1: vector<3xf64>) {
+ // CHECK: %{{.*}} = llvm.frem %{{.*}}, %{{.*}} : !llvm<"<3 x double>">
+ %0 = spv.FRem %arg0, %arg1 : vector<3xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FNegate
+//===----------------------------------------------------------------------===//
+
+func @fneg_scalar(%arg: f64) {
+ // CHECK: %{{.*}} = llvm.fneg %{{.*}} : !llvm.double
+ %0 = spv.FNegate %arg : f64
+ return
+}
+
+func @fneg_vector(%arg: vector<2xf32>) {
+ // CHECK: %{{.*}} = llvm.fneg %{{.*}} : !llvm<"<2 x float>">
+ %0 = spv.FNegate %arg : vector<2xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.UDiv
+//===----------------------------------------------------------------------===//
+
+func @udiv_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.udiv %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.UDiv %arg0, %arg1 : i32
+ return
+}
+
+func @udiv_vector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) {
+ // CHECK: %{{.*}} = llvm.udiv %{{.*}}, %{{.*}} : !llvm<"<3 x i64>">
+ %0 = spv.UDiv %arg0, %arg1 : vector<3xi64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SDiv
+//===----------------------------------------------------------------------===//
+
+func @sdiv_scalar(%arg0: i16, %arg1: i16) {
+ // CHECK: %{{.*}} = llvm.sdiv %{{.*}}, %{{.*}} : !llvm.i16
+ %0 = spv.SDiv %arg0, %arg1 : i16
+ return
+}
+
+func @sdiv_vector(%arg0: vector<2xi64>, %arg1: vector<2xi64>) {
+ // CHECK: %{{.*}} = llvm.sdiv %{{.*}}, %{{.*}} : !llvm<"<2 x i64>">
+ %0 = spv.SDiv %arg0, %arg1 : vector<2xi64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SRem
+//===----------------------------------------------------------------------===//
+
+func @srem_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.srem %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.SRem %arg0, %arg1 : i32
+ return
+}
+
+func @srem_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) {
+ // CHECK: %{{.*}} = llvm.srem %{{.*}}, %{{.*}} : !llvm<"<4 x i32>">
+ %0 = spv.SRem %arg0, %arg1 : vector<4xi32>
+ return
+}
diff --git a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
new file mode 100644
index 000000000000..13410400253e
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+func @bitwise_and_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.BitwiseAnd %arg0, %arg1 : i32
+ return
+}
+
+func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
+ // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+func @bitwise_or_scalar(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm.i64
+ %0 = spv.BitwiseOr %arg0, %arg1 : i64
+ return
+}
+
+func @bitwise_or_vector(%arg0: vector<3xi8>, %arg1: vector<3xi8>) {
+ // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm<"<3 x i8>">
+ %0 = spv.BitwiseOr %arg0, %arg1 : vector<3xi8>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+func @bitwise_xor_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.xor %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.BitwiseXor %arg0, %arg1 : i32
+ return
+}
+
+func @bitwise_xor_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) {
+ // CHECK: %{{.*}} = llvm.xor %{{.*}}, %{{.*}} : !llvm<"<2 x i16>">
+ %0 = spv.BitwiseXor %arg0, %arg1 : vector<2xi16>
+ return
+}
diff --git a/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir
deleted file mode 100644
index 326ef3a18854..000000000000
--- a/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir
+++ /dev/null
@@ -1,13 +0,0 @@
-// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
-
-func @bitwise_and_scalar(%arg0: i32, %arg1: i32) {
- // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32
- %0 = spv.BitwiseAnd %arg0, %arg1 : i32
- return
-}
-
-func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
- // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
- %0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64>
- return
-}
More information about the Mlir-commits
mailing list