[Mlir-commits] [mlir] d93d8fc - [MLIR][SPIRVToLLVM] Implemented conversion for arithmetic ops and 3 bitwise ops.

Lei Zhang llvmlistbot at llvm.org
Wed Jun 10 16:14:06 PDT 2020


Author: George Mitenkov
Date: 2020-06-10T19:10:31-04:00
New Revision: d93d8fcdec68211fd3ac7f586fa67bc065acef6a

URL: https://github.com/llvm/llvm-project/commit/d93d8fcdec68211fd3ac7f586fa67bc065acef6a
DIFF: https://github.com/llvm/llvm-project/commit/d93d8fcdec68211fd3ac7f586fa67bc065acef6a.diff

LOG: [MLIR][SPIRVToLLVM] Implemented conversion for arithmetic ops and 3 bitwise ops.

Following the previous revision `D81100`, this commit implements a templated class
that would provide conversion patterns for “straightforward” SPIR-V ops into
LLVM dialect. Templating allows to abstract away from concrete implementation
for each specific op. Those are mainly binary operations. Currently supported
and tested ops are:
- Arithmetic ops: `IAdd`, `ISub`, `IMul`, `FAdd`, `FSub`, `FMul`, `FDiv`,  `FNegate`,
  `SDiv`, `SRem` and  `UDiv`
- Bitwise ops: `BitwiseAnd`, `BitwiseOr`, `BitwiseXor`

The implementation relies on `SPIRVToLLVMConversion` class that makes use of
`OpConversionPattern`.

Differential Revision: https://reviews.llvm.org/D81305

Added: 
    mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
    mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir

Modified: 
    mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp

Removed: 
    mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
index 2bf74d65b90d..e82efac3abe5 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
@@ -20,6 +20,18 @@ class LLVMTypeConverter;
 class MLIRContext;
 class ModuleOp;
 
