[Mlir-commits] [mlir] 03fe7eb - [MLIR][SPIRVToLLVM] Implementation of spv.BitFieldInsert pattern

Lei Zhang llvmlistbot at llvm.org
Thu Jul 2 09:23:07 PDT 2020


Author: George Mitenkov
Date: 2020-07-02T12:19:12-04:00
New Revision: 03fe7eb16fa224a95d4ba252e2a03cbb3fa244af

URL: https://github.com/llvm/llvm-project/commit/03fe7eb16fa224a95d4ba252e2a03cbb3fa244af
DIFF: https://github.com/llvm/llvm-project/commit/03fe7eb16fa224a95d4ba252e2a03cbb3fa244af.diff

LOG: [MLIR][SPIRVToLLVM] Implementation of spv.BitFieldInsert pattern

This patch introduces conversion pattern for `spv.BitFiledInsert` op,
as well as some utility functions to facilitate code reading.
Since `spv.BitFiledInsert` may take both vector and integer operands,
this case was specifically handled by broadcasting values (`count`
and `offset` here) to vectors. Moreover, the types had to be converted
to same bitwidth in order to conform with LLVM dialect rules.
This was done with `zext` when extending (Note that `count` and
`offset` are treated as unsigned) and `trunc` in the opposite case.
For the latter one, truncation is safe since the op is defined only when
`count`/`offset`/their sum is less than the bitwidth of the result.
This introduces a natural bound of the value of 64, which can be
expressed as `i8`.

Reviewed By: antiagainst, ftynse

Differential Revision: https://reviews.llvm.org/D82639

Added: 
    

Modified: 
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
    mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 771610e45b7d..1681fed23487 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -53,7 +53,15 @@ static unsigned getBitWidth(Type type) {
   return elementType.getIntOrFloatBitWidth();
 }
 
-/// Creates `IntegerAttribute` with all bits set for given type.
+/// Returns the bit width of LLVMType integer or vector.
+static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
+  return type.isVectorTy() ? type.getVectorElementType()
+                                 .getUnderlyingType()
+                                 ->getIntegerBitWidth()
+                           : type.getUnderlyingType()->getIntegerBitWidth();
+}
+
+/// Creates `IntegerAttribute` with all bits set for given type
 IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
   if (auto vecType = type.dyn_cast<VectorType>()) {
     auto integerType = vecType.getElementType().cast<IntegerType>();
@@ -63,12 +71,132 @@ IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
   return builder.getIntegerAttr(integerType, -1);
 }
 
