[Mlir-commits] [mlir] [MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics (PR #131621)
Johannes de Fine Licht
llvmlistbot at llvm.org
Mon Mar 17 09:21:22 PDT 2025
https://github.com/definelicht updated https://github.com/llvm/llvm-project/pull/131621
>From 63777d8db95c243c5bbc91f60c46d95e85d479b2 Mon Sep 17 00:00:00 2001
From: Johannes de Fine Licht <johannes.definelicht at nextsilicon.com>
Date: Mon, 17 Mar 2025 14:35:23 +0000
Subject: [PATCH] [MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics.
This was lacking a bitcast from the shifted integer type into a float.
Other non-struct types than integers and floats will still not be
Mem2Reg'ed.
Also adds special handling for constants to be emitted as a constant
directly rather than relying on followup canonicalization patterns
(`memset` of zero is a case that can appear in the wild).
---
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 74 ++++++++++++-------
.../Dialect/LLVMIR/mem2reg-intrinsics.mlir | 61 ++++++++++-----
2 files changed, 91 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 655316cc5d66d..08046be86b2bb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1051,30 +1051,53 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
template <class MemsetIntr>
static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
OpBuilder &builder) {
- // TODO: Support non-integer types.
- return TypeSwitch<Type, Value>(slot.elemType)
- .Case([&](IntegerType intType) -> Value {
- if (intType.getWidth() == 8)
- return op.getVal();
-
- assert(intType.getWidth() % 8 == 0);
-
- // Build the memset integer by repeatedly shifting the value and
- // or-ing it with the previous value.
- uint64_t coveredBits = 8;
- Value currentValue =
- builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
- while (coveredBits < intType.getWidth()) {
- Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
- coveredBits);
- Value shifted =
- builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
- currentValue =
- builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
- coveredBits *= 2;
- }
+ /// Returns an integer value that is `width` bits wide representing the value
+ /// assigned to the slot by memset.
+ auto buildMemsetValue = [&](unsigned width) -> Value {
+ assert(width % 8 == 0);
+ auto intType = IntegerType::get(op.getContext(), width);
+
+ // If we know the pattern at compile time, we can compute and assign a
+ // constant directly.
+ IntegerAttr constantPattern;
+ if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
+ // The pattern must fit in a byte.
+ assert(constantPattern.getValue().getActiveBits() <= 8);
+ APInt memsetVal(/*numBits=*/width, /*val=*/0);
+ for (unsigned loBit = 0; loBit < width; loBit += 8)
+ memsetVal.insertBits(constantPattern.getValue(), loBit);
+ return builder.create<LLVM::ConstantOp>(
+ op.getLoc(), IntegerAttr::get(intType, memsetVal));
+ }
+
+ // If the output is a single byte, we can return the pattern directly.
+ if (width == 8)
+ return op.getVal();
+
+ // Otherwise build the memset integer at runtime by repeatedly shifting the
+ // value and or-ing it with the previous value.
+ uint64_t coveredBits = 8;
+ Value currentValue =
+ builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
+ while (coveredBits < width) {
+ Value shiftBy =
+ builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits);
+ Value shifted =
+ builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
+ currentValue =
+ builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
+ coveredBits *= 2;
+ }
- return currentValue;
+ return currentValue;
+ };
+ return TypeSwitch<Type, Value>(slot.elemType)
+ .Case([&](IntegerType type) -> Value {
+ return buildMemsetValue(type.getWidth());
+ })
+ .Case([&](FloatType type) -> Value {
+ Value intVal = buildMemsetValue(type.getWidth());
+ return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal);
})
.Default([](Type) -> Value {
llvm_unreachable(
@@ -1088,11 +1111,10 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
- // TODO: Support non-integer types.
bool canConvertType =
TypeSwitch<Type, bool>(slot.elemType)
- .Case([](IntegerType intType) {
- return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
+ .Case<IntegerType, FloatType>([](auto type) {
+ return type.getWidth() % 8 == 0 && type.getWidth() > 0;
})
.Default([](Type) { return false; });
if (!canConvertType)
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 646667505a373..37c2f525a9dcb 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -23,6 +23,30 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
// -----
+// CHECK-LABEL: llvm.func @memset_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_float(%memset_value: i8) -> f32 {
+ %one = llvm.mlir.constant(1 : i32) : i32
+ %alloca = llvm.alloca %one x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_len = llvm.mlir.constant(4 : i32) : i32
+ "llvm.intr.memset"(%alloca, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+ // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+ // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+ // CHECK-NOT: "llvm.intr.memset"
+ %load = llvm.load %alloca {alignment = 4 : i64} : !llvm.ptr -> f32
+ // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+ llvm.return %load : f32
+}
+
+// -----
+
// CHECK-LABEL: llvm.func @basic_memset_inline
// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
@@ -53,20 +77,28 @@ llvm.func @basic_memset_constant() -> i32 {
%memset_len = llvm.mlir.constant(4 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
- // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
- // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
- // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
- // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
- // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
- // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
- // CHECK: llvm.return %[[RES]] : i32
+ // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+ // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
llvm.return %2 : i32
}
// -----
+// CHECK-LABEL: llvm.func @memset_one_byte_constant
+llvm.func @memset_one_byte_constant() -> i8 {
+ %one = llvm.mlir.constant(1 : i32) : i32
+ %alloca = llvm.alloca %one x i8 : (i32) -> !llvm.ptr
+ // CHECK: %{{.+}} = llvm.mlir.constant(42 : i8) : i8
+ %value = llvm.mlir.constant(42 : i8) : i8
+ "llvm.intr.memset"(%alloca, %value, %one) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ %load = llvm.load %alloca : !llvm.ptr -> i8
+ // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(42 : i8) : i8
+ // CHECK: llvm.return %[[CONSTANT_VAL]] : i8
+ llvm.return %load : i8
+}
+
+// -----
+
// CHECK-LABEL: llvm.func @basic_memset_inline_constant
llvm.func @basic_memset_inline_constant() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
@@ -74,15 +106,8 @@ llvm.func @basic_memset_inline_constant() -> i32 {
%memset_value = llvm.mlir.constant(42 : i8) : i8
"llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
- // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
- // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
- // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
- // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
- // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
- // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
- // CHECK: llvm.return %[[RES]] : i32
+ // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+ // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
llvm.return %2 : i32
}
More information about the Mlir-commits
mailing list