[Mlir-commits] [mlir] 30753af - [mlir][llvm] Add support for memset.inline (#115711)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 12 11:17:53 PST 2024


Author: PikachuHy
Date: 2024-11-12T20:17:50+01:00
New Revision: 30753afc2a3171e962e261622781852a01fbec72

URL: https://github.com/llvm/llvm-project/commit/30753afc2a3171e962e261622781852a01fbec72
DIFF: https://github.com/llvm/llvm-project/commit/30753afc2a3171e962e261622781852a01fbec72.diff

LOG: [mlir][llvm] Add support for memset.inline (#115711)

support `llvm.intr.memset.inline` in llvm-project repo before we add
support for `__builtin_memset_inline` in clangir

cc @bcardosolopes

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
    mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
    mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
    mlir/test/Target/LLVMIR/Import/intrinsic.ll
    mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index d07ebbacc60434..85785938405859 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -256,6 +256,32 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
   ];
 }
 
+def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2],
+    [DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+     DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+     DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
+    /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
+    /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
+    /*immArgAttrNames=*/["len", "isVolatile"]> {
+  dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
+                  I8:$val, APIntAttr:$len, I1Attr:$isVolatile);
+  // Append the alias attributes defined by LLVM_IntrOpBase.
+  let arguments = !con(args, aliasAttrs);
+  let builders = [
+    OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
+                    "bool":$isVolatile), [{
+      build($_builder, $_state, dst, val, len,
+            $_builder.getBoolAttr(isVolatile));
+    }]>,
+    OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
+                    "IntegerAttr":$isVolatile), [{
+      build($_builder, $_state, dst, val, len, isVolatile,
+            /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+            /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+    }]>
+  ];
+}
+
 def LLVM_NoAliasScopeDeclOp
     : LLVM_ZeroResultIntrOp<"experimental.noalias.scope.decl"> {
   let arguments = (ins LLVM_AliasScopeAttr:$scope);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index cff16afc73af3f..a59900745d026e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -94,6 +94,10 @@ SmallVector<Value> mlir::LLVM::MemsetOp::getAccessedOperands() {
   return {getDst()};
 }
 
+SmallVector<Value> mlir::LLVM::MemsetInlineOp::getAccessedOperands() {
+  return {getDst()};
+}
+
 SmallVector<Value> mlir::LLVM::CallOp::getAccessedOperands() {
   return llvm::to_vector(
       llvm::make_filter_range(getArgOperands(), [](Value arg) {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 5dc506c14ef96f..8b164cf59a076d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -904,6 +904,76 @@ std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
   return memIntrLen.getZExtValue();
 }
 
+/// Returns the length of the given memory intrinsic in bytes if it can be known
+/// at compile-time on a best-effort basis, nothing otherwise.
+/// Because MemsetInlineOp has its length encoded as an attribute, this requires
+/// specialized handling.
+template <>
+std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemsetInlineOp op) {
+  APInt memIntrLen = op.getLen();
+  if (memIntrLen.getBitWidth() > 64)
+    return {};
+  return memIntrLen.getZExtValue();
+}
+
+/// Returns an integer attribute representing the length of a memset intrinsic
+template <class MemsetIntr>
+IntegerAttr createMemsetLenAttr(MemsetIntr op) {
+  IntegerAttr memsetLenAttr;
+  bool successfulMatch =
+      matchPattern(op.getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
+  (void)successfulMatch;
+  assert(successfulMatch);
+  return memsetLenAttr;
+}
+
+/// Returns an integer attribute representing the length of a memset intrinsic
+/// Because MemsetInlineOp has its length encoded as an attribute, this requires
+/// specialized handling.
+template <>
+IntegerAttr createMemsetLenAttr(LLVM::MemsetInlineOp op) {
+  return op.getLenAttr();
+}
+
+/// Creates a memset intrinsic of that matches the `toReplace` intrinsic
+/// using the provided parameters. There are template specializations for
+/// MemsetOp and MemsetInlineOp.
+template <class MemsetIntr>
+void createMemsetIntr(OpBuilder &builder, MemsetIntr toReplace,
+                      IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
+                      DenseMap<Attribute, MemorySlot> &subslots,
+                      Attribute index);
+
+template <>
+void createMemsetIntr(OpBuilder &builder, LLVM::MemsetOp toReplace,
+                      IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
+                      DenseMap<Attribute, MemorySlot> &subslots,
+                      Attribute index) {
+  Value newMemsetSizeValue =
+      builder
+          .create<LLVM::ConstantOp>(
+              toReplace.getLen().getLoc(),
+              IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
+          .getResult();
+
+  builder.create<LLVM::MemsetOp>(toReplace.getLoc(), subslots.at(index).ptr,
+                                 toReplace.getVal(), newMemsetSizeValue,
+                                 toReplace.getIsVolatile());
+}
+
+template <>
+void createMemsetIntr(OpBuilder &builder, LLVM::MemsetInlineOp toReplace,
+                      IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
+                      DenseMap<Attribute, MemorySlot> &subslots,
+                      Attribute index) {
+  auto newMemsetSizeValue =
+      IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize);
+
+  builder.create<LLVM::MemsetInlineOp>(
+      toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(),
+      newMemsetSizeValue, toReplace.getIsVolatile());
+}
+
 } // namespace
 
 /// Returns whether one can be sure the memory intrinsic does not write outside
@@ -931,23 +1001,37 @@ static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
 }
 
 //===----------------------------------------------------------------------===//
-// Interfaces for memset
+// Interfaces for memset and memset.inline
 //===----------------------------------------------------------------------===//
 
-bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
+template <class MemsetIntr>
+static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
+                            SmallPtrSetImpl<Attribute> &usedIndices,
+                            SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                            const DataLayout &dataLayout) {
+  if (&slot.elemType.getDialect() != op.getOperation()->getDialect())
+    return false;
 
-bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
-  return getDst() == slot.ptr;
+  if (op.getIsVolatile())
+    return false;
+
+  if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
+    return false;
+
+  if (!areAllIndicesI32(slot))
+    return false;
+
+  return definitelyWritesOnlyWithinSlot(op, slot, dataLayout);
 }
 
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
-                                Value reachingDef,
-                                const DataLayout &dataLayout) {
+template <class MemsetIntr>
+static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
+                             OpBuilder &builder) {
   // TODO: Support non-integer types.
   return TypeSwitch<Type, Value>(slot.elemType)
       .Case([&](IntegerType intType) -> Value {
         if (intType.getWidth() == 8)
-          return getVal();
+          return op.getVal();
 
         assert(intType.getWidth() % 8 == 0);
 
@@ -955,14 +1039,14 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
         // or-ing it with the previous value.
         uint64_t coveredBits = 8;
         Value currentValue =
-            builder.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
+            builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
         while (coveredBits < intType.getWidth()) {
-          Value shiftBy =
-              builder.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
+          Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
+                                                           coveredBits);
           Value shifted =
-              builder.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
+              builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
           currentValue =
-              builder.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
+              builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
           coveredBits *= 2;
         }
 
@@ -974,10 +1058,12 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
       });
 }
 
