[Mlir-commits] [mlir] [mlir][transform] Add PromoteTensorOp (PR #158318)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 12 09:15:30 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Hendrik_Klug (Jimmy2027)

<details>
<summary>Changes</summary>

Transform op to request a tensor value to live in a specific memory space after bufferization



---
Full diff: https://github.com/llvm/llvm-project/pull/158318.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+45-4) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+90-26) 
- (added) mlir/test/Dialect/Transform/test-promote-tensors.mlir (+104) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4b919a8..b4c62baad11bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/RegionKindInterface.td"
 
@@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
                       Transform_AnyOpType:$new_ops);
   let assemblyFormat = "$target attr-dict `:` type($target)";
   let hasVerifier = 1;
+}
 
-  let builders = [
-    OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
-    OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
-  ];
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+def PromoteTensorOp : Op<Transform_Dialect, "structured.promote_tensor",
+                         [DeclareOpInterfaceMethods<TransformOpInterface>,
+                          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+                          SameOperandsAndResultType]> {
+  let summary = "Request a tensor value to live in a specific memory space "
+                "after bufferization";
+  let description = [{
+    Requests that a tensor value lives in a specific memory space for its
+    lifetime. This is achieved by allocating a new tensor in the desired
+    memory space with `bufferization.alloc_tensor` and optionally materializing
+    the source value into that allocation with
+    `bufferization.materialize_in_destination`. All uses of the original value
+    are then redirected to the promoted value.
+
+    The generated code for promoting tensor value %0 resembles the following:
+
+      %1 = bufferization.alloc_tensor(<dynamic dims of %0>)
+           { memory_space = memory_space }
+      // Note: the materialization is omitted if %0 is never read and is only
+      // written into (i.e., it behaves as a result tensor).
+      %2 = bufferization.materialize_in_destination %0 in %1
+      // ...
+      <all users of %0 now use %2 instead>
+
+    Deallocation is not handled by this transform.
+
+    Return modes:
+    - Produces a silenceable failure if the given handle does not point to
+      tensor-typed values.
+    - Succeeds otherwise and returns a handle to the promoted value(s), i.e.,
+      the result of materialization if present and the allocation otherwise.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$tensor,
+      OptionalAttr<AnyAttr>:$memory_space);
+  let results = (outs TransformValueHandleTypeInterface:$promoted);
+
+  let assemblyFormat =
+      "(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..017886ef4fcd3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -41,6 +41,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/LogicalResult.h"
@@ -272,32 +273,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
 
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
-                                               OperationState &result,
-                                               Value target,
-                                               Attribute memorySpace) {
-  SmallVector<Type> resultTypes;
-  resultTypes.push_back(b.getType<transform::AnyValueType>());
-  resultTypes.push_back(b.getType<transform::AnyOpType>());
-  return build(b, result,
-               /*resultTypes=*/resultTypes,
-               /*target=*/target,
-               /*memorySpace=*/memorySpace);
-}
-
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
-                                               OperationState &result,
-                                               Value target,
-                                               int64_t memorySpace) {
-  SmallVector<Type> resultTypes;
-  resultTypes.push_back(b.getType<transform::AnyValueType>());
-  resultTypes.push_back(b.getType<transform::AnyOpType>());
-  return build(b, result,
-               /*resultTypes=*/resultTypes,
-               /*target=*/target,
-               /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
-}
-
 namespace {
 class NewOpsListener : public RewriterBase::ForwardingListener {
 public:
@@ -407,6 +382,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the operand may be read from by its owner. This is currently
+/// very conservative and only looks inside linalg operations to prevent
+/// unintentional data loss.
+static bool mayBeRead(OpOperand &operand) {
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+
+  // Be conservative about ops we cannot analyze deeper.
+  if (!linalgOp)
+    return true;
+
+  // Look inside linalg ops.
+  Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
+  return !blockArgument.use_empty();
+}
+
+/// Return true if the value may be read through any of its uses.
+static bool mayBeRead(Value value) {
+  // If the value has a reference semantics, it
+  // may be read through any alias...
+  if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
+    return true;
+  return llvm::any_of(value.getUses(),
+                      static_cast<bool (&)(OpOperand &)>(mayBeRead));
+}
+
+DiagnosedSilenceableFailure
+transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
+                                  transform::TransformResults &results,
+                                  transform::TransformState &state) {
+  SmallVector<Value> promoted;
+  for (Value tensor : state.getPayloadValues(getTensor())) {
+    auto type = dyn_cast<RankedTensorType>(tensor.getType());
+    if (!type) {
+      return emitSilenceableError() << "non-tensor type: " << tensor;
+    }
+
+    Operation *definingOp = tensor.getDefiningOp();
+    if (definingOp)
+      rewriter.setInsertionPointAfter(definingOp);
+    else
+      rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
+
+    // Check this before we emit operations using this value.
+    bool needsMaterialization = mayBeRead(tensor);
+
+    SmallVector<Value> dynamicDims;
+    llvm::SmallPtrSet<Operation *, 4> preservedOps;
+    for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
+      if (!ShapedType::isDynamic(dim))
+        continue;
+      Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
+      auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+      preservedOps.insert(dimOp);
+      dynamicDims.push_back(dimOp);
+    }
+    auto allocation = rewriter.create<bufferization::AllocTensorOp>(
+        tensor.getLoc(), type, dynamicDims);
+    // Set memory space if provided.
+    if (getMemorySpaceAttr())
+      allocation.setMemorySpaceAttr(getMemorySpaceAttr());
+    Value allocated = allocation;
+
+    // Only insert a materialization (typically bufferizes to a copy) when the
+    // value may be read from.
+    if (needsMaterialization) {
+      auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
+          tensor.getLoc(), tensor, allocated);
+      preservedOps.insert(copy);
+      promoted.push_back(copy.getResult());
+    } else {
+      promoted.push_back(allocated);
+    }
+    rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
+  }
+  results.setValues(cast<OpResult>(getPromoted()), promoted);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PromoteTensorOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTensorMutable(), effects);
+  transform::producesHandle(getOperation()->getOpResults(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // DecomposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-promote-tensors.mlir b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
new file mode 100644
index 0000000000000..bc9a05af64156
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @promote_in0
+// CHECK-SAME:  (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}, %{{.*}})
+// CHECK:  %[[C0:.+]] = arith.constant 0
+// CHECK:  %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK:  %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64}
+// CHECK:  %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]]
+// CHECK:  linalg.matmul ins(%[[MAT]], %{{.*}}
+func.func @promote_in0(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+                       outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%root: !transform.any_op) {
+        %mm = transform.structured.match ops{["linalg.matmul"]} in %root
+            : (!transform.any_op) -> !transform.any_op
+        %op0 = transform.get_operand %mm[0]
+            : (!transform.any_op) -> !transform.any_value
+        transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+        transform.yield
+    }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_out
+// CHECK-SAME: (%{{.*}}: tensor<?x42xf32>, %{{.*}}: tensor<?x42xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_out(%arg0: tensor<?x42xf32>, %arg1: tensor<?x42xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    // CHECK:  %[[C0:.+]] = arith.constant 0
+    // CHECK:  %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+    // CHECK:  %[[C1:.+]] = arith.constant 1
+    // CHECK:  %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
+    // CHECK:  %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64}
+    // CHECK-NOT: materialize_in_destination
+    // CHECK:  linalg.add {{.*}} outs(%[[ALLOC]]
+    %0 = linalg.add ins(%arg0, %arg1 : tensor<?x42xf32>, tensor<?x42xf32>)
+                    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%root: !transform.any_op) {
+        %la = transform.structured.match ops{["linalg.add"]} in %root
+            : (!transform.any_op) -> !transform.any_op
+        %init = transform.get_operand %la[2]
+                : (!transform.any_op) -> !transform.any_value
+        transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+        transform.yield
+    }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_in0_out_bufferize
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_in0_out_bufferize(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    // CHECK:  %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %{{.+}} = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[C0:.+]] = arith.constant 0 : index
+    // CHECK:  %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[C1:.+]] = arith.constant 1 : index
+    // CHECK:  %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref<?x?xf32, 1>
+    // CHECK:  %{{.+}} = arith.constant 0 : index
+    // CHECK:  %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref<?x42xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref<?x42xf32, 1>
+    // CHECK:  memref.copy %[[IN0]], %[[ALLOC_IN]] : memref<?x42xf32, strided<[?, ?], offset: ?>> to memref<?x42xf32, 1>
+    // CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref<?x42xf32, 1>, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref<?x?xf32, 1>)
+    %0 = linalg.add ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+                    outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%root: !transform.any_op) {
+        %la = transform.structured.match ops{["linalg.add"]} in %root
+            : (!transform.any_op) -> !transform.any_op
+        %op0 = transform.get_operand %la[0]
+            : (!transform.any_op) -> !transform.any_value
+        transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+
+        %init = transform.get_operand %la[2]
+                : (!transform.any_op) -> !transform.any_value
+        transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+        %func = transform.structured.match ops{["func.func"]} in %root
+                : (!transform.any_op) -> !transform.any_op
+
+        %bufferized = transform.bufferization.one_shot_bufferize %func
+            : (!transform.any_op) -> !transform.any_op
+
+        transform.yield
+    }
+}
+
+
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/158318


More information about the Mlir-commits mailing list