[Mlir-commits] [mlir] [mlir][memref] Introduce `memref.distinct_objects` op (PR #156913)

Ivan Butygin llvmlistbot at llvm.org
Wed Sep 10 08:26:05 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/156913

>From 313fd900e2f94260af5e856d1bf548dcfc93fe1c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 4 Sep 2025 17:25:26 +0200
Subject: [PATCH 1/2] [mlir][memref] Introduce `memref.distinct_objects` op

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 39 ++++++++++++++-
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  | 49 +++++++++++++++++--
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 19 +++++++
 .../MemRefToLLVM/memref-to-llvm.mlir          | 19 +++++++
 mlir/test/Dialect/MemRef/ops.mlir             |  9 ++++
 5 files changed, 130 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6b7a97179b71..d4c48025c1a07 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -154,7 +154,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       The `assume_alignment` operation takes a memref and an integer alignment
       value. It returns a new SSA value of the same memref type, but associated
       with the assumption that the underlying buffer is aligned to the given
-      alignment. 
+      alignment.
 
       If the buffer isn't aligned to the given alignment, its result is poison.
       This operation doesn't affect the semantics of a program where the
@@ -169,7 +169,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
   let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
   let extraClassDeclaration = [{
     MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
-    
+
     Value getViewSource() { return getMemref(); }
   }];
 
@@ -177,6 +177,41 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// DistinctObjectsOp
+//===----------------------------------------------------------------------===//
+
+def DistinctObjectsOp : MemRef_Op<"distinct_objects", [
+      Pure,
+      DeclareOpInterfaceMethods<InferTypeOpInterface>
+      // ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument
+    ]> {
+  let summary = "assumption that acesses to specific memrefs will never alias";
+  let description = [{
+      The `distinct_objects` operation takes a list of memrefs and returns a list of
+      memrefs of the same types, with the additional assumption that accesses to
+      these memrefs will never alias with each other. This means that loads and
+      stores to different memrefs in the list can be safely reordered.
+
+      If the memrefs do alias, the behavior is undefined. This operation doesn't
+      affect the semantics of a program where the non-aliasing assumption holds
+      true. It is intended for optimization purposes, allowing the compiler to
+      generate more efficient code based on the non-aliasing assumption. The
+      optimization is best-effort.
+
+      Example:
+
+      ```mlir
+      %1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32>
+      ```
+  }];
+  let arguments = (ins Variadic<AnyMemRef>:$operands);
+  let results = (outs Variadic<AnyMemRef>:$results);
+
+  let assemblyFormat = "$operands attr-dict `:` type($operands)";
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // AllocOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 262e0e7a30c63..571e5000b3f51 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -465,6 +465,48 @@ struct AssumeAlignmentOpLowering
   }
 };
 
+struct DistinctObjectsOpLowering
+    : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
+  using ConvertOpToLLVMPattern<
+      memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
+  explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
+      : ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {}
+
+  LogicalResult
+  matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ValueRange operands = adaptor.getOperands();
+    if (operands.empty()) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+    Location loc = op.getLoc();
+    SmallVector<Value> ptrs;
+    for (auto [origOperand, newOperand] :
+         llvm::zip_equal(op.getOperands(), operands)) {
+      auto memrefType = cast<MemRefType>(origOperand.getType());
+      Value ptr = getStridedElementPtr(rewriter, loc, memrefType, newOperand,
+                                       /*indices=*/{});
+      ptrs.push_back(ptr);
+    }
+
+    auto cond =
+        LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
+    // Generate separate_storage assumptions for each pair of pointers.
+    for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
+      for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
+        Value ptr1 = ptrs[i];
+        Value ptr2 = ptrs[j];
+        LLVM::AssumeOp::create(rewriter, loc, cond,
+                               LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
+      }
+    }
+
+    rewriter.replaceOp(op, operands);
+    return success();
+  }
+};
+
 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
 // The memref descriptor being an SSA value, there is no need to clean it up
 // in any way.
@@ -1997,22 +2039,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
   patterns.add<
       AllocaOpLowering,
       AllocaScopeOpLowering,
-      AtomicRMWOpLowering,
       AssumeAlignmentOpLowering,
+      AtomicRMWOpLowering,
       ConvertExtractAlignedPointerAsIndex,
       DimOpLowering,
+      DistinctObjectsOpLowering,
       ExtractStridedMetadataOpLowering,
       GenericAtomicRMWOpLowering,
       GetGlobalMemrefOpLowering,
       LoadOpLowering,
       MemRefCastOpLowering,
-      MemorySpaceCastOpLowering,
       MemRefReinterpretCastOpLowering,
       MemRefReshapeOpLowering,
+      MemorySpaceCastOpLowering,
       PrefetchOpLowering,
       RankOpLowering,
-      ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
+      ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
       StoreOpLowering,
       SubViewOpLowering,
       TransposeOpLowering,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b59d73d1291c8..9a4dec138319d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -542,6 +542,25 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
   return getMemref();
 }
 
+//===----------------------------------------------------------------------===//
+// DistinctObjectsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DistinctObjectsOp::verify() {
+  if (getOperandTypes() != getResultTypes())
+    return emitOpError("operand types and result types must match");
+  return success();
+}
+
+LogicalResult DistinctObjectsOp::inferReturnTypes(
+    MLIRContext * /*context*/, std::optional<Location> /*location*/,
+    ValueRange operands, DictionaryAttr /*attributes*/,
+    OpaqueProperties /*properties*/, RegionRange /*regions*/,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 45b1a1f1ca40c..3eb8df093af10 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -195,6 +195,25 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
 
 // -----
 
+// ALL-LABEL: func @distinct_objects
+//  ALL-SAME:   (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
+func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
+//   ALL-DAG:   %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xf16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//   ALL-DAG:   %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//   ALL-DAG:   %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref<?xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       ALL:   %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       ALL:   %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       ALL:   %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       ALL:   %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+//       ALL:   llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1
+//       ALL:   llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
+//       ALL:   llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
+  %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+  return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+}
+
+// -----
+
 // CHECK-LABEL: func @assume_alignment_w_offset
 // CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset
 func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 6c2298a3f8acb..a90c9505a8405 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
   return
 }
 
+// CHECK-LABEL: func @distinct_objects
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
+func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
+  // CHECK:  %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<?xf16>, memref<?xf32>, memref<?xf64>
+  %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+  // CHECK:  return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+  return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+}
+
 // CHECK-LABEL: func @expand_collapse_shape_static
 func.func @expand_collapse_shape_static(
     %arg0: memref<3x4x5xf32>,

>From a3476bb40166baca3a999b7b79b499fcafca14ae Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 10 Sep 2025 17:25:38 +0200
Subject: [PATCH 2/2] verifier test

---
 mlir/test/Dialect/MemRef/invalid.mlir | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index b4476036d6513..fe3d9e5331a8c 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1169,3 +1169,11 @@ func.func @expand_shape_invalid_output_shape(
       into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
   return
 }
+
+// -----
+
+func.func @Invalid_distinct_objects(%arg0: memref<?xf32>, %arg1: memref<?xi32>) -> (memref<?xi32>, memref<?xf32>) {
+  // expected-error @+1 {{operand types and result types must match}}
+  %0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref<?xf32>, memref<?xi32>) -> (memref<?xi32>, memref<?xf32>)
+  return %0, %1 : memref<?xi32>, memref<?xf32>
+}



More information about the Mlir-commits mailing list