-bool LLVM::MemsetOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses,
-    const DataLayout &dataLayout) {
+template <class MemsetIntr>
+static bool
+memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
+                       const SmallPtrSetImpl<OpOperand *> &blockingUses,
+                       SmallVectorImpl<OpOperand *> &newBlockingUses,
+                       const DataLayout &dataLayout) {
   // TODO: Support non-integer types.
   bool canConvertType =
       TypeSwitch<Type, bool>(slot.elemType)
@@ -988,62 +1074,28 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
   if (!canConvertType)
     return false;
 
-  if (getIsVolatile())
+  if (op.getIsVolatile())
     return false;
 
-  return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
-}
-
-DeletionKind LLVM::MemsetOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition,
-    const DataLayout &dataLayout) {
-  return DeletionKind::Delete;
-}
-
-LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
-    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
-    const DataLayout &dataLayout) {
-  return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
+  return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
 }
 
-bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
-                               SmallPtrSetImpl<Attribute> &usedIndices,
-                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
-                               const DataLayout &dataLayout) {
-  if (&slot.elemType.getDialect() != getOperation()->getDialect())
-    return false;
-
-  if (getIsVolatile())
-    return false;
-
-  if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
-    return false;
-
-  if (!areAllIndicesI32(slot))
-    return false;
-
-  return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
-}
+template <class MemsetIntr>
+static DeletionKind
+memsetRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
+             DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
+             const DataLayout &dataLayout) {
 
-DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
-                                    DenseMap<Attribute, MemorySlot> &subslots,
-                                    OpBuilder &builder,
-                                    const DataLayout &dataLayout) {
   std::optional<DenseMap<Attribute, Type>> types =
       cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
 
-  IntegerAttr memsetLenAttr;
-  bool successfulMatch =
-      matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
-  (void)successfulMatch;
-  assert(successfulMatch);
+  IntegerAttr memsetLenAttr = createMemsetLenAttr(op);
 
   bool packed = false;
   if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
     packed = structType.isPacked();
 
-  Type i32 = IntegerType::get(getContext(), 32);
+  Type i32 = IntegerType::get(op.getContext(), 32);
   uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
   uint64_t covered = 0;
   for (size_t i = 0; i < types->size(); i++) {
@@ -1063,16 +1115,8 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
     // Otherwise, only compute its offset within the original memset.
     if (subslots.contains(index)) {
       uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
-
-      Value newMemsetSizeValue =
-          builder
-              .create<LLVM::ConstantOp>(
-                  getLen().getLoc(),
-                  IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
-              .getResult();
-
-      builder.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr, getVal(),
-                                     newMemsetSizeValue, getIsVolatile());
+      createMemsetIntr(builder, op, memsetLenAttr, newMemsetSize, subslots,
+                       index);
     }
 
     covered += typeSize;
@@ -1081,6 +1125,103 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
   return DeletionKind::Delete;
 }
 
+bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
+  return getDst() == slot.ptr;
+}
+
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
+                                Value reachingDef,
+                                const DataLayout &dataLayout) {
+  return memsetGetStored(*this, slot, builder);
+}
+
+bool LLVM::MemsetOp::canUsesBeRemoved(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
+  return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+                                dataLayout);
+}
+
+DeletionKind LLVM::MemsetOp::removeBlockingUses(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    OpBuilder &builder, Value reachingDefinition,
+    const DataLayout &dataLayout) {
+  return DeletionKind::Delete;
+}
+
+LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
+  return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
+}
+
+bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
+                               SmallPtrSetImpl<Attribute> &usedIndices,
+                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                               const DataLayout &dataLayout) {
+  return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+                         dataLayout);
+}
+
+DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
+                                    DenseMap<Attribute, MemorySlot> &subslots,
+                                    OpBuilder &builder,
+                                    const DataLayout &dataLayout) {
+  return memsetRewire(*this, slot, subslots, builder, dataLayout);
+}
+
+bool LLVM::MemsetInlineOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+bool LLVM::MemsetInlineOp::storesTo(const MemorySlot &slot) {
+  return getDst() == slot.ptr;
+}
+
+Value LLVM::MemsetInlineOp::getStored(const MemorySlot &slot,
+                                      OpBuilder &builder, Value reachingDef,
+                                      const DataLayout &dataLayout) {
+  return memsetGetStored(*this, slot, builder);
+}
+
+bool LLVM::MemsetInlineOp::canUsesBeRemoved(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
+  return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+                                dataLayout);
+}
+
+DeletionKind LLVM::MemsetInlineOp::removeBlockingUses(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    OpBuilder &builder, Value reachingDefinition,
+    const DataLayout &dataLayout) {
+  return DeletionKind::Delete;
+}
+
+LogicalResult LLVM::MemsetInlineOp::ensureOnlySafeAccesses(
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
+  return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
+}
+
+bool LLVM::MemsetInlineOp::canRewire(
+    const DestructurableMemorySlot &slot,
+    SmallPtrSetImpl<Attribute> &usedIndices,
+    SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
+  return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+                         dataLayout);
+}
+
+DeletionKind
+LLVM::MemsetInlineOp::rewire(const DestructurableMemorySlot &slot,
+                             DenseMap<Attribute, MemorySlot> &subslots,
+                             OpBuilder &builder, const DataLayout &dataLayout) {
+  return memsetRewire(*this, slot, subslots, builder, dataLayout);
+}
+
 //===----------------------------------------------------------------------===//
 // Interfaces for memcpy/memmove
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 4fc80a87f20df5..646667505a3732 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -23,6 +23,28 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @basic_memset_inline
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4 : i32}> : (!llvm.ptr, i8) -> ()
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[VALUE_32]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @basic_memset_constant
 llvm.func @basic_memset_constant() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
