[Mlir-commits] [mlir] cf2b4d5 - [MLIR][SPIRVToLLVM] Implemented shift conversion pattern

Lei Zhang llvmlistbot at llvm.org
Fri Jun 12 16:08:19 PDT 2020


Author: George Mitenkov
Date: 2020-06-12T19:04:30-04:00
New Revision: cf2b4d5cb64cb0100473fca533edb51c2e6db6bf

URL: https://github.com/llvm/llvm-project/commit/cf2b4d5cb64cb0100473fca533edb51c2e6db6bf
DIFF: https://github.com/llvm/llvm-project/commit/cf2b4d5cb64cb0100473fca533edb51c2e6db6bf.diff

LOG: [MLIR][SPIRVToLLVM] Implemented shift conversion pattern

This patch has shift ops conversion implementation. In SPIR-V dialect,
`Shift` and `Base` may have different bit width. On the contrary,
in LLVM dialect both `Base` and `Shift` have to be of the same bit width.
This leads to the following cases:
- if `Base` has the same bit width as `Shift`, the conversion is
  straightforward.
- if `Base` has a greater bit width than `Shift`, shift is sign/zero
  extended first. Then the extended value is passed to the shift.
- otherwise the conversion is considered to be illegal.

Differential Revision: https://reviews.llvm.org/D81546

Added: 
    mlir/test/Conversion/SPIRVToLLVM/shifts-to-llvm.mlir

Modified: 
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index f2942c17c0f5..ddaf1ca34861 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -24,6 +24,19 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the given type is an unsigned integer or vector type
+static bool isUnsignedIntegerOrVector(Type type) {
+  if (type.isUnsignedInteger())
+    return true;
+  if (auto vecType = type.dyn_cast<VectorType>())
+    return vecType.getElementType().isUnsignedInteger();
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -91,6 +104,48 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
     return success();
   }
 };
