[Mlir-commits] [mlir] [MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics (PR #131621)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 17 07:39:39 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Johannes de Fine Licht (definelicht)

<details>
<summary>Changes</summary>

This was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed.

Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns (`memset` of zero is a case that can appear in the wild).

---
Full diff: https://github.com/llvm/llvm-project/pull/131621.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+48-26) 
- (modified) mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir (+51-18) 


``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 655316cc5d66d..16109b5c59f7e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1051,30 +1051,53 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
 template <class MemsetIntr>
 static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
                              OpBuilder &builder) {
-  // TODO: Support non-integer types.
-  return TypeSwitch<Type, Value>(slot.elemType)
-      .Case([&](IntegerType intType) -> Value {
-        if (intType.getWidth() == 8)
-          return op.getVal();
-
-        assert(intType.getWidth() % 8 == 0);
-
-        // Build the memset integer by repeatedly shifting the value and
-        // or-ing it with the previous value.
-        uint64_t coveredBits = 8;
-        Value currentValue =
-            builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
-        while (coveredBits < intType.getWidth()) {
-          Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
-                                                           coveredBits);
-          Value shifted =
-              builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
-          currentValue =
-              builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
-          coveredBits *= 2;
-        }
+  /// Returns an integer value that is `width` bits wide representing the value
+  /// assigned to the slot by memset.
+  auto buildMemsetValue = [&](unsigned width) -> Value {
+    if (width == 8)
+      return op.getVal();
+
+    assert(width % 8 == 0);
+
+    auto intType = IntegerType::get(op.getContext(), width);
+
+    // If we know the pattern at compile time, we can compute and assign a
+    // constant directly.
+    IntegerAttr constantPattern;
+    if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
+      APInt memsetVal(/*numBits=*/width, /*val=*/0);
+      unsigned patternWidth = op.getVal().getType().getWidth();
+      for (unsigned loBit = 0; loBit + patternWidth <= width;
+           loBit += patternWidth)
+        memsetVal.insertBits(constantPattern.getValue(), loBit);
+      return builder.create<LLVM::ConstantOp>(
+          op.getLoc(), IntegerAttr::get(intType, memsetVal));
+    }
+
+    // Otherwise build the memset integer at runtime by repeatedly shifting the
+    // value and or-ing it with the previous value.
+    uint64_t coveredBits = 8;
+    Value currentValue =
+        builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
+    while (coveredBits < width) {
+      Value shiftBy =
+          builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits);
+      Value shifted =
+          builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
+      currentValue =
+          builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
+      coveredBits *= 2;
+    }
 
-        return currentValue;
+    return currentValue;
+  };
+  return TypeSwitch<Type, Value>(slot.elemType)
+      .Case([&](IntegerType type) -> Value {
+        return buildMemsetValue(type.getWidth());
+      })
+      .Case([&](FloatType type) -> Value {
+        Value intVal = buildMemsetValue(type.getWidth());
+        return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal);
       })
       .Default([](Type) -> Value {
         llvm_unreachable(
@@ -1088,11 +1111,10 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
                        const SmallPtrSetImpl<OpOperand *> &blockingUses,
                        SmallVectorImpl<OpOperand *> &newBlockingUses,
                        const DataLayout &dataLayout) {
-  // TODO: Support non-integer types.
   bool canConvertType =
       TypeSwitch<Type, bool>(slot.elemType)
-          .Case([](IntegerType intType) {
-            return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
+          .Case<IntegerType, FloatType>([](auto type) {
+            return type.getWidth() % 8 == 0 && type.getWidth() > 0;
           })
           .Default([](Type) { return false; });
   if (!canConvertType)
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 646667505a373..f3dca45265082 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -23,6 +23,30 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_float(%memset_value: i8) -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_len = llvm.mlir.constant(4 : i32) : i32
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+  // CHECK-NOT: "llvm.intr.memset"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+  // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @basic_memset_inline
 // CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
 llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
@@ -45,6 +69,29 @@ llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_inline_float(%memset_value: i8) -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4 : i32}> : (!llvm.ptr, i8) -> ()
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+  // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @basic_memset_constant
 llvm.func @basic_memset_constant() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
@@ -53,15 +100,8 @@ llvm.func @basic_memset_constant() -> i32 {
   %memset_len = llvm.mlir.constant(4 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
-  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
-  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
-  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
-  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
-  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
-  // CHECK: llvm.return %[[RES]] : i32
+  // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
   llvm.return %2 : i32
 }
 
@@ -74,15 +114,8 @@ llvm.func @basic_memset_inline_constant() -> i32 {
   %memset_value = llvm.mlir.constant(42 : i8) : i8
   "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
-  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
-  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
-  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
-  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
-  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
-  // CHECK: llvm.return %[[RES]] : i32
+  // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
   llvm.return %2 : i32
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/131621


More information about the Mlir-commits mailing list