[Mlir-commits] [mlir] 2026501 - [MLIR] Make `OneShotModuleBufferize` use `OpInterface` (#110322)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 1 06:58:55 PDT 2024
Author: Tzung-Han Juang
Date: 2024-10-01T15:58:52+02:00
New Revision: 2026501cf107fcb3cbd51026ba25fda3af823941
URL: https://github.com/llvm/llvm-project/commit/2026501cf107fcb3cbd51026ba25fda3af823941
DIFF: https://github.com/llvm/llvm-project/commit/2026501cf107fcb3cbd51026ba25fda3af823941.diff
LOG: [MLIR] Make `OneShotModuleBufferize` use `OpInterface` (#110322)
**Description:**
This PR replaces a part of `FuncOp` and `CallOp` with
`FunctionOpInterface` and `CallOpInterface` in `OneShotModuleBufferize`.
Also fix the error from an integration test in the a previous PR
attempt. (https://github.com/llvm/llvm-project/pull/107295)
The below fixes skip `CallOpInterface` so that the assertions are not
triggered.
https://github.com/llvm/llvm-project/blob/8d780007625108a7f34e40efb8604b858e04c60c/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp#L254-L259
https://github.com/llvm/llvm-project/blob/8d780007625108a7f34e40efb8604b858e04c60c/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp#L311-L315
**Related Discord Discussion:**
[Link](https://discord.com/channels/636084430946959380/642426447167881246/1280556809911799900)
---------
Co-authored-by: erick-xanadu <110487834+erick-xanadu at users.noreply.github.com>
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
mlir/test/Dialect/LLVM/transform-e2e.mlir
mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
mlir/test/Dialect/Vector/transform-vector.mlir
mlir/test/Examples/transform/ChH/full.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index aceb9d059b95f3..d19687ec9afee1 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -11,6 +11,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfoVariant.h"
#include "llvm/ADT/SetVector.h"
@@ -260,9 +261,9 @@ struct BufferizationOptions {
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, func op, bufferization options
- using FunctionArgTypeConverterFn =
- std::function<BaseMemRefType(TensorType, Attribute memorySpace,
- func::FuncOp, const BufferizationOptions &)>;
+ using FunctionArgTypeConverterFn = std::function<BaseMemRefType(
+ TensorType, Attribute memorySpace, FunctionOpInterface,
+ const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index 0b91d3d675b7c9..8bed0dfc5814b7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
- DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;
+ DenseMap<FunctionOpInterface, IndexMapping> equivalentFuncArgs;
/// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
- DenseMap<FuncOp, IndexToIndexListMapping> aliasingReturnVals;
+ DenseMap<FunctionOpInterface, IndexToIndexListMapping> aliasingReturnVals;
/// A set of all read BlockArguments of FuncOps.
- DenseMap<FuncOp, BbArgIndexSet> readBbArgs;
+ DenseMap<FunctionOpInterface, BbArgIndexSet> readBbArgs;
/// A set of all written-to BlockArguments of FuncOps.
- DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;
+ DenseMap<FunctionOpInterface, BbArgIndexSet> writtenBbArgs;
/// Keep track of which FuncOps are fully analyzed or currently being
/// analyzed.
- DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
+ DenseMap<FunctionOpInterface, FuncOpAnalysisState> analyzedFuncOps;
/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
- void startFunctionAnalysis(FuncOp funcOp);
+ void startFunctionAnalysis(FunctionOpInterface funcOp);
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 85604eef2f2830..92f757111cbaf7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
@@ -314,7 +315,7 @@ namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
BaseMemRefType
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
- func::FuncOp funcOp,
+ FunctionOpInterface funcOp,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
}
@@ -361,7 +362,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
- func::FuncOp funcOp,
+ FunctionOpInterface funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..9749a71f3514bc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -22,7 +22,7 @@ namespace mlir {
namespace bufferization {
namespace func_ext {
-void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
+void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
auto createdAliasingResults =
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..a0e5c7fff7690f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
using namespace mlir::bufferization::func_ext;
/// A mapping of FuncOps to their callers.
-using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
+using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
/// Get or create FuncAnalysisState.
static FuncAnalysisState &
@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
- for (Block &b : funcOp.getBody()) {
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
+static Operation *getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
+ Operation *returnOp = nullptr;
+ for (Block &b : funcOp.getFunctionBody()) {
+ auto candidateOp = b.getTerminator();
+ if (candidateOp && candidateOp->hasTrait<OpTrait::ReturnLike>()) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
@@ -126,16 +127,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
/// Store function BlockArguments that are equivalent to/aliasing a returned
/// value in FuncAnalysisState.
static LogicalResult
-aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
+ OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
- if (funcOp.getBody().empty()) {
+ if (funcOp.getFunctionBody().empty()) {
// No function body available. Conservatively assume that every tensor
// return value may alias with any tensor bbArg.
- FunctionType type = funcOp.getFunctionType();
- for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
+ for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
if (!isa<TensorType>(inputIt.value()))
continue;
- for (const auto &resultIt : llvm::enumerate(type.getResults())) {
+ for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
if (!isa<TensorType>(resultIt.value()))
continue;
int64_t returnIdx = resultIt.index();
@@ -147,7 +148,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
}
// Support only single return-terminated block in the function.
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
for (OpOperand &returnVal : returnOp->getOpOperands())
@@ -168,8 +169,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
return success();
}
-static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
- bool isWritten) {
+static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
+ bool isRead, bool isWritten) {
OpBuilder b(funcOp.getContext());
Attribute accessType;
if (isRead && isWritten) {
@@ -189,12 +190,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
/// function with unknown ops, we conservatively assume that such ops bufferize
/// to a read + write.
static LogicalResult
-funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
+ OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
- for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
- ++idx) {
+ for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) {
// Skip non-tensor arguments.
- if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
+ if (!isa<TensorType>(funcOp.getArgumentTypes()[idx]))
continue;
bool isRead;
bool isWritten;
@@ -204,7 +205,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
StringRef str = accessAttr.getValue();
isRead = str == "read" || str == "read-write";
isWritten = str == "write" || str == "read-write";
- } else if (funcOp.getBody().empty()) {
+ } else if (funcOp.getFunctionBody().empty()) {
// If the function has no body, conservatively assume that all args are
// read + written.
isRead = true;
@@ -230,20 +231,19 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
/// Remove bufferization attributes on FuncOp arguments.
static void removeBufferizationAttributes(BlockArgument bbArg) {
- auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
+ auto funcOp = cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizationDialect::kBufferLayoutAttrName);
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizationDialect::kWritableAttrName);
}
-/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
- return dyn_cast_or_null<func::FuncOp>(
+ return dyn_cast_or_null<FunctionOpInterface>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
@@ -251,12 +251,13 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
/// Note: This only adds new equivalence info if the called function was already
/// analyzed.
// TODO: This does not handle cyclic function call graphs etc.
-static void equivalenceAnalysis(func::FuncOp funcOp,
+static void equivalenceAnalysis(FunctionOpInterface funcOp,
OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
- funcOp->walk([&](func::CallOp callOp) {
- func::FuncOp calledFunction = getCalledFunction(callOp);
- assert(calledFunction && "could not retrieved called func::FuncOp");
+ funcOp->walk([&](CallOpInterface callOp) {
+ FunctionOpInterface calledFunction = getCalledFunction(callOp);
+ if (!calledFunction)
+ return WalkResult::skip();
// No equivalence info available for the called function.
if (!funcState.equivalentFuncArgs.count(calledFunction))
@@ -267,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
int64_t bbargIdx = it.second;
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
continue;
- Value returnVal = callOp.getResult(returnIdx);
+ Value returnVal = callOp->getResult(returnIdx);
Value argVal = callOp->getOperand(bbargIdx);
state.unionEquivalenceClasses(returnVal, argVal);
}
@@ -277,11 +278,9 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
}
/// Return "true" if the given function signature has tensor semantics.
-static bool hasTensorSignature(func::FuncOp funcOp) {
- return llvm::any_of(funcOp.getFunctionType().getInputs(),
- llvm::IsaPred<TensorType>) ||
- llvm::any_of(funcOp.getFunctionType().getResults(),
- llvm::IsaPred<TensorType>);
+static bool hasTensorSignature(FunctionOpInterface funcOp) {
+ return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred<TensorType>) ||
+ llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred<TensorType>);
}
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
@@ -291,16 +290,16 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// retrieve the called FuncOp from any func::CallOp.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
- DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
+ DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
- DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
- WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
- if (!funcOp.getBody().empty()) {
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
+ WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
+ if (!funcOp.getFunctionBody().empty()) {
+ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
@@ -309,9 +308,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
- return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp);
- assert(calledFunction && "could not retrieved called func::FuncOp");
+ return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
+ FunctionOpInterface calledFunction = getCalledFunction(callOp);
+ if (!calledFunction)
+ return WalkResult::skip();
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
@@ -349,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
/// most generic layout map as function return types. After bufferizing the
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
-static void foldMemRefCasts(func::FuncOp funcOp) {
- if (funcOp.getBody().empty())
+static void foldMemRefCasts(FunctionOpInterface funcOp) {
+ if (funcOp.getFunctionBody().empty())
return;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
SmallVector<Type> resultTypes;
for (OpOperand &operand : returnOp->getOpOperands()) {
@@ -365,8 +365,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
}
- auto newFuncType = FunctionType::get(
- funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
+ auto newFuncType = FunctionType::get(funcOp.getContext(),
+ funcOp.getArgumentTypes(), resultTypes);
funcOp.setType(newFuncType);
}
@@ -379,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<FunctionOpInterface> orderedFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
@@ -388,7 +388,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return failure();
// Analyze ops.
- for (func::FuncOp funcOp : orderedFuncOps) {
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;
@@ -416,7 +416,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
void mlir::bufferization::removeBufferizationAttributesInModule(
ModuleOp moduleOp) {
- moduleOp.walk([&](func::FuncOp op) {
+ moduleOp.walk([&](FunctionOpInterface op) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationAttributes(bbArg);
});
@@ -430,7 +430,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<FunctionOpInterface> orderedFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
@@ -439,11 +439,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
return failure();
// Bufferize functions.
- for (func::FuncOp funcOp : orderedFuncOps) {
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
- if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
+ if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
// This function was not analyzed and RaW conflicts were not resolved.
// Buffer copies must be inserted before every write.
OneShotBufferizationOptions updatedOptions = options;
@@ -463,7 +463,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Bufferize all other ops.
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
- if (isa<func::FuncOp>(&op))
+ if (isa<FunctionOpInterface>(&op))
continue;
if (failed(bufferizeOp(&op, options, statistics)))
return failure();
@@ -490,12 +490,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
- auto func = dyn_cast<func::FuncOp>(op);
+ auto func = dyn_cast<FunctionOpInterface>(op);
if (!func)
- func = op->getParentOfType<func::FuncOp>();
+ func = op->getParentOfType<FunctionOpInterface>();
if (func)
return llvm::is_contained(options.noAnalysisFuncFilter,
- func.getSymName());
+ func.getName());
return false;
};
OneShotBufferizationOptions updatedOptions(options);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
index 3c50a9e72d9d9b..588aa8a85a84e6 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --transform-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" %s -split-input-file -verify-diagnostics | FileCheck %s
// Test One-Shot Bufferize.
@@ -12,19 +12,21 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @test_function(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
-func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
- %c0 = arith.constant 0 : index
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
- // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
- // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
- // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
- // CHECK: memref.copy %[[A_memref]], %[[alloc]]
- // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
- // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
- %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+ // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+ // CHECK: memref.copy %[[A_memref]], %[[alloc]]
+ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
- // CHECK: return %[[res_tensor]]
- return %0 : tensor<?xf32>
+ // CHECK: return %[[res_tensor]]
+ return %0 : tensor<?xf32>
+ }
}
// -----
@@ -42,19 +44,21 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @test_function(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
// CHECK-NOT: memref.copy
-func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
- %c0 = arith.constant 0 : index
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
- // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
- // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
- // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
- // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]]
- // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
- // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
- %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+ // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+ // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]]
+ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
- // CHECK: return %[[res_tensor]]
- return %0 : tensor<?xf32>
+ // CHECK: return %[[res_tensor]]
+ return %0 : tensor<?xf32>
+ }
}
// -----
@@ -72,13 +76,15 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @test_function_analysis(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
-func.func @test_function_analysis(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
- %c0 = arith.constant 0 : index
- // CHECK: vector.transfer_write
- // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]}
- // CHECK-SAME: tensor<?xf32>
- %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
- return %0 : tensor<?xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @test_function_analysis(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]}
+ // CHECK-SAME: tensor<?xf32>
+ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+ return %0 : tensor<?xf32>
+ }
}
// -----
@@ -95,10 +101,12 @@ module attributes {transform.with_named_sequence} {
}
}
-func.func @test_unknown_op_failure() -> (tensor<?xf32>) {
- // expected-error @+1 {{op was not bufferized}}
- %0 = "test.dummy_op"() : () -> (tensor<?xf32>)
- return %0 : tensor<?xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @test_unknown_op_failure() -> (tensor<?xf32>) {
+ // expected-error @+1 {{op was not bufferized}}
+ %0 = "test.dummy_op"() : () -> (tensor<?xf32>)
+ return %0 : tensor<?xf32>
+ }
}
// -----
@@ -111,7 +119,7 @@ module attributes {transform.with_named_sequence} {
}
}
-module {
+module @payload attributes { transform.target_tag = "payload" } {
// CHECK-LABEL: func @test_function(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
@@ -146,11 +154,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[A:.*]]: memref<12x9xf32>,
// CHECK-SAME: %[[B:.*]]: memref<9x6xf32>,
// CHECK-SAME: %[[C:.*]]: memref<12x6xf32>) -> memref<12x6xf32> {
-func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> {
- // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>)
- %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32>
- // CHECK: return %[[C]] : memref<12x6xf32>
- return %D : tensor<12x6xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> {
+ // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>)
+ %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32>
+ // CHECK: return %[[C]] : memref<12x6xf32>
+ return %D : tensor<12x6xf32>
+ }
}
// -----
@@ -165,10 +175,12 @@ module attributes {transform.with_named_sequence} {
}
// Expect `bufferization.empty_tensor_to_alloc_tensor` to replace the tensor.empty.
-func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
- // CHECK: bufferization.alloc_tensor
- %0 = tensor.empty() : tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
+ // CHECK: bufferization.alloc_tensor
+ %0 = tensor.empty() : tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+ }
}
// -----
@@ -185,13 +197,15 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.extract_slice
// CHECK: linalg.fill
// CHECK: tensor.insert_slice
-func.func @empty_tensor_elimination(
- %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> {
- %0 = tensor.empty() : tensor<5xf32>
- %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
- %2 = tensor.insert_slice %1 into %t [1][5][1]
- : tensor<5xf32> into tensor<10xf32>
- return %2 : tensor<10xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @empty_tensor_elimination(
+ %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> {
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ %2 = tensor.insert_slice %1 into %t [1][5][1]
+ : tensor<5xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+ }
}
// -----
@@ -208,12 +222,14 @@ module attributes {transform.with_named_sequence} {
// CHECK: memref.alloca
// CHECK: scf.for
// CHECK: memref.store
-func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %pos: index) {
- scf.for %iv = %lb to %ub step %step {
- %0 = memref.alloca() : memref<5xf32>
- memref.store %f, %0[%pos] : memref<5xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %pos: index) {
+ scf.for %iv = %lb to %ub step %step {
+ %0 = memref.alloca() : memref<5xf32>
+ memref.store %f, %0[%pos] : memref<5xf32>
+ }
+ return
}
- return
}
// -----
@@ -231,10 +247,12 @@ module attributes {transform.with_named_sequence} {
// Expect `bufferization.bufferize_to_allocation` to create an alloc.
// CHECK-LABEL: func.func @empty_to_tensor_alloc()
-func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
- // CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32>
- // CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32>
- // CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32>
- %0 = bufferization.alloc_tensor() : tensor<2x2xf32>
- return %0 : tensor<2x2xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
+ // CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32>
+ // CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32>
+ // CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32>
+ %0 = bufferization.alloc_tensor() : tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+ }
}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e97..3e637a3ec49a42 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -1,15 +1,17 @@
-// RUN: mlir-opt %s --transform-interpreter -test-transform-dialect-erase-schedule --test-lower-to-llvm --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter="debug-payload-root-tag=payload" -test-transform-dialect-erase-schedule --test-lower-to-llvm --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @matmul_tensors
-func.func @matmul_tensors(
- %arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>)
- -> tensor<2x6xf32> {
-// CHECK-NOT: linalg
-// CHECK: llvm.intr.fmuladd{{.*}}
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>)
- outs(%arg2: tensor<2x6xf32>)
- -> tensor<2x6xf32>
- return %0 : tensor<2x6xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @matmul_tensors(
+ %arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>)
+ -> tensor<2x6xf32> {
+ // CHECK-NOT: linalg
+ // CHECK: llvm.intr.fmuladd{{.*}}
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>)
+ outs(%arg2: tensor<2x6xf32>)
+ -> tensor<2x6xf32>
+ return %0 : tensor<2x6xf32>
+ }
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
index 3f8d2ea06641e1..9c223737750a9b 100644
--- a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
+++ b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file --transform-interpreter %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --transform-interpreter="debug-payload-root-tag=payload" %s | FileCheck %s
// CHECK-LABEL: func @matmul_divisible
// CHECK: scf.forall
@@ -24,19 +24,21 @@
// CHECK: scf.forall
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
-func.func @matmul_divisible(%A: tensor<1024x1024xf32>,
- %B: tensor<1024x1024xf32>,
- %C: tensor<1024x1024xf32>)
- -> tensor<1024x1024xf32>
-{
- %cst = arith.constant 0.000000e+00 : f32
- %0 = linalg.fill ins(%cst : f32)
- outs(%C : tensor<1024x1024xf32>)
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @matmul_divisible(%A: tensor<1024x1024xf32>,
+ %B: tensor<1024x1024xf32>,
+ %C: tensor<1024x1024xf32>)
-> tensor<1024x1024xf32>
- %1 = linalg.matmul ins(%A, %B : tensor<1024x1024xf32>, tensor<1024x1024xf32>)
- outs(%0 : tensor<1024x1024xf32>)
- -> tensor<1024x1024xf32>
- return %1 : tensor<1024x1024xf32>
+ {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%cst : f32)
+ outs(%C : tensor<1024x1024xf32>)
+ -> tensor<1024x1024xf32>
+ %1 = linalg.matmul ins(%A, %B : tensor<1024x1024xf32>, tensor<1024x1024xf32>)
+ outs(%0 : tensor<1024x1024xf32>)
+ -> tensor<1024x1024xf32>
+ return %1 : tensor<1024x1024xf32>
+ }
}
module attributes {transform.with_named_sequence} {
@@ -143,19 +145,21 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.matmul
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
+module @payload attributes { transform.target_tag = "payload" } {
func.func @matmul_not_divisible(%A: tensor<1023x1023xf32>,
- %B: tensor<1023x1023xf32>,
- %C: tensor<1023x1023xf32>)
- -> tensor<1023x1023xf32>
-{
- %cst = arith.constant 0.000000e+00 : f32
- %0 = linalg.fill ins(%cst : f32)
- outs(%C : tensor<1023x1023xf32>)
+ %B: tensor<1023x1023xf32>,
+ %C: tensor<1023x1023xf32>)
-> tensor<1023x1023xf32>
- %1 = linalg.matmul ins(%A, %B : tensor<1023x1023xf32>, tensor<1023x1023xf32>)
- outs(%0 : tensor<1023x1023xf32>)
- -> tensor<1023x1023xf32>
- return %1 : tensor<1023x1023xf32>
+ {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%cst : f32)
+ outs(%C : tensor<1023x1023xf32>)
+ -> tensor<1023x1023xf32>
+ %1 = linalg.matmul ins(%A, %B : tensor<1023x1023xf32>, tensor<1023x1023xf32>)
+ outs(%0 : tensor<1023x1023xf32>)
+ -> tensor<1023x1023xf32>
+ return %1 : tensor<1023x1023xf32>
+ }
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
index f2e9e839b7c46b..5e5657980ba120 100644
--- a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
+++ b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt --transform-interpreter -cse -canonicalize -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" -cse -canonicalize -split-input-file -verify-diagnostics %s | FileCheck %s
#map = affine_map<()[s0] -> (-s0 + 12, 7)>
@@ -7,43 +7,45 @@
// CHECK-SAME: %[[arg0:.*]]: memref<24x12xf32, strided<[?, ?], offset: ?>>,
// CHECK-SAME: %[[arg1:.*]]: memref<12x25xf32, strided<[?, ?], offset: ?>>,
// CHECK-SAME: %[[arg2:.*]]: memref<24x25xf32, strided<[?, ?], offset: ?>>,
-func.func @pad_to_memory_space(%arg0: tensor<24x12xf32>,
- %arg1: tensor<12x25xf32>,
- %arg2: tensor<24x25xf32>,
- %iv0 : index, %iv1 : index,
- %iv2 : index) -> tensor<24x25xf32> {
- %0 = affine.min #map()[%iv2]
-
- // CHECK: %[[s0:.*]] = memref.subview %[[arg0]]
- %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
- // CHECK: %[[s1:.*]] = memref.subview %[[arg1]]
- %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
- // CHECK: %[[s2:.*]] = memref.subview %[[arg2]]
- %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
-
- // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3>
- // CHECK: linalg.fill {{.*}} outs(%[[alloc0]]
- // CHECK: %[[alloc0_view:.*]] = memref.subview %[[alloc0]][0, 0] [4, %{{.*}}] [1, 1]
- // CHECK: memref.copy %[[s0]], %[[alloc0_view]]
-
- // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3>
- // CHECK: linalg.fill {{.*}} outs(%[[alloc1]]
- // CHECK: %[[alloc1_view:.*]] = memref.subview %[[alloc1]][0, 0] [%{{.*}}, 5] [1, 1]
- // CHECK: memref.copy %[[s1]], %[[alloc1_view]]
-
- // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3>
- // CHECK-NOT: linalg.fill {{.*}} outs(%[[alloc2]]
- // No subview because there is 0 padding
- // CHECK: memref.copy %[[s2]], %[[alloc2]]
-
- // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}})
- // Copy back result.
- // CHECK: memref.copy %[[alloc2]], %[[s2]]
- %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
-
- // insert_slice bufferizes to a no-op.
- %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
- func.return %5 : tensor<24x25xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @pad_to_memory_space(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>,
+ %iv0 : index, %iv1 : index,
+ %iv2 : index) -> tensor<24x25xf32> {
+ %0 = affine.min #map()[%iv2]
+
+ // CHECK: %[[s0:.*]] = memref.subview %[[arg0]]
+ %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
+ // CHECK: %[[s1:.*]] = memref.subview %[[arg1]]
+ %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
+ // CHECK: %[[s2:.*]] = memref.subview %[[arg2]]
+ %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
+
+ // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3>
+ // CHECK: linalg.fill {{.*}} outs(%[[alloc0]]
+ // CHECK: %[[alloc0_view:.*]] = memref.subview %[[alloc0]][0, 0] [4, %{{.*}}] [1, 1]
+ // CHECK: memref.copy %[[s0]], %[[alloc0_view]]
+
+ // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3>
+ // CHECK: linalg.fill {{.*}} outs(%[[alloc1]]
+ // CHECK: %[[alloc1_view:.*]] = memref.subview %[[alloc1]][0, 0] [%{{.*}}, 5] [1, 1]
+ // CHECK: memref.copy %[[s1]], %[[alloc1_view]]
+
+ // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3>
+ // CHECK-NOT: linalg.fill {{.*}} outs(%[[alloc2]]
+ // No subview because there is 0 padding
+ // CHECK: memref.copy %[[s2]], %[[alloc2]]
+
+ // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}})
+ // Copy back result.
+ // CHECK: memref.copy %[[alloc2]], %[[s2]]
+ %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
+
+ // insert_slice bufferizes to a no-op.
+ %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+ func.return %5 : tensor<24x25xf32>
+ }
}
module attributes {transform.with_named_sequence} {
@@ -69,40 +71,42 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[arg0:.*]]: memref<24x12xf32, strided<[?, ?], offset: ?>>,
// CHECK-SAME: %[[arg1:.*]]: memref<12x25xf32, strided<[?, ?], offset: ?>>,
// CHECK-SAME: %[[arg2:.*]]: memref<24x25xf32, strided<[?, ?], offset: ?>>,
-func.func @vectorize_and_bufferize_pad(%arg0: tensor<24x12xf32>,
- %arg1: tensor<12x25xf32>,
- %arg2: tensor<24x25xf32>,
- %iv0 : index, %iv1 : index,
- %iv2 : index) -> tensor<24x25xf32> {
- %0 = affine.min #map()[%iv2]
-
- // CHECK: %[[s0:.*]] = memref.subview %[[arg0]]
- %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
- // CHECK: %[[s1:.*]] = memref.subview %[[arg1]]
- %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
- // CHECK: %[[s2:.*]] = memref.subview %[[arg2]]
- %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
-
- // CHECK: %[[v0:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s0]]
- // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3>
- // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v0]], %[[alloc0]]
-
- // CHECK: %[[v1:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s1]]
- // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3>
- // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v1]], %[[alloc1]]
-
- // CHECK: %[[v2:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s2]]
- // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3>
- // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v2]], %[[alloc0]]
-
- // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}})
- // Copy back result.
- // CHECK: memref.copy %[[alloc2]], %[[s2]]
- %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
-
- // insert_slice bufferizes to a no-op.
- %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
- func.return %5 : tensor<24x25xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @vectorize_and_bufferize_pad(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>,
+ %iv0 : index, %iv1 : index,
+ %iv2 : index) -> tensor<24x25xf32> {
+ %0 = affine.min #map()[%iv2]
+
+ // CHECK: %[[s0:.*]] = memref.subview %[[arg0]]
+ %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
+ // CHECK: %[[s1:.*]] = memref.subview %[[arg1]]
+ %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
+ // CHECK: %[[s2:.*]] = memref.subview %[[arg2]]
+ %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
+
+ // CHECK: %[[v0:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s0]]
+ // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3>
+ // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v0]], %[[alloc0]]
+
+ // CHECK: %[[v1:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s1]]
+ // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3>
+ // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v1]], %[[alloc1]]
+
+ // CHECK: %[[v2:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s2]]
+ // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3>
+ // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v2]], %[[alloc0]]
+
+ // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}})
+ // Copy back result.
+ // CHECK: memref.copy %[[alloc2]], %[[s2]]
+ %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
+
+ // insert_slice bufferizes to a no-op.
+ %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+ func.return %5 : tensor<24x25xf32>
+ }
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 4b38db79bff3e1..0439844dc66cad 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -1,16 +1,18 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" %s --split-input-file | FileCheck %s
// CHECK-LABEL: func @matmul_tensors
-func.func @matmul_tensors(
- %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>)
- -> tensor<8x32xf32> {
-// CHECK-NOT: linalg
-// CHECK: vector.extract {{.*}} : vector<4xf32> from vector<8x4xf32>
-// CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32>
- %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>)
- outs(%arg2: tensor<8x32xf32>)
- -> tensor<8x32xf32>
- return %0 : tensor<8x32xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @matmul_tensors(
+ %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>)
+ -> tensor<8x32xf32> {
+ // CHECK-NOT: linalg
+ // CHECK: vector.extract {{.*}} : vector<4xf32> from vector<8x4xf32>
+ // CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>)
+ outs(%arg2: tensor<8x32xf32>)
+ -> tensor<8x32xf32>
+ return %0 : tensor<8x32xf32>
+ }
}
module attributes {transform.with_named_sequence} {
@@ -76,11 +78,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
// CHECK-NEXT: return %[[R]] : vector<64x64xf32>
-func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
- %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32>
- %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32>
- %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32>
- return %result : vector<64x64xf32>
+module @payload attributes { transform.target_tag = "payload" } {
+ func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
+ %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32>
+ %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32>
+ %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32>
+ return %result : vector<64x64xf32>
+ }
}
module attributes {transform.with_named_sequence} {
@@ -95,30 +99,32 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32
-// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
-// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
-// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
-// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
-func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
- %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
- %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
- %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
- %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
- return %mul: vector<[4]x[4]xi32>
-}
+module @payload attributes { transform.target_tag = "payload" } {
+ // CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32
+ // CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
+ // CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+ // CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
+ // CHECK: return %[[RES]] : vector<[4]x[4]xi32>
+ func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+ %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+ %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+ %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+ %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
+ return %mul: vector<[4]x[4]xi32>
+ }
-// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32
-// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
-// CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> {
-// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32>
-// CHECK: return %[[RES]] : vector<8x16xf32>
-func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> {
- %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32>
- %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
- %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32>
- %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32>
- return %mul: vector<8x16xf32>
+ // CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32
+ // CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
+ // CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> {
+ // CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32>
+ // CHECK: return %[[RES]] : vector<8x16xf32>
+ func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> {
+ %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32>
+ %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+ %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32>
+ %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32>
+ return %mul: vector<8x16xf32>
+ }
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Examples/transform/ChH/full.mlir b/mlir/test/Examples/transform/ChH/full.mlir
index 259475ebdbf49e..85dbf670233232 100644
--- a/mlir/test/Examples/transform/ChH/full.mlir
+++ b/mlir/test/Examples/transform/ChH/full.mlir
@@ -1,8 +1,6 @@
-// RUN: mlir-opt %s --transform-interpreter \
-// RUN: --test-transform-dialect-erase-schedule \
-// RUN: --math-uplift-to-fma \
-// RUN: --convert-bufferization-to-memref \
-// RUN: --test-lower-to-llvm |\
+// RUN: mlir-opt %s --transform-interpreter="debug-payload-root-tag=payload" \
+// RUN: --test-transform-dialect-erase-schedule |\
+// RUN: mlir-opt -pass-pipeline='builtin.module(builtin.module(math-uplift-to-fma,convert-bufferization-to-memref,test-lower-to-llvm))' - |\
// RUN: FileCheck %s
// Fixed-size tensor types to be used in convolution.
@@ -19,6 +17,7 @@
// tensors annotated with attributes from the `bufferization` dialect. These
// attributes hint the bufferization pass to assume buffers can be directly
// used for these tensors without reshaping.
+module @payload attributes { transform.target_tag = "payload" } {
func.func @conv(
%input: !tinput {bufferization.writable = false,
bufferization.access = "read",
@@ -84,7 +83,7 @@ func.func @conv(
return %relued : !toutput
}
-
+}
// Module containing the transformation script to be applied. The attribute
// is required to correctly verify the use of named (macro-like) sequences.
module attributes { transform.with_named_sequence } {
More information about the Mlir-commits
mailing list