[Mlir-commits] [mlir] [MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics (PR #131621)

Johannes de Fine Licht llvmlistbot at llvm.org
Mon Mar 17 09:26:15 PDT 2025


https://github.com/definelicht updated https://github.com/llvm/llvm-project/pull/131621

>From ce9af7449b59c668caca6fb1d0b7b65d0f7147af Mon Sep 17 00:00:00 2001
From: Johannes de Fine Licht <johannes.definelicht at nextsilicon.com>
Date: Mon, 17 Mar 2025 14:35:23 +0000
Subject: [PATCH] [MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics.

This was lacking a bitcast from the shifted integer type into a float.
Other non-struct types than integers and floats will still not be
Mem2Reg'ed.

Also adds special handling for constants to be emitted as a constant
directly rather than relying on followup canonicalization patterns
(`memset` of zero is a case that can appear in the wild).
---
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 73 ++++++++++++-------
 .../Dialect/LLVMIR/mem2reg-intrinsics.mlir    | 61 +++++++++++-----
 2 files changed, 90 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 655316cc5d66d..d1ccb487d2265 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1051,30 +1051,52 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
 template <class MemsetIntr>
 static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
                              OpBuilder &builder) {
-  // TODO: Support non-integer types.
-  return TypeSwitch<Type, Value>(slot.elemType)
-      .Case([&](IntegerType intType) -> Value {
-        if (intType.getWidth() == 8)
-          return op.getVal();
-
-        assert(intType.getWidth() % 8 == 0);
-
-        // Build the memset integer by repeatedly shifting the value and
-        // or-ing it with the previous value.
-        uint64_t coveredBits = 8;
-        Value currentValue =
-            builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
-        while (coveredBits < intType.getWidth()) {
-          Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
-                                                           coveredBits);
-          Value shifted =
-              builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
-          currentValue =
-              builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
-          coveredBits *= 2;
-        }
+  /// Returns an integer value that is `width` bits wide representing the value
+  /// assigned to the slot by memset.
+  auto buildMemsetValue = [&](unsigned width) -> Value {
+    assert(width % 8 == 0);
+    auto intType = IntegerType::get(op.getContext(), width);
+
+    // If we know the pattern at compile time, we can compute and assign a
+    // constant directly.
+    IntegerAttr constantPattern;
+    if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
+      assert(constantPattern.getValue().getBitWidth() == 8);
+      APInt memsetVal(/*numBits=*/width, /*val=*/0);
+      for (unsigned loBit = 0; loBit < width; loBit += 8)
+        memsetVal.insertBits(constantPattern.getValue(), loBit);
+      return builder.create<LLVM::ConstantOp>(
+          op.getLoc(), IntegerAttr::get(intType, memsetVal));
+    }
+
+    // If the output is a single byte, we can return the pattern directly.
+    if (width == 8)
+      return op.getVal();
+
+    // Otherwise build the memset integer at runtime by repeatedly shifting the
+    // value and or-ing it with the previous value.
+    uint64_t coveredBits = 8;
+    Value currentValue =
+        builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
+    while (coveredBits < width) {
+      Value shiftBy =
+          builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits);
+      Value shifted =
+          builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
+      currentValue =
+          builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
+      coveredBits *= 2;
+    }
 
-        return currentValue;
+    return currentValue;
+  };
+  return TypeSwitch<Type, Value>(slot.elemType)
+      .Case([&](IntegerType type) -> Value {
+        return buildMemsetValue(type.getWidth());
+      })
+      .Case([&](FloatType type) -> Value {
+        Value intVal = buildMemsetValue(type.getWidth());
+        return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal);
       })
       .Default([](Type) -> Value {
         llvm_unreachable(
@@ -1088,11 +1110,10 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
                        const SmallPtrSetImpl<OpOperand *> &blockingUses,
                        SmallVectorImpl<OpOperand *> &newBlockingUses,
                        const DataLayout &dataLayout) {
-  // TODO: Support non-integer types.
   bool canConvertType =
       TypeSwitch<Type, bool>(slot.elemType)
-          .Case([](IntegerType intType) {
-            return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
+          .Case<IntegerType, FloatType>([](auto type) {
+            return type.getWidth() % 8 == 0 && type.getWidth() > 0;
           })
           .Default([](Type) { return false; });
   if (!canConvertType)
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 646667505a373..37c2f525a9dcb 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -23,6 +23,30 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_float(%memset_value: i8) -> f32 {
+  %one = llvm.mlir.constant(1 : i32) : i32
+  %alloca = llvm.alloca %one x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_len = llvm.mlir.constant(4 : i32) : i32
+  "llvm.intr.memset"(%alloca, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+  // CHECK-NOT: "llvm.intr.memset"
+  %load = llvm.load %alloca {alignment = 4 : i64} : !llvm.ptr -> f32
+  // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+  llvm.return %load : f32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @basic_memset_inline
 // CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
 llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
@@ -53,20 +77,28 @@ llvm.func @basic_memset_constant() -> i32 {
   %memset_len = llvm.mlir.constant(4 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
-  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
-  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
-  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
-  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
-  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
-  // CHECK: llvm.return %[[RES]] : i32
+  // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
   llvm.return %2 : i32
 }
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_one_byte_constant
+llvm.func @memset_one_byte_constant() -> i8 {
+  %one = llvm.mlir.constant(1 : i32) : i32
+  %alloca = llvm.alloca %one x i8 : (i32) -> !llvm.ptr
+  // CHECK: %{{.+}} = llvm.mlir.constant(42 : i8) : i8
+  %value = llvm.mlir.constant(42 : i8) : i8
+  "llvm.intr.memset"(%alloca, %value, %one) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %load = llvm.load %alloca : !llvm.ptr -> i8
+  // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: llvm.return %[[CONSTANT_VAL]] : i8
+  llvm.return %load : i8
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @basic_memset_inline_constant
 llvm.func @basic_memset_inline_constant() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
@@ -74,15 +106,8 @@ llvm.func @basic_memset_inline_constant() -> i32 {
   %memset_value = llvm.mlir.constant(42 : i8) : i8
   "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
-  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
-  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
-  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
-  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
-  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
-  // CHECK: llvm.return %[[RES]] : i32
+  // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
   llvm.return %2 : i32
 }
 



More information about the Mlir-commits mailing list