@@ -45,6 +67,27 @@ llvm.func @basic_memset_constant() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @basic_memset_inline_constant
+llvm.func @basic_memset_inline_constant() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
+  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
+  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @exotic_target_memset
 // CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
 llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
@@ -71,6 +114,31 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @exotic_target_memset_inline
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @exotic_target_memset_inline(%memset_value: i8) -> i40 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 5}> : (!llvm.ptr, i8) -> ()
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
+  // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
+  // CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
+  // CHECK: llvm.return %[[VALUE_COMPL]] : i40
+  llvm.return %2 : i40
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @no_volatile_memset
 llvm.func @no_volatile_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -89,6 +157,22 @@ llvm.func @no_volatile_memset() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @no_volatile_memset_inline
+llvm.func @no_volatile_memset_inline() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = true, len = 4 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = true, len = 4}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  llvm.return %2 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @no_partial_memset
 llvm.func @no_partial_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -107,6 +191,22 @@ llvm.func @no_partial_memset() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @no_partial_memset_inline
+llvm.func @no_partial_memset_inline() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 2 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 2}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  llvm.return %2 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @no_overflowing_memset
 llvm.func @no_overflowing_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -125,6 +225,22 @@ llvm.func @no_overflowing_memset() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @no_overflowing_memset_inline