+/// Creates `llvm.mlir.constant` with all bits set for the given type.
+static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
+                                      PatternRewriter &rewriter) {
+  if (srcType.isa<VectorType>())
+    return rewriter.create<LLVM::ConstantOp>(
+        loc, dstType,
+        SplatElementsAttr::get(srcType.cast<ShapedType>(),
+                               minusOneIntegerAttribute(srcType, rewriter)));
+  return rewriter.create<LLVM::ConstantOp>(
+      loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
+}
+
+/// Utility function for bitfiled ops:
+///   - `BitFieldInsert`
+///   - `BitFieldSExtract`
+///   - `BitFieldUExtract`
+/// Truncates or extends the value. If the bitwidth of the value is the same as
+/// `dstType` bitwidth, the value remains unchanged.
+static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
+                                        PatternRewriter &rewriter) {
+  auto srcType = value.getType();
+  auto llvmType = dstType.cast<LLVM::LLVMType>();
+  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
+  unsigned valueBitWidth =
+      srcType.isa<LLVM::LLVMType>()
+          ? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>())
+          : getBitWidth(srcType);
+
+  if (valueBitWidth < targetBitWidth)
+    return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
+  // If the bit widths of `Count` and `Offset` are greater than the bit width
+  // of the target type, they are truncated. Truncation is safe since `Count`
+  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
+  // both values can be expressed in 8 bits.
+  if (valueBitWidth > targetBitWidth)
+    return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
+  return value;
+}
+
+/// Broadcasts the value to vector with `numElements` number of elements
+static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
+                       LLVMTypeConverter &typeConverter,
+                       ConversionPatternRewriter &rewriter) {
+  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
+  auto llvmVectorType = typeConverter.convertType(vectorType);
+  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
+  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
+  for (unsigned i = 0; i < numElements; ++i) {
+    auto index = rewriter.create<LLVM::ConstantOp>(
+        loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
+    broadcasted = rewriter.create<LLVM::InsertElementOp>(
+        loc, llvmVectorType, broadcasted, toBroadcast, index);
+  }
+  return broadcasted;
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
 
 namespace {
 
+class BitFieldInsertPattern
+    : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = op.getType();
+    auto dstType = this->typeConverter.convertType(srcType);
+    if (!dstType)
+      return failure();
+    Location loc = op.getLoc();
+
+    // Broadcast `Offset` and `Count` to match the type of `Base` and `Insert`.
+    // If `Base` is of a vector type, construct a vector that has:
+    //  - same number of elements as `Base`
+    //  - each element has the type that is the same as the type of `Offset` or
+    //    `Count`
+    //  - each element has the same value as `Offset` or `Count`
+    Value offset;
+    Value count;
+    if (auto vectorType = srcType.dyn_cast<VectorType>()) {
+      unsigned numElements = vectorType.getNumElements();
+      offset =
+          broadcast(loc, op.offset(), numElements, typeConverter, rewriter);
+      count = broadcast(loc, op.count(), numElements, typeConverter, rewriter);
+    } else {
+      offset = op.offset();
+      count = op.count();
+    }
+
+    // Create a mask with all bits set of the same type as `srcType`
+    Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
+
+    // Need to cast `Offset` and `Count` if their bit width is 
diff erent
+    // from `Base` bit width.
+    Value optionallyCastedCount =
+        optionallyTruncateOrExtend(loc, count, dstType, rewriter);
+    Value optionallyCastedOffset =
+        optionallyTruncateOrExtend(loc, offset, dstType, rewriter);
+
+    // Create a mask with bits set outside [Offset, Offset + Count - 1].
+    Value maskShiftedByCount = rewriter.create<LLVM::ShlOp>(
+        loc, dstType, minusOne, optionallyCastedCount);
+    Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
+                                                 maskShiftedByCount, minusOne);
+    Value maskShiftedByCountAndOffset = rewriter.create<LLVM::ShlOp>(
+        loc, dstType, negated, optionallyCastedOffset);
+    Value mask = rewriter.create<LLVM::XOrOp>(
+        loc, dstType, maskShiftedByCountAndOffset, minusOne);
+
+    // Extract unchanged bits from the `Base`  that are outside of
+    // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
+    Value baseAndMask =
+        rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
+    Value insertShiftedByOffset = rewriter.create<LLVM::ShlOp>(
+        loc, dstType, op.insert(), optionallyCastedOffset);
+    rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
+                                            insertShiftedByOffset);
+    return success();
+  }
+};
+
 /// Converts SPIR-V operations that have straightforward LLVM equivalent
 /// into LLVM dialect operations.
 template <typename SPIRVOp, typename LLVMOp>
@@ -380,6 +508,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
 
       // Bitwise ops
+      BitFieldInsertPattern,
       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
index 434430d660fd..e4958e3a43f0 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir
@@ -32,6 +32,84 @@ func @bitreverse_vector(%arg0: vector<4xi32>) {
 	return
 }
 
