[Mlir-commits] [mlir] [MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics (PR #131621)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 17 07:39:39 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Johannes de Fine Licht (definelicht)
<details>
<summary>Changes</summary>
This was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed.
Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns (`memset` of zero is a case that can appear in the wild).
---
Full diff: https://github.com/llvm/llvm-project/pull/131621.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+48-26)
- (modified) mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir (+51-18)
``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 655316cc5d66d..16109b5c59f7e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1051,30 +1051,53 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
template <class MemsetIntr>
static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
OpBuilder &builder) {
- // TODO: Support non-integer types.
- return TypeSwitch<Type, Value>(slot.elemType)
- .Case([&](IntegerType intType) -> Value {
- if (intType.getWidth() == 8)
- return op.getVal();
-
- assert(intType.getWidth() % 8 == 0);
-
- // Build the memset integer by repeatedly shifting the value and
- // or-ing it with the previous value.
- uint64_t coveredBits = 8;
- Value currentValue =
- builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
- while (coveredBits < intType.getWidth()) {
- Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
- coveredBits);
- Value shifted =
- builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
- currentValue =
- builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
- coveredBits *= 2;
- }
+ /// Returns an integer value that is `width` bits wide representing the value
+ /// assigned to the slot by memset.
+ auto buildMemsetValue = [&](unsigned width) -> Value {
+ if (width == 8)
+ return op.getVal();
+
+ assert(width % 8 == 0);
+
+ auto intType = IntegerType::get(op.getContext(), width);
+
+ // If we know the pattern at compile time, we can compute and assign a
+ // constant directly.
+ IntegerAttr constantPattern;
+ if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
+ APInt memsetVal(/*numBits=*/width, /*val=*/0);
+ unsigned patternWidth = op.getVal().getType().getWidth();
+ for (unsigned loBit = 0; loBit + patternWidth <= width;
+ loBit += patternWidth)
+ memsetVal.insertBits(constantPattern.getValue(), loBit);
+ return builder.create<LLVM::ConstantOp>(
+ op.getLoc(), IntegerAttr::get(intType, memsetVal));
+ }
+
+ // Otherwise build the memset integer at runtime by repeatedly shifting the
+ // value and or-ing it with the previous value.
+ uint64_t coveredBits = 8;
+ Value currentValue =
+ builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
+ while (coveredBits < width) {
+ Value shiftBy =
+ builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits);
+ Value shifted =
+ builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
+ currentValue =
+ builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
+ coveredBits *= 2;
+ }
- return currentValue;
+ return currentValue;
+ };
+ return TypeSwitch<Type, Value>(slot.elemType)
+ .Case([&](IntegerType type) -> Value {
+ return buildMemsetValue(type.getWidth());
+ })
+ .Case([&](FloatType type) -> Value {
+ Value intVal = buildMemsetValue(type.getWidth());
+ return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal);
})
.Default([](Type) -> Value {
llvm_unreachable(
@@ -1088,11 +1111,10 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
- // TODO: Support non-integer types.
bool canConvertType =
TypeSwitch<Type, bool>(slot.elemType)
- .Case([](IntegerType intType) {
- return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
+ .Case<IntegerType, FloatType>([](auto type) {
+ return type.getWidth() % 8 == 0 && type.getWidth() > 0;
})
.Default([](Type) { return false; });
if (!canConvertType)
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 646667505a373..f3dca45265082 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -23,6 +23,30 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
// -----
+// CHECK-LABEL: llvm.func @memset_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_float(%memset_value: i8) -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_len = llvm.mlir.constant(4 : i32) : i32
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+ // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+ // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+ // CHECK-NOT: "llvm.intr.memset"
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+ // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+ llvm.return %2 : f32
+}
+
+// -----
+
// CHECK-LABEL: llvm.func @basic_memset_inline
// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
@@ -45,6 +69,29 @@ llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
// -----
+// CHECK-LABEL: llvm.func @memset_inline_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_inline_float(%memset_value: i8) -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4 : i32}> : (!llvm.ptr, i8) -> ()
+ // CHECK-NOT: "llvm.intr.memset.inline"
+ // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+ // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+ // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+ // CHECK-NOT: "llvm.intr.memset.inline"
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+ // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+ llvm.return %2 : f32
+}
+
+// -----
+
// CHECK-LABEL: llvm.func @basic_memset_constant
llvm.func @basic_memset_constant() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
@@ -53,15 +100,8 @@ llvm.func @basic_memset_constant() -> i32 {
%memset_len = llvm.mlir.constant(4 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
- // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
- // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
- // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
- // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
- // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
- // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
- // CHECK: llvm.return %[[RES]] : i32
+ // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+ // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
llvm.return %2 : i32
}
@@ -74,15 +114,8 @@ llvm.func @basic_memset_inline_constant() -> i32 {
%memset_value = llvm.mlir.constant(42 : i8) : i8
"llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
- // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
- // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
- // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
- // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
- // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
- // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
- // CHECK: llvm.return %[[RES]] : i32
+ // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+ // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
llvm.return %2 : i32
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/131621
More information about the Mlir-commits
mailing list