[Mlir-commits] [mlir] 8404b23 - [mlir][llvm] Add memset support for mem2reg/sroa

Christian Ulmann llvmlistbot at llvm.org
Wed Jun 14 02:17:36 PDT 2023


Author: Théo Degioanni
Date: 2023-06-14T09:17:08Z
New Revision: 8404b23acd70b8db1411b98a04b4ea62eaeb48dd

URL: https://github.com/llvm/llvm-project/commit/8404b23acd70b8db1411b98a04b4ea62eaeb48dd
DIFF: https://github.com/llvm/llvm-project/commit/8404b23acd70b8db1411b98a04b4ea62eaeb48dd.diff

LOG: [mlir][llvm] Add memset support for mem2reg/sroa

This revision introduces support for memset intrinsics in SROA and
mem2reg for the LLVM dialect. This is achieved for SROA by breaking
memsets of aggregates into multiple memsets of scalars, and for mem2reg
by promoting memsets of single integer slots into the value the memset
operation would yield.

The SROA logic supports breaking memsets of static size operating at the
start of a memory slot. The intended most common case is for memsets
covering the entirety of a struct, most often as a way to initialize it
to 0.

The mem2reg logic supports dynamic values and static sizes as input to
promotable memsets. This is achieved by lowering memsets into
`ceil(log_2(n))` LeftShift operations, `ceil(log_2(n))` Or operations
and up to one ZExt operation (for n the byte width of the integer),
computing in registers the integer value the memset would create. Only
byte-aligned integers are supported, more types could easily be added
afterwards.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D152367

Added: 
    mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
    mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
    mlir/lib/Transforms/Mem2Reg.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 3872f5a7339ae..ce59aaebf0ac9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -282,8 +282,11 @@ def LLVM_MemcpyInlineOp :
       # setAliasAnalysisMetadataCode;
 }
 