+llvm.func @no_overflowing_memset_inline() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 6 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 6}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  llvm.return %2 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @only_byte_aligned_integers_memset
 llvm.func @only_byte_aligned_integers_memset() -> i10 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -143,6 +259,22 @@ llvm.func @only_byte_aligned_integers_memset() -> i10 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @only_byte_aligned_integers_memset_inline
+llvm.func @only_byte_aligned_integers_memset_inline() -> i10 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i10
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i10 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 2 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 2}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i10
+  llvm.return %2 : i10
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @basic_memcpy
 // CHECK-SAME: (%[[SOURCE:.*]]: !llvm.ptr)
 llvm.func @basic_memcpy(%source: !llvm.ptr) -> i32 {

diff  --git a/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
index ba73025814cc05..6dc8a97884ee19 100644
--- a/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
@@ -21,6 +21,25 @@ llvm.func @memset() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline
+llvm.func @memset_inline() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 16 bytes means it will span over the first 4 i32 entries.
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 4 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 16}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @memset_partial
 llvm.func @memset_partial() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -43,6 +62,26 @@ llvm.func @memset_partial() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_partial
+llvm.func @memset_inline_partial() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only the second i32 will be actually used. As the memset writes up
+  // to half of it, only 2 bytes will be set.
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 6 bytes means it will span over the first i32 and half of the second i32.
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 2 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 6}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @memset_full
 llvm.func @memset_full() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -64,6 +103,25 @@ llvm.func @memset_full() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_full
+llvm.func @memset_inline_full() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 40 bytes means it will span over the entire array.
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 4 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 40}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @memset_too_much
 llvm.func @memset_too_much() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -84,6 +142,24 @@ llvm.func @memset_too_much() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_too_much
+llvm.func @memset_inline_too_much() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32>
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 41 bytes means it will span over the entire array, and then some.
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 41 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 41}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @memset_no_volatile
 llvm.func @memset_no_volatile() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -103,6 +179,23 @@ llvm.func @memset_no_volatile() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_no_volatile
+llvm.func @memset_inline_no_volatile() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32>
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = true, len = 16 : i64}>
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = true, len = 16}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @indirect_memset
 llvm.func @indirect_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -123,6 +216,24 @@ llvm.func @indirect_memset() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @indirect_memset_inline
