[Mlir-commits] [mlir] 7a4e39b - [MLIR][SPIRVToLLVM] Implementation of spv.BitFieldSExtract and spv.BitFieldUExtract patterns
George Mitenkov
llvmlistbot at llvm.org
Wed Jul 8 02:38:42 PDT 2020
Author: George Mitenkov
Date: 2020-07-08T12:37:37+03:00
New Revision: 7a4e39b326d0cc69e6b4fbe9010aaf5dc704a12f
URL: https://github.com/llvm/llvm-project/commit/7a4e39b326d0cc69e6b4fbe9010aaf5dc704a12f
DIFF: https://github.com/llvm/llvm-project/commit/7a4e39b326d0cc69e6b4fbe9010aaf5dc704a12f.diff
LOG: [MLIR][SPIRVToLLVM] Implementation of spv.BitFieldSExtract and spv.BitFieldUExtract patterns
This patch adds conversion patterns for `spv.BitFieldSExtract` and `spv.BitFieldUExtract`.
As in the patch for `spv.BitFieldInsert`, `offset` and `count` have to be broadcasted in
vector case and casted to match the type of the base.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D82640
Added:
Modified:
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 202fba592b86..52fa2091389e 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -119,7 +119,7 @@ static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
return value;
}
-/// Broadcasts the value to vector with `numElements` number of elements
+/// Broadcasts the value to vector with `numElements` number of elements.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
@@ -136,6 +136,35 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
return broadcasted;
}
+/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
+static Value optionallyBroadcast(Location loc, Value value, Type srcType,
+ LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
+ if (auto vectorType = srcType.dyn_cast<VectorType>()) {
+ unsigned numElements = vectorType.getNumElements();
+ return broadcast(loc, value, numElements, typeConverter, rewriter);
+ }
+ return value;
+}
+
+/// Utility function for bitfiled ops: `BitFieldInsert`, `BitFieldSExtract` and
+/// `BitFieldUExtract`.
+/// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
+/// a vector type, construct a vector that has:
+/// - same number of elements as `Base`
+/// - each element has the type that is the same as the type of `Offset` or
+/// `Count`
+/// - each element has the same value as `Offset` or `Count`
+/// Then cast `Offset` and `Count` if their bit width is
diff erent
+/// from `Base` bit width.
+static Value processCountOrOffset(Location loc, Value value, Type srcType,
+ Type dstType, LLVMTypeConverter &converter,
+ ConversionPatternRewriter &rewriter) {
+ Value broadcasted =
+ optionallyBroadcast(loc, value, srcType, converter, rewriter);
+ return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
+}
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@@ -156,41 +185,20 @@ class BitFieldInsertPattern
return failure();
Location loc = op.getLoc();
- // Broadcast `Offset` and `Count` to match the type of `Base` and `Insert`.
- // If `Base` is of a vector type, construct a vector that has:
- // - same number of elements as `Base`
- // - each element has the type that is the same as the type of `Offset` or
- // `Count`
- // - each element has the same value as `Offset` or `Count`
- Value offset;
- Value count;
- if (auto vectorType = srcType.dyn_cast<VectorType>()) {
- unsigned numElements = vectorType.getNumElements();
- offset =
- broadcast(loc, op.offset(), numElements, typeConverter, rewriter);
- count = broadcast(loc, op.count(), numElements, typeConverter, rewriter);
- } else {
- offset = op.offset();
- count = op.count();
- }
-
- // Create a mask with all bits set of the same type as `srcType`
- Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
-
- // Need to cast `Offset` and `Count` if their bit width is
diff erent
- // from `Base` bit width.
- Value optionallyCastedCount =
- optionallyTruncateOrExtend(loc, count, dstType, rewriter);
- Value optionallyCastedOffset =
- optionallyTruncateOrExtend(loc, offset, dstType, rewriter);
+ // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
+ Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
+ typeConverter, rewriter);
+ Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
+ typeConverter, rewriter);
// Create a mask with bits set outside [Offset, Offset + Count - 1].
- Value maskShiftedByCount = rewriter.create<LLVM::ShlOp>(
- loc, dstType, minusOne, optionallyCastedCount);
+ Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
+ Value maskShiftedByCount =
+ rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
maskShiftedByCount, minusOne);
- Value maskShiftedByCountAndOffset = rewriter.create<LLVM::ShlOp>(
- loc, dstType, negated, optionallyCastedOffset);
+ Value maskShiftedByCountAndOffset =
+ rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
Value mask = rewriter.create<LLVM::XOrOp>(
loc, dstType, maskShiftedByCountAndOffset, minusOne);
@@ -198,8 +206,8 @@ class BitFieldInsertPattern
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
Value baseAndMask =
rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
- Value insertShiftedByOffset = rewriter.create<LLVM::ShlOp>(
- loc, dstType, op.insert(), optionallyCastedOffset);
+ Value insertShiftedByOffset =
+ rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
insertShiftedByOffset);
return success();
@@ -252,6 +260,94 @@ class ConstantScalarAndVectorPattern
}
};
+class BitFieldSExtractPattern
+ : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = op.getType();
+ auto dstType = this->typeConverter.convertType(srcType);
+ if (!dstType)
+ return failure();
+ Location loc = op.getLoc();
+
+ // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
+ Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
+ typeConverter, rewriter);
+ Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
+ typeConverter, rewriter);
+
+ // Create a constant that holds the size of the `Base`.
+ IntegerType integerType;
+ if (auto vecType = srcType.dyn_cast<VectorType>())
+ integerType = vecType.getElementType().cast<IntegerType>();
+ else
+ integerType = srcType.cast<IntegerType>();
+
+ auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
+ Value size =
+ srcType.isa<VectorType>()
+ ? rewriter.create<LLVM::ConstantOp>(
+ loc, dstType,
+ SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
+ : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
+
+ // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
+ // at Offset + Count - 1 is the most significant bit now.
+ Value countPlusOffset =
+ rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
+ Value amountToShiftLeft =
+ rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
+ Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
+ loc, dstType, op.base(), amountToShiftLeft);
+
+ // Shift the result right, filling the bits with the sign bit.
+ Value amountToShiftRight =
+ rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
+ rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
+ amountToShiftRight);
+ return success();
+ }
+};
+
+class BitFieldUExtractPattern
+ : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = op.getType();
+ auto dstType = this->typeConverter.convertType(srcType);
+ if (!dstType)
+ return failure();
+ Location loc = op.getLoc();
+
+ // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
+ Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
+ typeConverter, rewriter);
+ Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
+ typeConverter, rewriter);
+
+ // Create a mask with bits set at [0, Count - 1].
+ Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
+ Value maskShiftedByCount =
+ rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
+ Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
+ minusOne);
+
+ // Shift `Base` by `Offset` and apply the mask on it.
+ Value shiftedBase =
+ rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
+ rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
+ return success();
+ }
+};
+
/// Converts SPIR-V operations that have straightforward LLVM equivalent
/// into LLVM dialect operations.
template <typename SPIRVOp, typename LLVMOp>
@@ -586,7 +682,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
// Bitwise ops
- BitFieldInsertPattern,
+ BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
diff --git a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
index e4958e3a43f0..31ffc6dbf7dc 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
@@ -54,9 +54,9 @@ func @bitfield_insert_scalar_same_bit_width(%base: i32, %insert: i32, %offset: i
// CHECK-LABEL: func @bitfield_insert_scalar_smaller_bit_width
// CHECK-SAME: %[[BASE:.*]]: !llvm.i64, %[[INSERT:.*]]: !llvm.i64, %[[OFFSET:.*]]: !llvm.i8, %[[COUNT:.*]]: !llvm.i8
func @bitfield_insert_scalar_smaller_bit_width(%base: i64, %insert: i64, %offset: i8, %count: i8) {
- // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i64) : !llvm.i64
- // CHECK: %[[EXT_COUNT:.*]] = llvm.zext %[[COUNT]] : !llvm.i8 to !llvm.i64
// CHECK: %[[EXT_OFFSET:.*]] = llvm.zext %[[OFFSET]] : !llvm.i8 to !llvm.i64
+ // CHECK: %[[EXT_COUNT:.*]] = llvm.zext %[[COUNT]] : !llvm.i8 to !llvm.i64
+ // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i64) : !llvm.i64
// CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[EXT_COUNT]] : !llvm.i64
// CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i64
// CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[EXT_OFFSET]] : !llvm.i64
@@ -71,9 +71,9 @@ func @bitfield_insert_scalar_smaller_bit_width(%base: i64, %insert: i64, %offset
// CHECK-LABEL: func @bitfield_insert_scalar_greater_bit_width
// CHECK-SAME: %[[BASE:.*]]: !llvm.i16, %[[INSERT:.*]]: !llvm.i16, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i64
func @bitfield_insert_scalar_greater_bit_width(%base: i16, %insert: i16, %offset: i32, %count: i64) {
- // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i16) : !llvm.i16
- // CHECK: %[[TRUNC_COUNT:.*]] = llvm.trunc %[[COUNT]] : !llvm.i64 to !llvm.i16
// CHECK: %[[TRUNC_OFFSET:.*]] = llvm.trunc %[[OFFSET]] : !llvm.i32 to !llvm.i16
+ // CHECK: %[[TRUNC_COUNT:.*]] = llvm.trunc %[[COUNT]] : !llvm.i64 to !llvm.i16
+ // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i16) : !llvm.i16
// CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[TRUNC_COUNT]] : !llvm.i16
// CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i16
// CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[TRUNC_OFFSET]] : !llvm.i16
@@ -110,6 +110,141 @@ func @bitfield_insert_vector(%base: vector<2xi32>, %insert: vector<2xi32>, %offs
return
}
+//===----------------------------------------------------------------------===//
+// spv.BitFieldSExtract
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @bitfield_sextract_scalar_same_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i64, %[[OFFSET:.*]]: !llvm.i64, %[[COUNT:.*]]: !llvm.i64
+func @bitfield_sextract_scalar_same_bit_width(%base: i64, %offset: i64, %count: i64) {
+ // CHECK: %[[SIZE:.]] = llvm.mlir.constant(64 : i64) : !llvm.i64
+ // CHECK: %[[T0:.*]] = llvm.add %[[COUNT]], %[[OFFSET]] : !llvm.i64
+ // CHECK: %[[T1:.*]] = llvm.sub %[[SIZE]], %[[T0]] : !llvm.i64
+ // CHECK: %[[SHIFTED_LEFT:.*]] = llvm.shl %[[BASE]], %[[T1]] : !llvm.i64
+ // CHECK: %[[T2:.*]] = llvm.add %[[OFFSET]], %[[T1]] : !llvm.i64
+ // CHECK: %{{.*}} = llvm.ashr %[[SHIFTED_LEFT]], %[[T2]] : !llvm.i64
+ %0 = spv.BitFieldSExtract %base, %offset, %count : i64, i64, i64
+ return
+}
+
+// CHECK-LABEL: func @bitfield_sextract_scalar_smaller_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i32, %[[OFFSET:.*]]: !llvm.i8, %[[COUNT:.*]]: !llvm.i8
+func @bitfield_sextract_scalar_smaller_bit_width(%base: i32, %offset: i8, %count: i8) {
+ // CHECK: %[[EXT_OFFSET:.*]] = llvm.zext %[[OFFSET]] : !llvm.i8 to !llvm.i32
+ // CHECK: %[[EXT_COUNT:.*]] = llvm.zext %[[COUNT]] : !llvm.i8 to !llvm.i32
+ // CHECK: %[[SIZE:.]] = llvm.mlir.constant(32 : i32) : !llvm.i32
+ // CHECK: %[[T0:.*]] = llvm.add %[[EXT_COUNT]], %[[EXT_OFFSET]] : !llvm.i32
+ // CHECK: %[[T1:.*]] = llvm.sub %[[SIZE]], %[[T0]] : !llvm.i32
+ // CHECK: %[[SHIFTED_LEFT:.*]] = llvm.shl %[[BASE]], %[[T1]] : !llvm.i32
+ // CHECK: %[[T2:.*]] = llvm.add %[[EXT_OFFSET]], %[[T1]] : !llvm.i32
+ // CHECK: %{{.*}} = llvm.ashr %[[SHIFTED_LEFT]], %[[T2]] : !llvm.i32
+ %0 = spv.BitFieldSExtract %base, %offset, %count : i32, i8, i8
+ return
+}
+
+// CHECK-LABEL: func @bitfield_sextract_scalar_greater_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i32, %[[OFFSET:.*]]: !llvm.i64, %[[COUNT:.*]]: !llvm.i64
+func @bitfield_sextract_scalar_greater_bit_width(%base: i32, %offset: i64, %count: i64) {
+ // CHECK: %[[TRUNC_OFFSET:.*]] = llvm.trunc %[[OFFSET]] : !llvm.i64 to !llvm.i32
+ // CHECK: %[[TRUNC_COUNT:.*]] = llvm.trunc %[[COUNT]] : !llvm.i64 to !llvm.i32
+ // CHECK: %[[SIZE:.]] = llvm.mlir.constant(32 : i32) : !llvm.i32
+ // CHECK: %[[T0:.*]] = llvm.add %[[TRUNC_COUNT]], %[[TRUNC_OFFSET]] : !llvm.i32
+ // CHECK: %[[T1:.*]] = llvm.sub %[[SIZE]], %[[T0]] : !llvm.i32
+ // CHECK: %[[SHIFTED_LEFT:.*]] = llvm.shl %[[BASE]], %[[T1]] : !llvm.i32
+ // CHECK: %[[T2:.*]] = llvm.add %[[TRUNC_OFFSET]], %[[T1]] : !llvm.i32
+ // CHECK: %{{.*}} = llvm.ashr %[[SHIFTED_LEFT]], %[[T2]] : !llvm.i32
+ %0 = spv.BitFieldSExtract %base, %offset, %count : i32, i64, i64
+ return
+}
+
+// CHECK-LABEL: func @bitfield_sextract_vector
+// CHECK-SAME: %[[BASE:.*]]: !llvm<"<2 x i32>">, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32
+func @bitfield_sextract_vector(%base: vector<2xi32>, %offset: i32, %count: i32) {
+ // CHECK: %[[OFFSET_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>">
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+ // CHECK: %[[OFFSET_V1:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>">
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+ // CHECK: %[[OFFSET_V2:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>">
+ // CHECK: %[[COUNT_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>">
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+ // CHECK: %[[COUNT_V1:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>">
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+ // CHECK: %[[COUNT_V2:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>">
+ // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(dense<32> : vector<2xi32>) : !llvm<"<2 x i32>">
+ // CHECK: %[[T0:.*]] = llvm.add %[[COUNT_V2]], %[[OFFSET_V2]] : !llvm<"<2 x i32>">
+ // CHECK: %[[T1:.*]] = llvm.sub %[[SIZE]], %[[T0]] : !llvm<"<2 x i32>">
+ // CHECK: %[[SHIFTED_LEFT:.*]] = llvm.shl %[[BASE]], %[[T1]] : !llvm<"<2 x i32>">
+ // CHECK: %[[T2:.*]] = llvm.add %[[OFFSET_V2]], %[[T1]] : !llvm<"<2 x i32>">
+ // CHECK: %{{.*}} = llvm.ashr %[[SHIFTED_LEFT]], %[[T2]] : !llvm<"<2 x i32>">
+ %0 = spv.BitFieldSExtract %base, %offset, %count : vector<2xi32>, i32, i32
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitFieldUExtract
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @bitfield_uextract_scalar_same_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i32, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32
+func @bitfield_uextract_scalar_same_bit_width(%base: i32, %offset: i32, %count: i32) {
+ // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
+ // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT]] : !llvm.i32
+ // CHECK: %[[MASK:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i32
+ // CHECK: %[[SHIFTED_BASE:.*]] = llvm.lshr %[[BASE]], %[[OFFSET]] : !llvm.i32
+ // CHECK: %{{.*}} = llvm.and %[[SHIFTED_BASE]], %[[MASK]] : !llvm.i32
+ %0 = spv.BitFieldUExtract %base, %offset, %count : i32, i32, i32
+ return
+}
+
+// CHECK-LABEL: func @bitfield_uextract_scalar_smaller_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i32, %[[OFFSET:.*]]: !llvm.i16, %[[COUNT:.*]]: !llvm.i8
+func @bitfield_uextract_scalar_smaller_bit_width(%base: i32, %offset: i16, %count: i8) {
+ // CHECK: %[[EXT_OFFSET:.*]] = llvm.zext %[[OFFSET]] : !llvm.i16 to !llvm.i32
+ // CHECK: %[[EXT_COUNT:.*]] = llvm.zext %[[COUNT]] : !llvm.i8 to !llvm.i32
+ // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
+ // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[EXT_COUNT]] : !llvm.i32
+ // CHECK: %[[MASK:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i32
+ // CHECK: %[[SHIFTED_BASE:.*]] = llvm.lshr %[[BASE]], %[[EXT_OFFSET]] : !llvm.i32
+ // CHECK: %{{.*}} = llvm.and %[[SHIFTED_BASE]], %[[MASK]] : !llvm.i32
+ %0 = spv.BitFieldUExtract %base, %offset, %count : i32, i16, i8
+ return
+}
+
+// CHECK-LABEL: func @bitfield_uextract_scalar_greater_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i8, %[[OFFSET:.*]]: !llvm.i16, %[[COUNT:.*]]: !llvm.i8
+func @bitfield_uextract_scalar_greater_bit_width(%base: i8, %offset: i16, %count: i8) {
+ // CHECK: %[[TRUNC_OFFSET:.*]] = llvm.trunc %[[OFFSET]] : !llvm.i16 to !llvm.i8
+ // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i8) : !llvm.i8
+ // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT]] : !llvm.i8
+ // CHECK: %[[MASK:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i8
+ // CHECK: %[[SHIFTED_BASE:.*]] = llvm.lshr %[[BASE]], %[[TRUNC_OFFSET]] : !llvm.i8
+ // CHECK: %{{.*}} = llvm.and %[[SHIFTED_BASE]], %[[MASK]] : !llvm.i8
+ %0 = spv.BitFieldUExtract %base, %offset, %count : i8, i16, i8
+ return
+}
+
+// CHECK-LABEL: func @bitfield_uextract_vector
+// CHECK-SAME: %[[BASE:.*]]: !llvm<"<2 x i32>">, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32
+func @bitfield_uextract_vector(%base: vector<2xi32>, %offset: i32, %count: i32) {
+ // CHECK: %[[OFFSET_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>">
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+ // CHECK: %[[OFFSET_V1:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>">
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+ // CHECK: %[[OFFSET_V2:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>">
+ // CHECK: %[[COUNT_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>">
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+ // CHECK: %[[COUNT_V1:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>">
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+ // CHECK: %[[COUNT_V2:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>">
+ // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(dense<-1> : vector<2xi32>) : !llvm<"<2 x i32>">
+ // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT_V2]] : !llvm<"<2 x i32>">
+ // CHECK: %[[MASK:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm<"<2 x i32>">
+ // CHECK: %[[SHIFTED_BASE:.*]] = llvm.lshr %[[BASE]], %[[OFFSET_V2]] : !llvm<"<2 x i32>">
+ // CHECK: %{{.*}} = llvm.and %[[SHIFTED_BASE]], %[[MASK]] : !llvm<"<2 x i32>">
+ %0 = spv.BitFieldUExtract %base, %offset, %count : vector<2xi32>, i32, i32
+ return
+}
+
//===----------------------------------------------------------------------===//
// spv.BitwiseAnd
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list