[Mlir-commits] [mlir] [mlir][transform] Add PromoteTensorOp (PR #158318)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 12 09:15:30 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hendrik_Klug (Jimmy2027)
<details>
<summary>Changes</summary>
Transform op to request a tensor value to live in a specific memory space after bufferization
---
Full diff: https://github.com/llvm/llvm-project/pull/158318.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+45-4)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+90-26)
- (added) mlir/test/Dialect/Transform/test-promote-tensors.mlir (+104)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4b919a8..b4c62baad11bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/RegionKindInterface.td"
@@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
Transform_AnyOpType:$new_ops);
let assemblyFormat = "$target attr-dict `:` type($target)";
let hasVerifier = 1;
+}
- let builders = [
- OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
- OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
- ];
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+def PromoteTensorOp : Op<Transform_Dialect, "structured.promote_tensor",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ SameOperandsAndResultType]> {
+ let summary = "Request a tensor value to live in a specific memory space "
+ "after bufferization";
+ let description = [{
+ Requests that a tensor value lives in a specific memory space for its
+ lifetime. This is achieved by allocating a new tensor in the desired
+ memory space with `bufferization.alloc_tensor` and optionally materializing
+ the source value into that allocation with
+ `bufferization.materialize_in_destination`. All uses of the original value
+ are then redirected to the promoted value.
+
+ The generated code for promoting tensor value %0 resembles the following:
+
+ %1 = bufferization.alloc_tensor(<dynamic dims of %0>)
+ { memory_space = memory_space }
+ // Note: the materialization is omitted if %0 is never read and is only
+ // written into (i.e., it behaves as a result tensor).
+ %2 = bufferization.materialize_in_destination %0 in %1
+ // ...
+ <all users of %0 now use %2 instead>
+
+ Deallocation is not handled by this transform.
+
+ Return modes:
+ - Produces a silenceable failure if the given handle does not point to
+ tensor-typed values.
+ - Succeeds otherwise and returns a handle to the promoted value(s), i.e.,
+ the result of materialization if present and the allocation otherwise.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface:$tensor,
+ OptionalAttr<AnyAttr>:$memory_space);
+ let results = (outs TransformValueHandleTypeInterface:$promoted);
+
+ let assemblyFormat =
+ "(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..017886ef4fcd3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -41,6 +41,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
@@ -272,32 +273,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- Attribute memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memorySpace=*/memorySpace);
-}
-
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- int64_t memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
-}
-
namespace {
class NewOpsListener : public RewriterBase::ForwardingListener {
public:
@@ -407,6 +382,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the operand may be read from by its owner. This is currently
+/// very conservative and only looks inside linalg operations to prevent
+/// unintentional data loss.
+static bool mayBeRead(OpOperand &operand) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+
+ // Be conservative about ops we cannot analyze deeper.
+ if (!linalgOp)
+ return true;
+
+ // Look inside linalg ops.
+ Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
+ return !blockArgument.use_empty();
+}
+
+/// Return true if the value may be read through any of its uses.
+static bool mayBeRead(Value value) {
+ // If the value has a reference semantics, it
+ // may be read through any alias...
+ if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
+ return true;
+ return llvm::any_of(value.getUses(),
+ static_cast<bool (&)(OpOperand &)>(mayBeRead));
+}
+
+DiagnosedSilenceableFailure
+transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> promoted;
+ for (Value tensor : state.getPayloadValues(getTensor())) {
+ auto type = dyn_cast<RankedTensorType>(tensor.getType());
+ if (!type) {
+ return emitSilenceableError() << "non-tensor type: " << tensor;
+ }
+
+ Operation *definingOp = tensor.getDefiningOp();
+ if (definingOp)
+ rewriter.setInsertionPointAfter(definingOp);
+ else
+ rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
+
+ // Check this before we emit operations using this value.
+ bool needsMaterialization = mayBeRead(tensor);
+
+ SmallVector<Value> dynamicDims;
+ llvm::SmallPtrSet<Operation *, 4> preservedOps;
+ for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
+ if (!ShapedType::isDynamic(dim))
+ continue;
+ Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
+ auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ preservedOps.insert(dimOp);
+ dynamicDims.push_back(dimOp);
+ }
+ auto allocation = rewriter.create<bufferization::AllocTensorOp>(
+ tensor.getLoc(), type, dynamicDims);
+ // Set memory space if provided.
+ if (getMemorySpaceAttr())
+ allocation.setMemorySpaceAttr(getMemorySpaceAttr());
+ Value allocated = allocation;
+
+ // Only insert a materialization (typically bufferizes to a copy) when the
+ // value may be read from.
+ if (needsMaterialization) {
+ auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
+ tensor.getLoc(), tensor, allocated);
+ preservedOps.insert(copy);
+ promoted.push_back(copy.getResult());
+ } else {
+ promoted.push_back(allocated);
+ }
+ rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
+ }
+ results.setValues(cast<OpResult>(getPromoted()), promoted);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PromoteTensorOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTensorMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-promote-tensors.mlir b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
new file mode 100644
index 0000000000000..bc9a05af64156
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @promote_in0
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}, %{{.*}})
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64}
+// CHECK: %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]]
+// CHECK: linalg.matmul ins(%[[MAT]], %{{.*}}
+func.func @promote_in0(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+ outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %mm = transform.structured.match ops{["linalg.matmul"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %op0 = transform.get_operand %mm[0]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_out
+// CHECK-SAME: (%{{.*}}: tensor<?x42xf32>, %{{.*}}: tensor<?x42xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_out(%arg0: tensor<?x42xf32>, %arg1: tensor<?x42xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
+ // CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64}
+ // CHECK-NOT: materialize_in_destination
+ // CHECK: linalg.add {{.*}} outs(%[[ALLOC]]
+ %0 = linalg.add ins(%arg0, %arg1 : tensor<?x42xf32>, tensor<?x42xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %la = transform.structured.match ops{["linalg.add"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %init = transform.get_operand %la[2]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_in0_out_bufferize
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_in0_out_bufferize(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref<?x?xf32, 1>
+ // CHECK: %{{.+}} = arith.constant 0 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref<?x42xf32, 1>
+ // CHECK: memref.copy %[[IN0]], %[[ALLOC_IN]] : memref<?x42xf32, strided<[?, ?], offset: ?>> to memref<?x42xf32, 1>
+ // CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref<?x42xf32, 1>, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref<?x?xf32, 1>)
+ %0 = linalg.add ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+ outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %la = transform.structured.match ops{["linalg.add"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %op0 = transform.get_operand %la[0]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+
+ %init = transform.get_operand %la[2]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+ %func = transform.structured.match ops{["func.func"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+
+ %bufferized = transform.bufferization.one_shot_bufferize %func
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.yield
+ }
+}
+
+
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/158318
More information about the Mlir-commits
mailing list