+
+/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
+/// puts a restriction on `Shift` and `Base` to have the same bit width,
+/// `Shift` is zero or sign extended to match this specification. Cases when
+/// `Shift` bit width > `Base` bit width are considered to be illegal.
+template <typename SPIRVOp, typename LLVMOp>
+class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
+public:
+  using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto dstType = this->typeConverter.convertType(operation.getType());
+    if (!dstType)
+      return failure();
+
+    Type op1Type = operation.operand1().getType();
+    Type op2Type = operation.operand2().getType();
+
+    if (op1Type == op2Type) {
+      rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
+                                                   operands);
+      return success();
+    }
+
+    Location loc = operation.getLoc();
+    Value extended;
+    if (isUnsignedIntegerOrVector(op2Type)) {
+      extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
+                                                        operation.operand2());
+    } else {
+      extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
+                                                        operation.operand2());
+    }
+    Value result = rewriter.template create<LLVMOp>(
+        loc, dstType, operation.operand1(), extended);
+    rewriter.replaceOp(operation, result);
+    return success();
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -142,6 +197,11 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
-      IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>>(
-      context, typeConverter);
+      IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
+
+      // Shift ops
+      ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
+      ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
+      ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>>(context,
+                                                            typeConverter);
 }

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/shifts-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/shifts-to-llvm.mlir
new file mode 100644
index 000000000000..09e396d6400f
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/shifts-to-llvm.mlir
@@ -0,0 +1,115 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.ShiftRightArithmetic
+//===----------------------------------------------------------------------===//
+
+func @shift_right_arithmetic_scalar(%arg0: i32, %arg1: si32, %arg2 : i16, %arg3 : ui16) {
+	// CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.ShiftRightArithmetic %arg0, %arg0 : i32, i32
+
+    // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : !llvm.i32
+	%1 = spv.ShiftRightArithmetic %arg0, %arg1 : i32, si32
+
+    // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm.i16 to !llvm.i32
+    // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %[[EXT1:.*]]: !llvm.i32
+	%2 = spv.ShiftRightArithmetic %arg0, %arg2 : i32, i16
+
+    // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm.i16 to !llvm.i32
+    // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %[[EXT2:.*]]: !llvm.i32
+	%3 = spv.ShiftRightArithmetic %arg0, %arg3 : i32, ui16
+	return
+}
+
+func @shift_right_arithmetic_vector(%arg0: vector<4xi64>, %arg1: vector<4xui64>, %arg2: vector<4xi32>, %arg3: vector<4xui32>) {
+	// CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+	%0 = spv.ShiftRightArithmetic %arg0, %arg0 : vector<4xi64>, vector<4xi64>
+
+    // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+	%1 = spv.ShiftRightArithmetic %arg0, %arg1 : vector<4xi64>, vector<4xui64>
+
+    // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+    // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %[[EXT1:.*]]: !llvm<"<4 x i64>">
+	%2 = spv.ShiftRightArithmetic %arg0, %arg2 : vector<4xi64>,  vector<4xi32>
+
+    // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+    // CHECK: %{{.*}} = llvm.ashr %{{.*}}, %[[EXT2:.*]]: !llvm<"<4 x i64>">
+	%3 = spv.ShiftRightArithmetic %arg0, %arg3 : vector<4xi64>, vector<4xui32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ShiftRightLogical
+//===----------------------------------------------------------------------===//
+
+func @shift_right_logical_scalar(%arg0: i32, %arg1: si32, %arg2 : si16, %arg3 : ui16) {
+	// CHECK: %{{.*}} = llvm.lshr %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.ShiftRightLogical %arg0, %arg0 : i32, i32
+
+    // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %{{.*}} : !llvm.i32
+	%1 = spv.ShiftRightLogical %arg0, %arg1 : i32, si32
+
+    // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm.i16 to !llvm.i32
+    // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %[[EXT1:.*]]: !llvm.i32
+	%2 = spv.ShiftRightLogical %arg0, %arg2 : i32, si16
+
+    // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm.i16 to !llvm.i32
+    // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %[[EXT2:.*]]: !llvm.i32
+	%3 = spv.ShiftRightLogical %arg0, %arg3 : i32, ui16
+	return
+}
+
+func @shift_right_logical_vector(%arg0: vector<4xi64>, %arg1: vector<4xsi64>, %arg2: vector<4xi32>, %arg3: vector<4xui32>) {
+	// CHECK: %{{.*}} = llvm.lshr %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+	%0 = spv.ShiftRightLogical %arg0, %arg0 : vector<4xi64>, vector<4xi64>
+
+    // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+	%1 = spv.ShiftRightLogical %arg0, %arg1 : vector<4xi64>, vector<4xsi64>
+
+    // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+    // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %[[EXT1:.*]]: !llvm<"<4 x i64>">
+	%2 = spv.ShiftRightLogical %arg0, %arg2 : vector<4xi64>,  vector<4xi32>
+
+    // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+    // CHECK: %{{.*}} = llvm.lshr %{{.*}}, %[[EXT2:.*]]: !llvm<"<4 x i64>">
+	%3 = spv.ShiftRightLogical %arg0, %arg3 : vector<4xi64>, vector<4xui32>
+	return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ShiftLeftLogical
+//===----------------------------------------------------------------------===//
+
+func @shift_left_logical_scalar(%arg0: i32, %arg1: si32, %arg2 : i16, %arg3 : ui16) {
+	// CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} : !llvm.i32
+	%0 = spv.ShiftLeftLogical %arg0, %arg0 : i32, i32
+
+    // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} : !llvm.i32
+	%1 = spv.ShiftLeftLogical %arg0, %arg1 : i32, si32
+
+    // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm.i16 to !llvm.i32
+    // CHECK: %{{.*}} = llvm.shl %{{.*}}, %[[EXT1:.*]]: !llvm.i32
+	%2 = spv.ShiftLeftLogical %arg0, %arg2 : i32, i16
+
+    // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm.i16 to !llvm.i32
+    // CHECK: %{{.*}} = llvm.shl %{{.*}}, %[[EXT2:.*]]: !llvm.i32
+	%3 = spv.ShiftLeftLogical %arg0, %arg3 : i32, ui16
+	return
+}
+
+func @shift_left_logical_vector(%arg0: vector<4xi64>, %arg1: vector<4xsi64>, %arg2: vector<4xi32>, %arg3: vector<4xui32>) {
+	// CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+	%0 = spv.ShiftLeftLogical %arg0, %arg0 : vector<4xi64>, vector<4xi64>
+
+    // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+	%1 = spv.ShiftLeftLogical %arg0, %arg1 : vector<4xi64>, vector<4xsi64>
+
+    // CHECK: %[[EXT1:.*]] = llvm.sext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+    // CHECK: %{{.*}} = llvm.shl %{{.*}}, %[[EXT1:.*]]: !llvm<"<4 x i64>">
+	%2 = spv.ShiftLeftLogical %arg0, %arg2 : vector<4xi64>,  vector<4xi32>
+
+    // CHECK: %[[EXT2:.*]] = llvm.zext %{{.*}} : !llvm<"<4 x i32>"> to !llvm<"<4 x i64>">
+    // CHECK: %{{.*}} = llvm.shl %{{.*}}, %[[EXT2:.*]]: !llvm<"<4 x i64>">
+	%3 = spv.ShiftLeftLogical %arg0, %arg3 : vector<4xi64>, vector<4xui32>
+	return
+}


        


More information about the Mlir-commits mailing list