+//===----------------------------------------------------------------------===//
+// spv.BitFieldInsert
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @bitfield_insert_scalar_same_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i32, %[[INSERT:.*]]: !llvm.i32, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32
+func @bitfield_insert_scalar_same_bit_width(%base: i32, %insert: i32, %offset: i32, %count: i32) {
+    // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
+    // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT]] : !llvm.i32
+    // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i32
+    // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[OFFSET]] : !llvm.i32
+    // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i32
+    // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i32
+    // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[OFFSET]] : !llvm.i32
+    // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i32
+    %0 = spv.BitFieldInsert %base, %insert, %offset, %count : i32, i32, i32
+    return
+}
+
+// CHECK-LABEL: func @bitfield_insert_scalar_smaller_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i64, %[[INSERT:.*]]: !llvm.i64, %[[OFFSET:.*]]: !llvm.i8, %[[COUNT:.*]]: !llvm.i8
+func @bitfield_insert_scalar_smaller_bit_width(%base: i64, %insert: i64, %offset: i8, %count: i8) {
+    // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i64) : !llvm.i64
+    // CHECK: %[[EXT_COUNT:.*]] = llvm.zext %[[COUNT]] : !llvm.i8 to !llvm.i64
+    // CHECK: %[[EXT_OFFSET:.*]] = llvm.zext %[[OFFSET]] : !llvm.i8 to !llvm.i64
+    // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[EXT_COUNT]] : !llvm.i64
+    // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i64
+    // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[EXT_OFFSET]] : !llvm.i64
+    // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i64
+    // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i64
+    // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[EXT_OFFSET]] : !llvm.i64
+    // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i64
+    %0 = spv.BitFieldInsert %base, %insert, %offset, %count : i64, i8, i8
+    return
+}
+
+// CHECK-LABEL: func @bitfield_insert_scalar_greater_bit_width
+// CHECK-SAME: %[[BASE:.*]]: !llvm.i16, %[[INSERT:.*]]: !llvm.i16, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i64
+func @bitfield_insert_scalar_greater_bit_width(%base: i16, %insert: i16, %offset: i32, %count: i64) {
+    // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i16) : !llvm.i16
+    // CHECK: %[[TRUNC_COUNT:.*]] = llvm.trunc %[[COUNT]] : !llvm.i64 to !llvm.i16
+    // CHECK: %[[TRUNC_OFFSET:.*]] = llvm.trunc %[[OFFSET]] : !llvm.i32 to !llvm.i16
+    // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[TRUNC_COUNT]] : !llvm.i16
+    // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i16
+    // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[TRUNC_OFFSET]] : !llvm.i16
+    // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i16
+    // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i16
+    // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[TRUNC_OFFSET]] : !llvm.i16
+    // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i16
+    %0 = spv.BitFieldInsert %base, %insert, %offset, %count : i16, i32, i64
+    return
+}
+
+// CHECK-LABEL: func @bitfield_insert_vector
+// CHECK-SAME: %[[BASE:.*]]: !llvm<"<2 x i32>">, %[[INSERT:.*]]: !llvm<"<2 x i32>">, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32
+func @bitfield_insert_vector(%base: vector<2xi32>, %insert: vector<2xi32>, %offset: i32, %count: i32) {
+    // CHECK: %[[OFFSET_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>">
+    // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+    // CHECK: %[[OFFSET_V1:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>">
+    // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+    // CHECK: %[[OFFSET_V2:.*]] = llvm.insertelement  %[[OFFSET]], %[[OFFSET_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>">
+    // CHECK: %[[COUNT_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>">
+    // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+    // CHECK: %[[COUNT_V1:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>">
+    // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+    // CHECK: %[[COUNT_V2:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>">
+    // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(dense<-1> : vector<2xi32>) : !llvm<"<2 x i32>">
+    // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT_V2]] : !llvm<"<2 x i32>">
+    // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm<"<2 x i32>">
+    // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[OFFSET_V2]] : !llvm<"<2 x i32>">
+    // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm<"<2 x i32>">
+    // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm<"<2 x i32>">
+    // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[OFFSET_V2]] : !llvm<"<2 x i32>">
+    // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm<"<2 x i32>">
+    %0 = spv.BitFieldInsert %base, %insert, %offset, %count : vector<2xi32>, i32, i32
+    return
+}
+
 //===----------------------------------------------------------------------===//
 // spv.BitwiseAnd
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list