[Mlir-commits] [mlir] 6d4f780 - [MLIR] Support for ReturnOps in memref map layout normalization
Uday Bondhugula
llvmlistbot at llvm.org
Thu Aug 13 06:45:18 PDT 2020
Author: avarmapml
Date: 2020-08-13T19:10:47+05:30
New Revision: 6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78
URL: https://github.com/llvm/llvm-project/commit/6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78
DIFF: https://github.com/llvm/llvm-project/commit/6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78.diff
LOG: [MLIR] Support for ReturnOps in memref map layout normalization
-- This commit handles the returnOp in memref map layout normalization.
-- An initial filter is applied on FuncOps which helps us know which functions can be
a suitable candidate for memref normalization which doesn't lead to invalid IR.
-- Handles memref map normalization for external function assuming the external function
is normalizable.
Differential Revision: https://reviews.llvm.org/D85226
Added:
Modified:
mlir/lib/Transforms/NormalizeMemRefs.cpp
mlir/lib/Transforms/Utils/Utils.cpp
mlir/test/Transforms/normalize-memrefs.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index 1484dcb9e10c..1736fa989a83 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/SmallSet.h"
#define DEBUG_TYPE "normalize-memrefs"
@@ -24,39 +25,45 @@ namespace {
/// All memrefs passed across functions with non-trivial layout maps are
/// converted to ones with trivial identity layout ones.
-
-// Input :-
-// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
-// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
-// (memref<16xf64, #tile>) {
-// affine.for %arg3 = 0 to 16 {
-// %a = affine.load %A[%arg3] : memref<16xf64, #tile>
-// %p = mulf %a, %a : f64
-// affine.store %p, %A[%arg3] : memref<16xf64, #tile>
-// }
-// %c = alloc() : memref<16xf64, #tile>
-// %d = affine.load %c[0] : memref<16xf64, #tile>
-// return %A: memref<16xf64, #tile>
-// }
-
-// Output :-
-// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
-// -> memref<4x4xf64> {
-// affine.for %arg3 = 0 to 16 {
-// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
-// %3 = mulf %2, %2 : f64
-// affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
-// }
-// %0 = alloc() : memref<16xf64, #map0>
-// %1 = affine.load %0[0] : memref<16xf64, #map0>
-// return %arg0 : memref<4x4xf64>
-// }
-
+/// If all the memref types/uses in a function are normalizable, we treat
+/// such functions as normalizable. Also, if a normalizable function is known
+/// to call a non-normalizable function, we treat that function as
+/// non-normalizable as well. We assume external functions to be normalizable.
+///
+/// Input :-
+/// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
+/// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
+/// (memref<16xf64, #tile>) {
+/// affine.for %arg3 = 0 to 16 {
+/// %a = affine.load %A[%arg3] : memref<16xf64, #tile>
+/// %p = mulf %a, %a : f64
+/// affine.store %p, %A[%arg3] : memref<16xf64, #tile>
+/// }
+/// %c = alloc() : memref<16xf64, #tile>
+/// %d = affine.load %c[0] : memref<16xf64, #tile>
+/// return %A: memref<16xf64, #tile>
+/// }
+///
+/// Output :-
+/// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
+/// -> memref<4x4xf64> {
+/// affine.for %arg3 = 0 to 16 {
+/// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] :
+/// memref<4x4xf64> %3 = mulf %2, %2 : f64 affine.store %3, %arg0[%arg3
+/// floordiv 4, %arg3 mod 4] : memref<4x4xf64>
+/// }
+/// %0 = alloc() : memref<16xf64, #map0>
+/// %1 = affine.load %0[0] : memref<16xf64, #map0>
+/// return %arg0 : memref<4x4xf64>
+/// }
+///
struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
void runOnOperation() override;
- void runOnFunction(FuncOp funcOp);
+ void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
bool areMemRefsNormalizable(FuncOp funcOp);
- void updateFunctionSignature(FuncOp funcOp);
+ void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
+ void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
+ DenseSet<FuncOp> &normalizableFuncs);
};
} // end anonymous namespace
@@ -67,41 +74,109 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
void NormalizeMemRefs::runOnOperation() {
ModuleOp moduleOp = getOperation();
- // We traverse each function within the module in order to normalize the
- // memref type arguments.
- // TODO: Handle external functions.
+ // We maintain all normalizable FuncOps in a DenseSet. It is initialized
+ // with all the functions within a module and then functions which are not
+ // normalizable are removed from this set.
+ // TODO: Change this to work on FuncLikeOp once there is an operation
+ // interface for it.
+ DenseSet<FuncOp> normalizableFuncs;
+ // Initialize `normalizableFuncs` with all the functions within a module.
+ moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
+
+ // Traverse through all the functions applying a filter which determines
+ // whether that function is normalizable or not. All callers/callees of
+ // a non-normalizable function will also become non-normalizable even if
+ // they aren't passing any or specific non-normalizable memrefs. So,
+ // functions which calls or get called by a non-normalizable becomes non-
+ // normalizable functions themselves.
moduleOp.walk([&](FuncOp funcOp) {
- if (areMemRefsNormalizable(funcOp))
- runOnFunction(funcOp);
+ if (normalizableFuncs.contains(funcOp)) {
+ if (!areMemRefsNormalizable(funcOp)) {
+ // Since this function is not normalizable, we set all the caller
+ // functions and the callees of this function as not normalizable.
+ // TODO: Drop this conservative assumption in the future.
+ setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
+ normalizableFuncs);
+ }
+ }
});
+
+ // Those functions which can be normalized are subjected to normalization.
+ for (FuncOp &funcOp : normalizableFuncs)
+ normalizeFuncOpMemRefs(funcOp, moduleOp);
}
-// Return true if this operation dereferences one or more memref's.
-// TODO: Temporary utility, will be replaced when this is modeled through
-// side-effects/op traits.
+/// Return true if this operation dereferences one or more memref's.
+/// TODO: Temporary utility, will be replaced when this is modeled through
+/// side-effects/op traits.
static bool isMemRefDereferencingOp(Operation &op) {
return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
AffineDmaWaitOp>(op);
}
-// Check whether all the uses of oldMemRef are either dereferencing uses or the
-// op is of type : DeallocOp, CallOp. Only if these constraints are satisfied
-// will the value become a candidate for replacement.
+/// Check whether all the uses of oldMemRef are either dereferencing uses or the
+/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
+/// are satisfied will the value become a candidate for replacement.
+/// TODO: Extend this for DimOps.
static bool isMemRefNormalizable(Value::user_range opUsers) {
if (llvm::any_of(opUsers, [](Operation *op) {
if (isMemRefDereferencingOp(*op))
return false;
- return !isa<DeallocOp, CallOp>(*op);
+ return !isa<DeallocOp, CallOp, ReturnOp>(*op);
}))
return false;
return true;
}
-// Check whether all the uses of AllocOps, CallOps and function arguments of a
-// function are either of dereferencing type or of type: DeallocOp, CallOp. Only
-// if these constraints are satisfied will the function become a candidate for
-// normalization.
+/// Set all the calling functions and the callees of the function as not
+/// normalizable.
+void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
+ FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) {
+ if (!normalizableFuncs.contains(funcOp))
+ return;
+
+ normalizableFuncs.erase(funcOp);
+ // Caller of the function.
+ Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
+ for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+ // TODO: Extend this for ops that are FunctionLike. This would require
+ // creating an OpInterface for FunctionLike ops.
+ FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>();
+ for (FuncOp &funcOp : normalizableFuncs) {
+ if (parentFuncOp == funcOp) {
+ setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
+ normalizableFuncs);
+ break;
+ }
+ }
+ }
+
+ // Functions called by this function.
+ funcOp.walk([&](CallOp callOp) {
+ StringRef callee = callOp.getCallee();
+ for (FuncOp &funcOp : normalizableFuncs) {
+ // We compare FuncOp and callee's name.
+ if (callee == funcOp.getName()) {
+ setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
+ normalizableFuncs);
+ break;
+ }
+ }
+ });
+}
+
+/// Check whether all the uses of AllocOps, CallOps and function arguments of a
+/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
+/// or ReturnOp. Only if these constraints are satisfied will the function
+/// become a candidate for normalization. We follow a conservative approach here
+/// wherein even if the non-normalizable memref is not a part of the function's
+/// argument or return type, we still label the entire function as
+/// non-normalizable. We assume external functions to be normalizable.
bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
+ // We assume external functions to be normalizable.
+ if (funcOp.isExternal())
+ return true;
+
if (funcOp
.walk([&](AllocOp allocOp) -> WalkResult {
Value oldMemRef = allocOp.getResult();
@@ -136,28 +211,138 @@ bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
return true;
}
-// Fetch the updated argument list and result of the function and update the
-// function signature.
-void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp) {
+/// Fetch the updated argument list and result of the function and update the
+/// function signature. This updates the function's return type at the caller
+/// site and in case the return type is a normalized memref then it updates
+/// the calling function's signature.
+/// TODO: An update to the calling function signature is required only if the
+/// returned value is in turn used in ReturnOp of the calling function.
+void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
+ ModuleOp moduleOp) {
FunctionType functionType = funcOp.getType();
- SmallVector<Type, 8> argTypes;
SmallVector<Type, 4> resultTypes;
+ FunctionType newFuncType;
+ resultTypes = llvm::to_vector<4>(functionType.getResults());
- for (const auto &arg : llvm::enumerate(funcOp.getArguments()))
- argTypes.push_back(arg.value().getType());
+ // External function's signature was already updated in
+ // 'normalizeFuncOpMemRefs()'.
+ if (!funcOp.isExternal()) {
+ SmallVector<Type, 8> argTypes;
+ for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
+ argTypes.push_back(argEn.value().getType());
- resultTypes = llvm::to_vector<4>(functionType.getResults());
- // We create a new function type and modify the function signature with this
- // new type.
- FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes,
- /*results=*/resultTypes,
- /*context=*/&getContext());
-
- // TODO: Handle ReturnOps to update function results the caller site.
- funcOp.setType(newFuncType);
+ // Traverse ReturnOps to check if an update to the return type in the
+ // function signature is required.
+ funcOp.walk([&](ReturnOp returnOp) {
+ for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
+ Type opType = operandEn.value().getType();
+ MemRefType memrefType = opType.dyn_cast<MemRefType>();
+ // If type is not memref or if the memref type is same as that in
+ // function's return signature then no update is required.
+ if (!memrefType || memrefType == resultTypes[operandEn.index()])
+ continue;
+ // Update function's return type signature.
+ // Return type gets normalized either as a result of function argument
+ // normalization, AllocOp normalization or an update made at CallOp.
+ // There can be many call flows inside a function and an update to a
+ // specific ReturnOp has not yet been made. So we check that the result
+ // memref type is normalized.
+ // TODO: When selective normalization is implemented, handle multiple
+ // results case where some are normalized, some aren't.
+ if (memrefType.getAffineMaps().empty())
+ resultTypes[operandEn.index()] = memrefType;
+ }
+ });
+
+ // We create a new function type and modify the function signature with this
+ // new type.
+ newFuncType = FunctionType::get(/*inputs=*/argTypes,
+ /*results=*/resultTypes,
+ /*context=*/&getContext());
+ }
+
+ // Since we update the function signature, it might affect the result types at
+ // the caller site. Since this result might even be used by the caller
+ // function in ReturnOps, the caller function's signature will also change.
+ // Hence we record the caller function in 'funcOpsToUpdate' to update their
+ // signature as well.
+ llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate;
+ // We iterate over all symbolic uses of the function and update the return
+ // type at the caller site.
+ Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
+ for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+ Operation *callOp = symbolUse.getUser();
+ OpBuilder builder(callOp);
+ StringRef callee = cast<CallOp>(callOp).getCallee();
+ Operation *newCallOp = builder.create<CallOp>(
+ callOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
+ callOp->getOperands());
+ bool replacingMemRefUsesFailed = false;
+ bool returnTypeChanged = false;
+ for (unsigned resIndex : llvm::seq<unsigned>(0, callOp->getNumResults())) {
+ OpResult oldResult = callOp->getResult(resIndex);
+ OpResult newResult = newCallOp->getResult(resIndex);
+ // This condition ensures that if the result is not of type memref or if
+ // the resulting memref was already having a trivial map layout then we
+ // need not perform any use replacement here.
+ if (oldResult.getType() == newResult.getType())
+ continue;
+ AffineMap layoutMap =
+ oldResult.getType().dyn_cast<MemRefType>().getAffineMaps().front();
+ if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
+ /*extraIndices=*/{},
+ /*indexRemap=*/layoutMap,
+ /*extraOperands=*/{},
+ /*symbolOperands=*/{},
+ /*domInstFilter=*/nullptr,
+ /*postDomInstFilter=*/nullptr,
+ /*allowDereferencingOps=*/true,
+ /*replaceInDeallocOp=*/true))) {
+ // If it failed (due to escapes for example), bail out.
+ // It should never hit this part of the code because it is called by
+ // only those functions which are normalizable.
+ newCallOp->erase();
+ replacingMemRefUsesFailed = true;
+ break;
+ }
+ returnTypeChanged = true;
+ }
+ if (replacingMemRefUsesFailed)
+ continue;
+ // Replace all uses for other non-memref result types.
+ callOp->replaceAllUsesWith(newCallOp);
+ callOp->erase();
+ if (returnTypeChanged) {
+ // Since the return type changed it might lead to a change in function's
+ // signature.
+ // TODO: If funcOp doesn't return any memref type then no need to update
+ // signature.
+ // TODO: Further optimization - Check if the memref is indeed part of
+ // ReturnOp at the parentFuncOp and only then updation of signature is
+ // required.
+ // TODO: Extend this for ops that are FunctionLike. This would require
+ // creating an OpInterface for FunctionLike ops.
+ FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>();
+ funcOpsToUpdate.insert(parentFuncOp);
+ }
+ }
+ // Because external function's signature is already updated in
+ // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
+ if (!funcOp.isExternal())
+ funcOp.setType(newFuncType);
+
+ // Updating the signature type of those functions which call the current
+ // function. Only if the return type of the current function has a normalized
+ // memref will the caller function become a candidate for signature update.
+ for (FuncOp parentFuncOp : funcOpsToUpdate)
+ updateFunctionSignature(parentFuncOp, moduleOp);
}
-void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
+/// Normalizes the memrefs within a function which includes those arising as a
+/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
+/// is used to help update function's signature after normalization.
+void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
+ ModuleOp moduleOp) {
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
// alloc ops first and then process since normalizeMemRef replaces/erases ops
// during memref rewriting.
@@ -169,22 +354,27 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
// We use this OpBuilder to create new memref layout later.
OpBuilder b(funcOp);
+ FunctionType functionType = funcOp.getType();
+ SmallVector<Type, 8> inputTypes;
// Walk over each argument of a function to perform memref normalization (if
- // any).
- for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
- Type argType = funcOp.getArgument(argIndex).getType();
+ for (unsigned argIndex :
+ llvm::seq<unsigned>(0, functionType.getNumInputs())) {
+ Type argType = functionType.getInput(argIndex);
MemRefType memrefType = argType.dyn_cast<MemRefType>();
// Check whether argument is of MemRef type. Any other argument type can
// simply be part of the final function signature.
- if (!memrefType)
+ if (!memrefType) {
+ inputTypes.push_back(argType);
continue;
+ }
// Fetch a new memref type after normalizing the old memref to have an
// identity map layout.
MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
/*numSymbolicOperands=*/0);
- if (newMemRefType == memrefType) {
+ if (newMemRefType == memrefType || funcOp.isExternal()) {
// Either memrefType already had an identity map or the map couldn't be
// transformed to an identity map.
+ inputTypes.push_back(newMemRefType);
continue;
}
@@ -202,7 +392,7 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
/*domInstFilter=*/nullptr,
/*postDomInstFilter=*/nullptr,
/*allowNonDereferencingOps=*/true,
- /*handleDeallocOp=*/true))) {
+ /*replaceInDeallocOp=*/true))) {
// If it failed (due to escapes for example), bail out. Removing the
// temporary argument inserted previously.
funcOp.front().eraseArgument(argIndex);
@@ -214,5 +404,36 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
funcOp.front().eraseArgument(argIndex + 1);
}
- updateFunctionSignature(funcOp);
+ // In a normal function, memrefs in the return type signature gets normalized
+ // as a result of normalization of functions arguments, AllocOps or CallOps'
+ // result types. Since an external function doesn't have a body, memrefs in
+ // the return type signature can only get normalized by iterating over the
+ // individual return types.
+ if (funcOp.isExternal()) {
+ SmallVector<Type, 4> resultTypes;
+ for (unsigned resIndex :
+ llvm::seq<unsigned>(0, functionType.getNumResults())) {
+ Type resType = functionType.getResult(resIndex);
+ MemRefType memrefType = resType.dyn_cast<MemRefType>();
+ // Check whether result is of MemRef type. Any other argument type can
+ // simply be part of the final function signature.
+ if (!memrefType) {
+ resultTypes.push_back(resType);
+ continue;
+ }
+ // Computing a new memref type after normalizing the old memref to have an
+ // identity map layout.
+ MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
+ /*numSymbolicOperands=*/0);
+ resultTypes.push_back(newMemRefType);
+ continue;
+ }
+
+ FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes,
+ /*results=*/resultTypes,
+ /*context=*/&getContext());
+ // Setting the new function signature for this external function.
+ funcOp.setType(newFuncType);
+ }
+ updateFunctionSignature(funcOp, moduleOp);
}
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 73df4fa939bf..c310702745a2 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -274,12 +274,12 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
// for the memref to be used in a non-dereferencing way outside of the
// region where this replacement is happening.
if (!isMemRefDereferencingOp(*op)) {
- // Currently we support the following non-dereferencing types to be a
- // candidate for replacement: Dealloc and CallOp.
- // TODO: Add support for other kinds of ops.
if (!allowNonDereferencingOps)
return failure();
- if (!(isa<DeallocOp, CallOp>(*op)))
+ // Currently we support the following non-dereferencing ops to be a
+ // candidate for replacement: Dealloc, CallOp and ReturnOp.
+ // TODO: Add support for other kinds of ops.
+ if (!isa<DeallocOp, CallOp, ReturnOp>(*op))
return failure();
}
diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir
index 7d56c8893940..9c1e610c483a 100644
--- a/mlir/test/Transforms/normalize-memrefs.mlir
+++ b/mlir/test/Transforms/normalize-memrefs.mlir
@@ -126,14 +126,6 @@ func @symbolic_operands(%s : index) {
return
}
-// Memref escapes; no normalization.
-// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}>
-func @escaping() -> memref<64xf32, affine_map<(d0) -> (d0 + 2)>> {
- // CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}>
- %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 2)>>
- return %A : memref<64xf32, affine_map<(d0) -> (d0 + 2)>>
-}
-
// Semi-affine maps, normalization not implemented yet.
// CHECK-LABEL: func @semi_affine_layout_map
func @semi_affine_layout_map(%s0: index, %s1: index) {
@@ -205,9 +197,125 @@ func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 {
return %d : i1
}
-// Test case 4: No normalization should take place because the function is returning the memref.
-// CHECK-LABEL: func @memref_used_in_return
-// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}>
-func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) {
- return %A : memref<8xf64, #tile>
+// Test cases here onwards deal with normalization of memref in function signature, caller site.
+
+// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls.
+// CHECK-LABEL: func @ret_multiple_argument_type
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64)
+func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) {
+ %a = affine.load %A[0] : memref<16xf64, #tile>
+ %p = mulf %a, %a : f64
+ %cond = constant 1 : i1
+ cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+ return %res2, %p: memref<8xf64, #tile>, f64
+ ^bb2:
+ return %C, %p: memref<8xf64, #tile>, f64
+}
+
+// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
+// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64
+// CHECK: %true = constant true
+// CHECK: cond_br %true, ^bb1, ^bb2
+// CHECK: ^bb1: // pred: ^bb0
+// CHECK: %[[res:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64
+// CHECK: ^bb2: // pred: ^bb0
+// CHECK: return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64
+
+// CHECK-LABEL: func @ret_single_argument_type
+// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){
+ %a = alloc() : memref<8xf64, #tile>
+ %b = alloc() : memref<16xf64, #tile>
+ %d = constant 23.0 : f64
+ call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+ call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+ %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64)
+ %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+ return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile>
+}
+
+// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64>
+// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64>
+// CHECK: %cst = constant 2.300000e+01 : f64
+// CHECK: %[[resA:[0-9]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: %[[resB:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: %[[resC:[0-9]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64)
+// CHECK: %[[resD:[0-9]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64>
+
+// Test case set #5: To check normalization in a chain of interconnected functions.
+// CHECK-LABEL: func @func_A
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_A(%A: memref<8xf64, #tile>) {
+ call @func_B(%A) : (memref<8xf64, #tile>) -> ()
+ return
+}
+// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> ()
+
+// CHECK-LABEL: func @func_B
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_B(%A: memref<8xf64, #tile>) {
+ call @func_C(%A) : (memref<8xf64, #tile>) -> ()
+ return
+}
+// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> ()
+
+// CHECK-LABEL: func @func_C
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_C(%A: memref<8xf64, #tile>) {
+ return
+}
+
+// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type.
+// CHECK-LABEL: func @some_func_A
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_A(%A: memref<8xf64, #tile>) {
+ call @some_func_B(%A) : (memref<8xf64, #tile>) -> ()
+ return
+}
+// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
+
+// CHECK-LABEL: func @some_func_B
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_B(%A: memref<8xf64, #tile>) {
+ "test.test"(%A) : (memref<8xf64, #tile>) -> ()
+ call @some_func_C(%A) : (memref<8xf64, #tile>) -> ()
+ return
+}
+// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
+
+// CHECK-LABEL: func @some_func_C
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_C(%A: memref<8xf64, #tile>) {
+ return
+}
+
+// Test case set #7: Check normalization in case of external functions.
+// CHECK-LABEL: func @external_func_A
+// CHECK-SAME: (memref<4x4xf64>)
+func @external_func_A(memref<16xf64, #tile>) -> ()
+
+// CHECK-LABEL: func @external_func_B
+// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64>
+func @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
+
+// CHECK-LABEL: func @simply_call_external()
+func @simply_call_external() {
+ %a = alloc() : memref<16xf64, #tile>
+ call @external_func_A(%a) : (memref<16xf64, #tile>) -> ()
+ return
+}
+// CHECK: %[[a:[0-9]+]] = alloc() : memref<4x4xf64>
+// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> ()
+
+// CHECK-LABEL: func @use_value_of_external
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64) -> memref<2x4xf64>
+func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) {
+ %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
+ return %res : memref<8xf64, #tile>
}
+// CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64>
+// CHECK: return %{{.*}} : memref<2x4xf64>
More information about the Mlir-commits
mailing list