-def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], [],
-                      /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
+def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], 
+    [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+     DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+     DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
+    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> {
   dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
                   I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
   // Append the alias attributes defined by LLVM_IntrOpBase.

diff  --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 5ef646b24ff9f..9ffa709cc5bfd 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -103,16 +103,32 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       "bool", "loadsFrom",
       (ins "const ::mlir::MemorySlot &":$slot)
     >,
+    InterfaceMethod<[{
+        Gets whether this operation stores to the specified slot.
+
+        No IR mutation is allowed in this method.
+      }],
+      "bool", "storesTo",
+      (ins "const ::mlir::MemorySlot &":$slot)
+    >,
     InterfaceMethod<[{
         Gets the value stored to the provided memory slot, or returns a null
         value if this operation does not store to this slot. An operation
         storing a value to a slot must always be able to provide the value it
-        stores. This method is only called on operations that use the slot.
-
-        No IR mutation is allowed in this method.
+        stores. This method is only called once per slot promotion, and only
+        on operations that store to the slot according to the `storesTo` method.
+        The returned value must dominate all operations dominated by the storing
+        operation.
+
+        If IR must be mutated to extract a concrete value being stored, mutation
+        must happen through the provided rewriter. The rewriter is located
+        immediately after the memory operation on call. No IR deletion is
+        allowed in this method. IR mutations must not introduce new uses of the
+        memory slot. Existing control flow must not be modified.
       }],
       "::mlir::Value", "getStored",
-      (ins "const ::mlir::MemorySlot &":$slot)
+      (ins "const ::mlir::MemorySlot &":$slot,
+           "::mlir::RewriterBase &":$rewriter)
     >,
     InterfaceMethod<[{
         Checks that this operation can be promoted to no longer use the provided

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index b36ed78d6685c..65c3d0290b269 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -13,10 +13,13 @@
 
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 
@@ -119,12 +122,20 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
   return getAddr() == slot.ptr;
 }
 
-Value LLVM::LoadOp::getStored(const MemorySlot &slot) { return {}; }
+bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
+
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+  llvm_unreachable("getStored should not be called on LoadOp");
+}
 
 bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
 
-Value LLVM::StoreOp::getStored(const MemorySlot &slot) {
-  return getAddr() == slot.ptr ? getValue() : Value();
+bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
+  return getAddr() == slot.ptr;
+}
+
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+  return getValue();
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -310,7 +321,8 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
   for (IntegerAttr index : llvm::drop_begin(indices)) {
     // Ensure the structure of the type being indexed can be reasoned about.
     // This includes rejecting any potential typed pointer.
-    auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
+    auto destructurable =
+        llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
     if (!destructurable)
       return {};
 
@@ -369,7 +381,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
                                  DenseMap<Attribute, MemorySlot> &subslots,
                                  RewriterBase &rewriter) {
-  IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
+  IntegerAttr firstLevelIndex =
+      llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
   const MemorySlot &newSlot = subslots.at(firstLevelIndex);
 
   ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
@@ -400,6 +413,180 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
   return DeletionKind::Keep;
 }
 
+//===----------------------------------------------------------------------===//
+// Interfaces for memset
+//===----------------------------------------------------------------------===//
+
+/// Returns the length of the given memset in bytes if it can be known at
+/// compile-time on a best-effort basis, nothing otherwise.
+static std::optional<uint64_t> getStaticMemsetLen(LLVM::MemsetOp op) {
+  IntegerAttr::ValueType memsetLen;
+  if (!matchPattern(op.getLen(), m_ConstantInt(&memsetLen)))
+    return {};
+
+  if (memsetLen.getBitWidth() > 64)
+    return {};
+
+  return memsetLen.getZExtValue();
+}
+
+/// Returns whether one can be sure the memset does not write outside of the
+/// bounds of the given slot, on a best-effort basis.
+static bool definitelyWritesOnlyWithinSlot(LLVM::MemsetOp op,
+                                           const MemorySlot &slot,
+                                           DataLayout &dataLayout) {
+  if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
+      op.getDst() != slot.ptr)
+    return false;
+
+  std::optional<uint64_t> memsetLen = getStaticMemsetLen(op);
+  return memsetLen && *memsetLen <= dataLayout.getTypeSize(slot.elemType);
+}
+
+bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
+  return getDst() == slot.ptr;
+}
+
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
+                                RewriterBase &rewriter) {
+  // TODO: Support non-integer types.
+  return TypeSwitch<Type, Value>(slot.elemType)
+      .Case([&](IntegerType intType) -> Value {
+        if (intType.getWidth() == 8)
+          return getVal();
+
+        assert(intType.getWidth() % 8 == 0);
+
+        // Build the memset integer by repeatedly shifting the value and or-ing
+        // it with the previous value.
+        uint64_t coveredBits = 8;
+        Value currentValue =
+            rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
+        while (coveredBits < intType.getWidth()) {
+          Value shiftBy =
+              rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
+          Value shifted =
+              rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
+          currentValue =
+              rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
+          coveredBits *= 2;
+        }
+
+        return currentValue;
+      })
+      .Default([](Type) -> Value {
+        llvm_unreachable(
+            "getStored should not be called on memset to unsupported type");
+      });
+}
+
+bool LLVM::MemsetOp::canUsesBeRemoved(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+  // TODO: Support non-integer types.
+  bool canConvertType =
+      TypeSwitch<Type, bool>(slot.elemType)
+          .Case([](IntegerType intType) {
+            return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
+          })
+          .Default([](Type) { return false; });
+  if (!canConvertType)
+    return false;
+
+  if (getIsVolatile())
+    return false;
+
+  DataLayout layout = DataLayout::closest(*this);
+  return getStaticMemsetLen(*this) == layout.getTypeSize(slot.elemType);
+}
+
+DeletionKind LLVM::MemsetOp::removeBlockingUses(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    RewriterBase &rewriter, Value reachingDefinition) {
+  return DeletionKind::Delete;
+}
+
+LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+  DataLayout dataLayout = DataLayout::closest(*this);
+  return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
+}
+
+bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
+                               SmallPtrSetImpl<Attribute> &usedIndices,
+                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+  if (&slot.elemType.getDialect() != getOperation()->getDialect())
+    return false;
+
+  if (getIsVolatile())
+    return false;
+
+  if (!slot.elemType.cast<DestructurableTypeInterface>()
+           .getSubelementIndexMap())
+    return false;
+
+  DataLayout dataLayout = DataLayout::closest(*this);
+  return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
+}
+
+DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
+                                    DenseMap<Attribute, MemorySlot> &subslots,
+                                    RewriterBase &rewriter) {
+  std::optional<DenseMap<Attribute, Type>> types =
+      slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+
+  IntegerAttr memsetLenAttr;
+  bool successfulMatch =
+      matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
+  (void)successfulMatch;
+  assert(successfulMatch);
+
+  bool packed = false;
+  if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
+    packed = structType.isPacked();
+
+  Type i32 = IntegerType::get(getContext(), 32);
+  DataLayout dataLayout = DataLayout::closest(*this);
+  uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
+  uint64_t covered = 0;
+  for (size_t i = 0; i < types->size(); i++) {
+    // Create indices on the fly to get elements in the right order.
+    Attribute index = IntegerAttr::get(i32, i);
+    Type elemType = types->at(index);
+    uint64_t typeSize = dataLayout.getTypeSize(elemType);
+
+    if (!packed)
+      covered =
+          llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
+
+    if (covered >= memsetLen)
+      break;
+
+    // If this subslot is used, apply a new memset to it.
+    // Otherwise, only compute its offset within the original memset.
+    if (subslots.contains(index)) {
+      uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
+
+      Value newMemsetSizeValue =
+          rewriter
+              .create<LLVM::ConstantOp>(
+                  getLen().getLoc(),
+                  IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
+              .getResult();
+
+      rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
+                                      getVal(), newMemsetSizeValue,
+                                      getIsVolatile());
+    }
+
+    covered += typeSize;
+  }
+
+  return DeletionKind::Delete;
+}
+
 //===----------------------------------------------------------------------===//
 // Interfaces for destructurable types
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index bd15d6455f5ab..93ec2bcdf58fa 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 
@@ -160,7 +161,12 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
   return getMemRef() == slot.ptr;
 }
 
-Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
+bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
+
+Value memref::LoadOp::getStored(const MemorySlot &slot,
+                                RewriterBase &rewriter) {
+  llvm_unreachable("getStored should not be called on LoadOp");
+}
 
 bool memref::LoadOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
@@ -222,9 +228,12 @@ DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
 
 bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
 
-Value memref::StoreOp::getStored(const MemorySlot &slot) {
-  if (getMemRef() != slot.ptr)
-    return {};
+bool memref::StoreOp::storesTo(const MemorySlot &slot) {
+  return getMemRef() == slot.ptr;
+}
+
+Value memref::StoreOp::getStored(const MemorySlot &slot,
+                                 RewriterBase &rewriter) {
   return getValue();
 }
 

diff  --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 5152d10fbf6b4..65de25dd2f326 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -172,12 +172,13 @@ class MemorySlotPromoter {
   /// Computes the reaching definition for all the operations that require
   /// promotion. `reachingDef` is the value the slot should contain at the
   /// beginning of the block. This method returns the reached definition at the
-  /// end of the block.
+  /// end of the block. This method must only be called at most once per block.
   Value computeReachingDefInBlock(Block *block, Value reachingDef);
 
   /// Computes the reaching definition for all the operations that require
   /// promotion. `reachingDef` corresponds to the initial value the
   /// slot will contain before any write, typically a poison value.
+  /// This method must only be called at most once per region.
   void computeReachingDefInRegion(Region *region, Value reachingDef);
 
   /// Removes the blocking uses of the slot, in topological order.
@@ -326,7 +327,7 @@ SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
 
         // If we store to the slot, further loads will see that value.
         // Because we did not meet any load before, the value is not live-in.
-        if (memOp.getStored(slot))
+        if (memOp.storesTo(slot))
           break;
       }
     }