+llvm.func @indirect_memset_inline() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // This memset will only cover the selected element.
+  %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 4 : i64}>
+  "llvm.intr.memset.inline"(%2, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @invalid_indirect_memset
 llvm.func @invalid_indirect_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -144,6 +255,25 @@ llvm.func @invalid_indirect_memset() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @invalid_indirect_memset_inline
+llvm.func @invalid_indirect_memset_inline() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.struct<"foo", (i32, i32)>
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // This memset will go slightly beyond one of the elements.
+  // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0]
+  %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)>
+  // CHECK: "llvm.intr.memset.inline"(%[[GEP]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 6 : i64}>
+  "llvm.intr.memset.inline"(%2, %memset_value) <{isVolatile = false, len = 6}> : (!llvm.ptr, i8) -> ()
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  llvm.return %3 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @memset_double_use
 llvm.func @memset_double_use() -> i32 {
   // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -176,6 +306,35 @@ llvm.func @memset_double_use() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_double_use
+llvm.func @memset_inline_double_use() -> i32 {
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
+  // CHECK: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // We expect two generated memset, one for each field.
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 4 : i64}>
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 4 : i64}>
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // 8 bytes means it will span over the two i32 entries.
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 8}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  %4 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
+  %5 = llvm.load %4 : !llvm.ptr -> f32
+  // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+  %6 = llvm.bitcast %5 : f32 to i32
+  %7 = llvm.add %3, %6 : i32
+  llvm.return %7 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @memset_considers_alignment
 llvm.func @memset_considers_alignment() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -207,6 +366,35 @@ llvm.func @memset_considers_alignment() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_considers_alignment
+llvm.func @memset_inline_considers_alignment() -> i32 {
+  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // 8 bytes means it will span over the i8 and the i32 entry.
+  // Because of padding, the f32 entry will not be touched.
+  // Even though the two i32 are used, only one memset should be generated,
+  // as the second i32 is not touched by the initial memset.
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 4 : i64}>
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 8}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  %4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)>
+  %5 = llvm.load %4 : !llvm.ptr -> f32
+  // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+  %6 = llvm.bitcast %5 : f32 to i32
+  %7 = llvm.add %3, %6 : i32
+  llvm.return %7 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @memset_considers_packing
 llvm.func @memset_considers_packing() -> i32 {
   // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -239,6 +427,35 @@ llvm.func @memset_considers_packing() -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_considers_packing
+llvm.func @memset_inline_considers_packing() -> i32 {
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
+  // CHECK: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  // Now all fields are touched by the memset.
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 4 : i64}>
+  // CHECK: "llvm.intr.memset.inline"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]]) <{isVolatile = false, len = 3 : i64}>
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // 8 bytes means it will span over all the fields, because there is no padding as the struct is packed.
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 8}> : (!llvm.ptr, i8) -> ()
+  %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)>
+  %3 = llvm.load %2 : !llvm.ptr -> i32
+  %4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)>
+  %5 = llvm.load %4 : !llvm.ptr -> f32
+  // We use this exotic bitcast to use the f32 easily. Semantics do not matter here.
+  %6 = llvm.bitcast %5 : f32 to i32
+  %7 = llvm.add %3, %6 : i32
+  llvm.return %7 : i32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @memcpy_dest
 // CHECK-SAME: (%[[OTHER_ARRAY:.*]]: !llvm.ptr)
 llvm.func @memcpy_dest(%other_array: !llvm.ptr) -> i32 {

diff  --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 606b11175f572f..e857e252ff0839 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -505,6 +505,10 @@ define void @memmove_test(i32 %0, ptr %1, ptr %2) {
 define void @memset_test(i32 %0, ptr %1, i8 %2) {
   ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   call void @llvm.memset.p0.i32(ptr %1, i8 %2, i32 %0, i1 false)
+  ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
+  call void @llvm.memset.inline.p0.i64(ptr %1, i8 %2, i64 10, i1 false)
+  ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
+  call void @llvm.memset.inline.p0.i32(ptr %1, i8 %2, i32 10, i1 false)
   ret void
 }
 

diff  --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index cb712eb4e1262d..9d45f219cf746e 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -533,6 +533,10 @@ llvm.func @memset_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: i8) {
   %i1 = llvm.mlir.constant(false) : i1
   // CHECK: call void @llvm.memset.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false
   "llvm.intr.memset"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  // CHECK: call void @llvm.memset.inline.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 10, i1 true
+  "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
+  // CHECK: call void @llvm.memset.inline.p0.i64(ptr %{{.*}}, i8 %{{.*}}, i64 10, i1 true
+  "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
   llvm.return
 }
 


        


More information about the Mlir-commits mailing list