[Mlir-commits] [mlir] [MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics (PR #131621)

Johannes de Fine Licht llvmlistbot at llvm.org
Mon Mar 17 07:41:15 PDT 2025


================
@@ -1051,30 +1051,53 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
 template <class MemsetIntr>
 static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
                              OpBuilder &builder) {
-  // TODO: Support non-integer types.
-  return TypeSwitch<Type, Value>(slot.elemType)
-      .Case([&](IntegerType intType) -> Value {
-        if (intType.getWidth() == 8)
-          return op.getVal();
-
-        assert(intType.getWidth() % 8 == 0);
-
-        // Build the memset integer by repeatedly shifting the value and
-        // or-ing it with the previous value.
-        uint64_t coveredBits = 8;
-        Value currentValue =
-            builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
-        while (coveredBits < intType.getWidth()) {
-          Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
-                                                           coveredBits);
-          Value shifted =
-              builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
-          currentValue =
-              builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
-          coveredBits *= 2;
-        }
+  /// Returns an integer value that is `width` bits wide representing the value
+  /// assigned to the slot by memset.
+  auto buildMemsetValue = [&](unsigned width) -> Value {
+    if (width == 8)
+      return op.getVal();
+
+    assert(width % 8 == 0);
+
+    auto intType = IntegerType::get(op.getContext(), width);
+
+    // If we know the pattern at compile time, we can compute and assign a
+    // constant directly.
+    IntegerAttr constantPattern;
+    if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
+      APInt memsetVal(/*numBits=*/width, /*val=*/0);
+      unsigned patternWidth = op.getVal().getType().getWidth();
+      for (unsigned loBit = 0; loBit + patternWidth <= width;
+           loBit += patternWidth)
----------------
definelicht wrote:

This is slightly over-generalized/paranoid because it doesn't hard assume that the pattern is exactly 8 bits, which doesn't seem to be verified on the operation...

https://github.com/llvm/llvm-project/pull/131621


More information about the Mlir-commits mailing list