@@ -365,7 +366,7 @@ void MemorySlotPromotionAnalyzer::computeMergePoints(
   SmallPtrSet<Block *, 16> definingBlocks;
   for (Operation *user : slot.ptr.getUsers())
     if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
-      if (storeOp.getStored(slot))
+      if (storeOp.storesTo(slot))
         definingBlocks.insert(user->getBlock());
 
   idfCalculator.setDefiningBlocks(definingBlocks);
@@ -416,13 +417,21 @@ MemorySlotPromotionAnalyzer::computeInfo() {
 
 Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
                                                     Value reachingDef) {
-  for (Operation &op : block->getOperations()) {
+  SmallVector<Operation *> blockOps;
+  for (Operation &op : block->getOperations())
+    blockOps.push_back(&op);
+  for (Operation *op : blockOps) {
     if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
       if (info.userToBlockingUses.contains(memOp))
         reachingDefs.insert({memOp, reachingDef});
 
-      if (Value stored = memOp.getStored(slot))
+      if (memOp.storesTo(slot)) {
+        rewriter.setInsertionPointAfter(memOp);
+        Value stored = memOp.getStored(slot, rewriter);
+        assert(stored && "a memory operation storing to a slot must provide a "
+                         "new definition of the slot");
         reachingDef = stored;
+      }
     }
   }
 

diff  --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
new file mode 100644
index 0000000000000..e0b0d297e7470
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -0,0 +1,145 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @basic_memset
+llvm.func @basic_memset() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(4 : i32) : i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK-NOT: "llvm.intr.memset"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[VALUE_32]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @allow_dynamic_value_memset
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @allow_dynamic_value_memset(%memset_value: i8) -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_len = llvm.mlir.constant(4 : i32) : i32
+  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK-NOT: "llvm.intr.memset"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[VALUE_32]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @exotic_target_memset
+llvm.func @exotic_target_memset() -> i40 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(5 : i32) : i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
+  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
+  // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
+  // CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
+  // CHECK-NOT: "llvm.intr.memset"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
+  // CHECK: llvm.return %[[VALUE_COMPL]] : i40
+  llvm.return %2 : i40
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_volatile_memset
+llvm.func @no_volatile_memset() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = true}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_partial_memset
+llvm.func @no_partial_memset() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @no_overflowing_memset
+llvm.func @no_overflowing_memset() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(6 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(6 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @only_byte_aligned_integers_memset
+llvm.func @only_byte_aligned_integers_memset() -> i10 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i10
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i10 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i10
+  llvm.return %2 : i10
+}

diff  --git a/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
new file mode 100644
index 0000000000000..4b4905f67dc00
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
@@ -0,0 +1,237 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(sroa))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @memset
+llvm.func @memset() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 16 bytes means it will span over the first 4 i32 entries
+  %memset_len = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_partial
+llvm.func @memset_partial() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only the second i32 will be actually used. As the memset writes up
+  // to half of it, only 2 bytes will be set.
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 6 bytes means it will span over the first i32 and half of the second i32.
+  %memset_len = llvm.mlir.constant(6 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_full
+llvm.func @memset_full() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 40 bytes means it will span over the entire array
+  %memset_len = llvm.mlir.constant(40 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_too_much
+llvm.func @memset_too_much() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32>
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(41 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 41 bytes means it will span over the entire array, and then some
+  %memset_len = llvm.mlir.constant(41 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_no_volatile
+llvm.func @memset_no_volatile() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32>
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(16 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = true}>
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @indirect_memset
+llvm.func @indirect_memset() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // This memset will only cover the selected element.
+  %memset_len = llvm.mlir.constant(4 : i32) : i32
+  %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%2, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @invalid_indirect_memset
+llvm.func @invalid_indirect_memset() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.struct<"foo", (i32, i32)>
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(6 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // This memset will go slightly beyond one of the elements.
+  %memset_len = llvm.mlir.constant(6 : i32) : i32
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0]
+  %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+  // CHECK: "llvm.intr.memset"(%[[GEP]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  "llvm.intr.memset"(%2, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_double_use
+llvm.func @memset_double_use() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 8 bytes means it will span over the two i32 entries.
+  %memset_len = llvm.mlir.constant(8 : i32) : i32
+  // We expect two generated memset, one for each field.
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  // CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  // CHECK-NOT: "llvm.intr.memset"
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  %4 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
+  %5 = llvm.load %4 : !llvm.ptr -> f32
+  // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+  %6 = llvm.bitcast %5 : f32 to i32
+  %7 = llvm.add %3, %6 : i32
+  llvm.return %7 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_considers_alignment
+llvm.func @memset_considers_alignment() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
+  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 8 bytes means it will span over the i8 and the i32 entry.
+  // Because of padding, the f32 entry will not be touched.
+  %memset_len = llvm.mlir.constant(8 : i32) : i32
+  // Even though the two i32 are used, only one memset should be generated,
+  // as the second i32 is not touched by the initial memset.
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  // CHECK-NOT: "llvm.intr.memset"
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  %4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)>
+  %5 = llvm.load %4 : !llvm.ptr -> f32
+  // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+  %6 = llvm.bitcast %5 : f32 to i32
+  %7 = llvm.add %3, %6 : i32
+  llvm.return %7 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @memset_considers_packing
+llvm.func @memset_considers_packing() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
+  // CHECK-DAG: %[[MEMSET_LEN_WHOLE:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK-DAG: %[[MEMSET_LEN_PARTIAL:.*]] = llvm.mlir.constant(3 : i32) : i32
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 8 bytes means it will span over all the fields, because there is no padding as the struct is packed.
+  %memset_len = llvm.mlir.constant(8 : i32) : i32
+  // Now all fields are touched by the memset.
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_WHOLE]]) <{isVolatile = false}>
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_PARTIAL]]) <{isVolatile = false}>
+  // CHECK-NOT: "llvm.intr.memset"
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  %4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)>
+  %5 = llvm.load %4 : !llvm.ptr -> f32
+  // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+  %6 = llvm.bitcast %5 : f32 to i32
+  %7 = llvm.add %3, %6 : i32
+  llvm.return %7 : i32
+}


        


More information about the Mlir-commits mailing list