+template <typename SPIRVOp>
+class SPIRVToLLVMConversion : public OpConversionPattern<SPIRVOp> {
+public:
+  SPIRVToLLVMConversion(MLIRContext *context, LLVMTypeConverter &typeConverter,
+                        PatternBenefit benefit = 1)
+      : OpConversionPattern<SPIRVOp>(context, benefit),
+        typeConverter(typeConverter) {}
+
+protected:
+  LLVMTypeConverter &typeConverter;
+};
+
 /// Populates the given list with patterns that convert from SPIR-V to LLVM.
 void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
                                            LLVMTypeConverter &typeConverter,

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 2ccd2431ac6e..4b056a553b76 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -24,30 +24,51 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Operation conversion
+//===----------------------------------------------------------------------===//
+
 namespace {
 
-class BitwiseAndOpConversion : public ConvertToLLVMPattern {
+/// Converts SPIR-V operations that have straightforward LLVM equivalent
+/// into LLVM dialect operations.
+template <typename SPIRVOp, typename LLVMOp>
+class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
 public:
-  explicit BitwiseAndOpConversion(MLIRContext *context,
-                                  LLVMTypeConverter &typeConverter)
-      : ConvertToLLVMPattern(spirv::BitwiseAndOp::getOperationName(), context,
-                             typeConverter) {}
+  using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto bitwiseAndOp = cast<spirv::BitwiseAndOp>(op);
-    auto dstType = typeConverter.convertType(bitwiseAndOp.getType());
+    auto dstType = this->typeConverter.convertType(operation.getType());
     if (!dstType)
       return failure();
-    rewriter.replaceOpWithNewOp<LLVM::AndOp>(bitwiseAndOp, dstType, operands);
+    rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands);
     return success();
   }
 };
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
 void mlir::populateSPIRVToLLVMConversionPatterns(
     MLIRContext *context, LLVMTypeConverter &typeConverter,
     OwningRewritePatternList &patterns) {
-  patterns.insert<BitwiseAndOpConversion>(context, typeConverter);
+  patterns.insert<DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
+                  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
+                  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
+                  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
+                  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
+                  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
+                  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
+                  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
+                  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
+                  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
+                  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
+                  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
+                  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
+                  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>>(
+      context, typeConverter);
 }

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
new file mode 100644
index 000000000000..5e658c969a69
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir
@@ -0,0 +1,177 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.IAdd
+//===----------------------------------------------------------------------===//
+
+func @iadd_scalar(%arg0: i32, %arg1: i32) {
+	// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.IAdd %arg0, %arg1 : i32
+	return
+}
+
+func @iadd_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
+	// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+	%0 = spv.IAdd %arg0, %arg1 : vector<4xi64>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ISub
+//===----------------------------------------------------------------------===//
+
+func @isub_scalar(%arg0: i8, %arg1: i8) {
+	// CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i8
+	%0 = spv.ISub %arg0, %arg1 : i8
+	return
+}
+
+func @isub_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) {
+	// CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm<"<2 x i16>">
+	%0 = spv.ISub %arg0, %arg1 : vector<2xi16>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+func @imul_scalar(%arg0: i32, %arg1: i32) {
+	// CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.IMul %arg0, %arg1 : i32
+	return
+}
+
+func @imul_vector(%arg0: vector<3xi32>, %arg1: vector<3xi32>) {
+	// CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm<"<3 x i32>">
+	%0 = spv.IMul %arg0, %arg1 : vector<3xi32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FAdd
+//===----------------------------------------------------------------------===//
+
+func @fadd_scalar(%arg0: f16, %arg1: f16) {
+	// CHECK: %{{.*}} = llvm.fadd %{{.*}}, %{{.*}} : !llvm.half
+	%0 = spv.FAdd %arg0, %arg1 : f16
+	return
+}
+
+func @fadd_vector(%arg0: vector<4xf32>, %arg1: vector<4xf32>) {
+	// CHECK: %{{.*}} = llvm.fadd %{{.*}}, %{{.*}} : !llvm<"<4 x float>">
+	%0 = spv.FAdd %arg0, %arg1 : vector<4xf32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FSub
+//===----------------------------------------------------------------------===//
+
+func @fsub_scalar(%arg0: f32, %arg1: f32) {
+	// CHECK: %{{.*}} = llvm.fsub %{{.*}}, %{{.*}} : !llvm.float
+	%0 = spv.FSub %arg0, %arg1 : f32
+	return
+}
+
+func @fsub_vector(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
+	// CHECK: %{{.*}} = llvm.fsub %{{.*}}, %{{.*}} : !llvm<"<2 x float>">
+	%0 = spv.FSub %arg0, %arg1 : vector<2xf32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FDiv
+//===----------------------------------------------------------------------===//
+
+func @fdiv_scalar(%arg0: f32, %arg1: f32) {
+	// CHECK: %{{.*}} = llvm.fdiv %{{.*}}, %{{.*}} : !llvm.float
+	%0 = spv.FDiv %arg0, %arg1 : f32
+	return
+}
+
+func @fdiv_vector(%arg0: vector<3xf64>, %arg1: vector<3xf64>) {
+	// CHECK: %{{.*}} = llvm.fdiv %{{.*}}, %{{.*}} : !llvm<"<3 x double>">
+	%0 = spv.FDiv %arg0, %arg1 : vector<3xf64>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FRem
+//===----------------------------------------------------------------------===//
+
+func @frem_scalar(%arg0: f32, %arg1: f32) {
+	// CHECK: %{{.*}} = llvm.frem %{{.*}}, %{{.*}} : !llvm.float
+	%0 = spv.FRem %arg0, %arg1 : f32
+	return
+}
+
+func @frem_vector(%arg0: vector<3xf64>, %arg1: vector<3xf64>) {
+	// CHECK: %{{.*}} = llvm.frem %{{.*}}, %{{.*}} : !llvm<"<3 x double>">
+	%0 = spv.FRem %arg0, %arg1 : vector<3xf64>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FNegate
+//===----------------------------------------------------------------------===//
+
+func @fneg_scalar(%arg: f64) {
+	// CHECK: %{{.*}} = llvm.fneg %{{.*}} : !llvm.double
+	%0 = spv.FNegate %arg : f64
+	return
+}
+
+func @fneg_vector(%arg: vector<2xf32>) {
+	// CHECK: %{{.*}} = llvm.fneg %{{.*}} : !llvm<"<2 x float>">
+	%0 = spv.FNegate %arg : vector<2xf32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.UDiv
+//===----------------------------------------------------------------------===//
+
+func @udiv_scalar(%arg0: i32, %arg1: i32) {
+	// CHECK: %{{.*}} = llvm.udiv %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.UDiv %arg0, %arg1 : i32
+	return
+}
+
+func @udiv_vector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) {
+	// CHECK: %{{.*}} = llvm.udiv %{{.*}}, %{{.*}} : !llvm<"<3 x i64>">
+	%0 = spv.UDiv %arg0, %arg1 : vector<3xi64>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SDiv
+//===----------------------------------------------------------------------===//
+
+func @sdiv_scalar(%arg0: i16, %arg1: i16) {
+	// CHECK: %{{.*}} = llvm.sdiv %{{.*}}, %{{.*}} : !llvm.i16
+	%0 = spv.SDiv %arg0, %arg1 : i16
+	return
+}
+
+func @sdiv_vector(%arg0: vector<2xi64>, %arg1: vector<2xi64>) {
+	// CHECK: %{{.*}} = llvm.sdiv %{{.*}}, %{{.*}} : !llvm<"<2 x i64>">
+	%0 = spv.SDiv %arg0, %arg1 : vector<2xi64>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SRem
+//===----------------------------------------------------------------------===//
+
+func @srem_scalar(%arg0: i32, %arg1: i32) {
+	// CHECK: %{{.*}} = llvm.srem %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.SRem %arg0, %arg1 : i32
+	return
+}
+
+func @srem_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) {
+	// CHECK: %{{.*}} = llvm.srem %{{.*}}, %{{.*}} : !llvm<"<4 x i32>">
+	%0 = spv.SRem %arg0, %arg1 : vector<4xi32>
+	return
+}

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
new file mode 100644
index 000000000000..13410400253e
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+func @bitwise_and_scalar(%arg0: i32, %arg1: i32) {
+	// CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.BitwiseAnd %arg0, %arg1 : i32
+	return
+}
+
+func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
+	// CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+	%0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+func @bitwise_or_scalar(%arg0: i64, %arg1: i64) {
+	// CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm.i64
+	%0 = spv.BitwiseOr %arg0, %arg1 : i64
+	return
+}
+
+func @bitwise_or_vector(%arg0: vector<3xi8>, %arg1: vector<3xi8>) {
+	// CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm<"<3 x i8>">
+	%0 = spv.BitwiseOr %arg0, %arg1 : vector<3xi8>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+func @bitwise_xor_scalar(%arg0: i32, %arg1: i32) {
+	// CHECK: %{{.*}} = llvm.xor %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.BitwiseXor %arg0, %arg1 : i32
+	return
+}
+
+func @bitwise_xor_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) {
+	// CHECK: %{{.*}} = llvm.xor %{{.*}}, %{{.*}} : !llvm<"<2 x i16>">
+	%0 = spv.BitwiseXor %arg0, %arg1 : vector<2xi16>
+	return
+}

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir
deleted file mode 100644
index 326ef3a18854..000000000000
--- a/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir
+++ /dev/null
@@ -1,13 +0,0 @@
-// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
-
-func @bitwise_and_scalar(%arg0: i32, %arg1: i32) {
-	// CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32
-	%0 = spv.BitwiseAnd %arg0, %arg1 : i32
-	return
-}
-
-func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
-	// CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
-	%0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64>
-	return
-}


        


More information about the Mlir-commits mailing list