[Mlir-commits] [mlir] cf2b4d5 - [MLIR][SPIRVToLLVM] Implemented shift conversion pattern
Lei Zhang
llvmlistbot at llvm.org
Fri Jun 12 16:08:19 PDT 2020
Author: George Mitenkov
Date: 2020-06-12T19:04:30-04:00
New Revision: cf2b4d5cb64cb0100473fca533edb51c2e6db6bf
URL: https://github.com/llvm/llvm-project/commit/cf2b4d5cb64cb0100473fca533edb51c2e6db6bf
DIFF: https://github.com/llvm/llvm-project/commit/cf2b4d5cb64cb0100473fca533edb51c2e6db6bf.diff
LOG: [MLIR][SPIRVToLLVM] Implemented shift conversion pattern
This patch has shift ops conversion implementation. In SPIR-V dialect,
`Shift` and `Base` may have different bit width. On the contrary,
in LLVM dialect both `Base` and `Shift` have to be of the same bit width.
This leads to the following cases:
- if `Base` has the same bit width as `Shift`, the conversion is
straightforward.
- if `Base` has a greater bit width than `Shift`, shift is sign/zero
extended first. Then the extended value is passed to the shift.
- otherwise the conversion is considered to be illegal.
Differential Revision: https://reviews.llvm.org/D81546
Added:
mlir/test/Conversion/SPIRVToLLVM/shifts-to-llvm.mlir
Modified:
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index f2942c17c0f5..ddaf1ca34861 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -24,6 +24,19 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the given type is an unsigned integer or vector type
+static bool isUnsignedIntegerOrVector(Type type) {
+ if (type.isUnsignedInteger())
+ return true;
+ if (auto vecType = type.dyn_cast<VectorType>())
+ return vecType.getElementType().isUnsignedInteger();
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@@ -91,6 +104,48 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
return success();
}
};
+
+/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
+/// puts a restriction on `Shift` and `Base` to have the same bit width,
+/// `Shift` is zero or sign extended to match this specification. Cases when
+/// `Shift` bit width > `Base` bit width are considered to be illegal.
+template <typename SPIRVOp, typename LLVMOp>
+class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
+public:
+ using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto dstType = this->typeConverter.convertType(operation.getType());
+ if (!dstType)
+ return failure();
+
+ Type op1Type = operation.operand1().getType();
+ Type op2Type = operation.operand2().getType();
+
+ if (op1Type == op2Type) {
+ rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
+ operands);
+ return success();
+ }
+
+ Location loc = operation.getLoc();
+ Value extended;
+ if (isUnsignedIntegerOrVector(op2Type)) {
+ extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
+ operation.operand2());
+ } else {
+ extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
+ operation.operand2());
+ }
+ Value result = rewriter.template create<LLVMOp>(
+ loc, dstType, operation.operand1(), extended);
+ rewriter.replaceOp(operation, result);
+ return success();
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -142,6 +197,11 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
- IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>>(
- context, typeConverter);
+ IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
+
+ // Shift ops
+ ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
+ ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
+ ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>>(context,
+ typeConverter);
}
diff --git a/mlir/test/Conversion/SPIRVToLLVM/shifts-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/shifts-to-llvm.mlir
new file mode 100644
index 000000000000..09e396d6400f
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/shifts-to-llvm.mlir
@@ -0,0 +1,115 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.ShiftRightArithmetic
+//===----------------------------------------------------------------------===//
+
+func @shift_right_arithmetic_scalar(%arg0: i32, %arg1: si32, %arg2 : i16, %arg3 : ui16) {
+ // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.ShiftRightArithmetic %arg0, %arg0 : i32, i32
+
+ // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : !llvm.i32
+ %1 = spv.ShiftRightArithmetic %arg0, %arg1 : i32, si32
+
+ // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm.i16 to !llvm.i32
+ // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %[[EXT1:.*]]: !llvm.i32
+ %2 = spv.ShiftRightArithmetic %arg0, %arg2 : i32, i16
+
+ // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm.i16 to !llvm.i32
+ // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %[[EXT2:.*]]: !llvm.i32
+ %3 = spv.ShiftRightArithmetic %arg0, %arg3 : i32, ui16
+ return
+}
+
+func @shift_right_arithmetic_vector(%arg0: vector<4xi64>, %arg1: vector<4xui64>, %arg2: vector<4xi32>, %arg3: vector<4xui32>) {
+ // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %0 = spv.ShiftRightArithmetic %arg0, %arg0 : vector<4xi64>, vector<4xi64>
+
+ // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %1 = spv.ShiftRightArithmetic %arg0, %arg1 : vector<4xi64>, vector<4xui64>
+
+ // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+ // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %[[EXT1:.*]]: !llvm<"<4 x i64>">
+ %2 = spv.ShiftRightArithmetic %arg0, %arg2 : vector<4xi64>, vector<4xi32>
+
+ // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+ // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %[[EXT2:.*]]: !llvm<"<4 x i64>">
+ %3 = spv.ShiftRightArithmetic %arg0, %arg3 : vector<4xi64>, vector<4xui32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ShiftRightLogical
+//===----------------------------------------------------------------------===//
+
+func @shift_right_logical_scalar(%arg0: i32, %arg1: si32, %arg2 : si16, %arg3 : ui16) {
+ // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.ShiftRightLogical %arg0, %arg0 : i32, i32
+
+ // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %{{.*}} : !llvm.i32
+ %1 = spv.ShiftRightLogical %arg0, %arg1 : i32, si32
+
+ // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm.i16 to !llvm.i32
+ // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %[[EXT1:.*]]: !llvm.i32
+ %2 = spv.ShiftRightLogical %arg0, %arg2 : i32, si16
+
+ // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm.i16 to !llvm.i32
+ // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %[[EXT2:.*]]: !llvm.i32
+ %3 = spv.ShiftRightLogical %arg0, %arg3 : i32, ui16
+ return
+}
+
+func @shift_right_logical_vector(%arg0: vector<4xi64>, %arg1: vector<4xsi64>, %arg2: vector<4xi32>, %arg3: vector<4xui32>) {
+ // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %0 = spv.ShiftRightLogical %arg0, %arg0 : vector<4xi64>, vector<4xi64>
+
+ // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %1 = spv.ShiftRightLogical %arg0, %arg1 : vector<4xi64>, vector<4xsi64>
+
+ // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+ // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %[[EXT1:.*]]: !llvm<"<4 x i64>">
+ %2 = spv.ShiftRightLogical %arg0, %arg2 : vector<4xi64>, vector<4xi32>
+
+ // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+ // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %[[EXT2:.*]]: !llvm<"<4 x i64>">
+ %3 = spv.ShiftRightLogical %arg0, %arg3 : vector<4xi64>, vector<4xui32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ShiftLeftLogical
+//===----------------------------------------------------------------------===//
+
+func @shift_left_logical_scalar(%arg0: i32, %arg1: si32, %arg2 : i16, %arg3 : ui16) {
+ // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.ShiftLeftLogical %arg0, %arg0 : i32, i32
+
+ // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} : !llvm.i32
+ %1 = spv.ShiftLeftLogical %arg0, %arg1 : i32, si32
+
+ // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm.i16 to !llvm.i32
+ // CHECK: %{{.*}} = llvm.shl %{{.*}}, %[[EXT1:.*]]: !llvm.i32
+ %2 = spv.ShiftLeftLogical %arg0, %arg2 : i32, i16
+
+ // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm.i16 to !llvm.i32
+ // CHECK: %{{.*}} = llvm.shl %{{.*}}, %[[EXT2:.*]]: !llvm.i32
+ %3 = spv.ShiftLeftLogical %arg0, %arg3 : i32, ui16
+ return
+}
+
+func @shift_left_logical_vector(%arg0: vector<4xi64>, %arg1: vector<4xsi64>, %arg2: vector<4xi32>, %arg3: vector<4xui32>) {
+ // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %0 = spv.ShiftLeftLogical %arg0, %arg0 : vector<4xi64>, vector<4xi64>
+
+ // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %1 = spv.ShiftLeftLogical %arg0, %arg1 : vector<4xi64>, vector<4xsi64>
+
+ // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+ // CHECK: %{{.*}} = llvm.shl %{{.*}}, %[[EXT1:.*]]: !llvm<"<4 x i64>">
+ %2 = spv.ShiftLeftLogical %arg0, %arg2 : vector<4xi64>, vector<4xi32>
+
+ // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+ // CHECK: %{{.*}} = llvm.shl %{{.*}}, %[[EXT2:.*]]: !llvm<"<4 x i64>">
+ %3 = spv.ShiftLeftLogical %arg0, %arg3 : vector<4xi64>, vector<4xui32>
+ return
+}
More information about the Mlir-commits
mailing list