[Mlir-commits] [mlir] 6d4f780 - [MLIR] Support for ReturnOps in memref map layout normalization

Uday Bondhugula llvmlistbot at llvm.org
Thu Aug 13 06:45:18 PDT 2020


Author: avarmapml
Date: 2020-08-13T19:10:47+05:30
New Revision: 6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78

URL: https://github.com/llvm/llvm-project/commit/6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78
DIFF: https://github.com/llvm/llvm-project/commit/6d4f7801b1d2a0ec6fbc0cb4eb9d3613df788d78.diff

LOG: [MLIR] Support for ReturnOps in memref map layout normalization

-- This commit handles the returnOp in memref map layout normalization.
-- An initial filter is applied on FuncOps which helps us know which functions can be
   a suitable candidate for memref normalization which doesn't lead to invalid IR.
-- Handles memref map normalization for external function assuming the external function
   is normalizable.

Differential Revision: https://reviews.llvm.org/D85226

Added: 
    

Modified: 
    mlir/lib/Transforms/NormalizeMemRefs.cpp
    mlir/lib/Transforms/Utils/Utils.cpp
    mlir/test/Transforms/normalize-memrefs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index 1484dcb9e10c..1736fa989a83 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/SmallSet.h"
 
 #define DEBUG_TYPE "normalize-memrefs"
 
@@ -24,39 +25,45 @@ namespace {
 
 /// All memrefs passed across functions with non-trivial layout maps are
 /// converted to ones with trivial identity layout ones.
-
-// Input :-
-// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
-// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
-// (memref<16xf64, #tile>) {
-//   affine.for %arg3 = 0 to 16 {
-//         %a = affine.load %A[%arg3] : memref<16xf64, #tile>
-//         %p = mulf %a, %a : f64
-//         affine.store %p, %A[%arg3] : memref<16xf64, #tile>
-//   }
-//   %c = alloc() : memref<16xf64, #tile>
-//   %d = affine.load %c[0] : memref<16xf64, #tile>
-//   return %A: memref<16xf64, #tile>
-// }
-
-// Output :-
-//   func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
-//   -> memref<4x4xf64> {
-//     affine.for %arg3 = 0 to 16 {
-//       %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
-//       %3 = mulf %2, %2 : f64
-//       affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
-//     }
-//     %0 = alloc() : memref<16xf64, #map0>
-//     %1 = affine.load %0[0] : memref<16xf64, #map0>
-//     return %arg0 : memref<4x4xf64>
-//   }
-
+/// If all the memref types/uses in a function are normalizable, we treat
+/// such functions as normalizable. Also, if a normalizable function is known
+/// to call a non-normalizable function, we treat that function as
+/// non-normalizable as well. We assume external functions to be normalizable.
+///
+/// Input :-
+/// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
+/// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
+/// (memref<16xf64, #tile>) {
+///   affine.for %arg3 = 0 to 16 {
+///         %a = affine.load %A[%arg3] : memref<16xf64, #tile>
+///         %p = mulf %a, %a : f64
+///         affine.store %p, %A[%arg3] : memref<16xf64, #tile>
+///   }
+///   %c = alloc() : memref<16xf64, #tile>
+///   %d = affine.load %c[0] : memref<16xf64, #tile>
+///   return %A: memref<16xf64, #tile>
+/// }
+///
+/// Output :-
+///   func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
+///   -> memref<4x4xf64> {
+///     affine.for %arg3 = 0 to 16 {
+///       %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] :
+///       memref<4x4xf64> %3 = mulf %2, %2 : f64 affine.store %3, %arg0[%arg3
+///       floordiv 4, %arg3 mod 4] : memref<4x4xf64>
+///     }
+///     %0 = alloc() : memref<16xf64, #map0>
+///     %1 = affine.load %0[0] : memref<16xf64, #map0>
+///     return %arg0 : memref<4x4xf64>
+///   }
+///
 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
   void runOnOperation() override;
-  void runOnFunction(FuncOp funcOp);
+  void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
   bool areMemRefsNormalizable(FuncOp funcOp);
-  void updateFunctionSignature(FuncOp funcOp);
+  void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
+  void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
+                                           DenseSet<FuncOp> &normalizableFuncs);
 };
 
 } // end anonymous namespace
@@ -67,41 +74,109 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
 
 void NormalizeMemRefs::runOnOperation() {
   ModuleOp moduleOp = getOperation();
-  // We traverse each function within the module in order to normalize the
-  // memref type arguments.
-  // TODO: Handle external functions.
+  // We maintain all normalizable FuncOps in a DenseSet. It is initialized
+  // with all the functions within a module and then functions which are not
+  // normalizable are removed from this set.
+  // TODO: Change this to work on FuncLikeOp once there is an operation
+  // interface for it.
+  DenseSet<FuncOp> normalizableFuncs;
+  // Initialize `normalizableFuncs` with all the functions within a module.
+  moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
+
+  // Traverse through all the functions applying a filter which determines
+  // whether that function is normalizable or not. All callers/callees of
+  // a non-normalizable function will also become non-normalizable even if
+  // they aren't passing any or specific non-normalizable memrefs. So,
+  // functions which calls or get called by a non-normalizable becomes non-
+  // normalizable functions themselves.
   moduleOp.walk([&](FuncOp funcOp) {
-    if (areMemRefsNormalizable(funcOp))
-      runOnFunction(funcOp);
+    if (normalizableFuncs.contains(funcOp)) {
+      if (!areMemRefsNormalizable(funcOp)) {
+        // Since this function is not normalizable, we set all the caller
+        // functions and the callees of this function as not normalizable.
+        // TODO: Drop this conservative assumption in the future.
+        setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
+                                            normalizableFuncs);
+      }
+    }
   });
+
+  // Those functions which can be normalized are subjected to normalization.
+  for (FuncOp &funcOp : normalizableFuncs)
+    normalizeFuncOpMemRefs(funcOp, moduleOp);
 }
 
-// Return true if this operation dereferences one or more memref's.
-// TODO: Temporary utility, will be replaced when this is modeled through
-// side-effects/op traits.
+/// Return true if this operation dereferences one or more memref's.
+/// TODO: Temporary utility, will be replaced when this is modeled through
+/// side-effects/op traits.
 static bool isMemRefDereferencingOp(Operation &op) {
   return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
              AffineDmaWaitOp>(op);
 }
 
-// Check whether all the uses of oldMemRef are either dereferencing uses or the
-// op is of type : DeallocOp, CallOp. Only if these constraints are satisfied
-// will the value become a candidate for replacement.
+/// Check whether all the uses of oldMemRef are either dereferencing uses or the
+/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
+/// are satisfied will the value become a candidate for replacement.
+/// TODO: Extend this for DimOps.
 static bool isMemRefNormalizable(Value::user_range opUsers) {
   if (llvm::any_of(opUsers, [](Operation *op) {
         if (isMemRefDereferencingOp(*op))
           return false;
-        return !isa<DeallocOp, CallOp>(*op);
+        return !isa<DeallocOp, CallOp, ReturnOp>(*op);
       }))
     return false;
   return true;
 }
 
-// Check whether all the uses of AllocOps, CallOps and function arguments of a
-// function are either of dereferencing type or of type: DeallocOp, CallOp. Only
-// if these constraints are satisfied will the function become a candidate for
-// normalization.
+/// Set all the calling functions and the callees of the function as not
+/// normalizable.
+void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
+    FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) {
+  if (!normalizableFuncs.contains(funcOp))
+    return;
+
+  normalizableFuncs.erase(funcOp);
+  // Caller of the function.
+  Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
+  for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+    // TODO: Extend this for ops that are FunctionLike. This would require
+    // creating an OpInterface for FunctionLike ops.
+    FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>();
+    for (FuncOp &funcOp : normalizableFuncs) {
+      if (parentFuncOp == funcOp) {
+        setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
+                                            normalizableFuncs);
+        break;
+      }
+    }
+  }
+
+  // Functions called by this function.
+  funcOp.walk([&](CallOp callOp) {
+    StringRef callee = callOp.getCallee();
+    for (FuncOp &funcOp : normalizableFuncs) {
+      // We compare FuncOp and callee's name.
+      if (callee == funcOp.getName()) {
+        setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
+                                            normalizableFuncs);
+        break;
+      }
+    }
+  });
+}
+
+/// Check whether all the uses of AllocOps, CallOps and function arguments of a
+/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
+/// or ReturnOp. Only if these constraints are satisfied will the function
+/// become a candidate for normalization. We follow a conservative approach here
+/// wherein even if the non-normalizable memref is not a part of the function's
+/// argument or return type, we still label the entire function as
+/// non-normalizable. We assume external functions to be normalizable.
 bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
+  // We assume external functions to be normalizable.
+  if (funcOp.isExternal())
+    return true;
+
   if (funcOp
           .walk([&](AllocOp allocOp) -> WalkResult {
             Value oldMemRef = allocOp.getResult();
@@ -136,28 +211,138 @@ bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
   return true;
 }
 
-// Fetch the updated argument list and result of the function and update the
-// function signature.
-void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp) {
+/// Fetch the updated argument list and result of the function and update the
+/// function signature. This updates the function's return type at the caller
+/// site and in case the return type is a normalized memref then it updates
+/// the calling function's signature.
+/// TODO: An update to the calling function signature is required only if the
+/// returned value is in turn used in ReturnOp of the calling function.
+void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
+                                               ModuleOp moduleOp) {
   FunctionType functionType = funcOp.getType();
-  SmallVector<Type, 8> argTypes;
   SmallVector<Type, 4> resultTypes;
+  FunctionType newFuncType;
+  resultTypes = llvm::to_vector<4>(functionType.getResults());
 
-  for (const auto &arg : llvm::enumerate(funcOp.getArguments()))
-    argTypes.push_back(arg.value().getType());
+  // External function's signature was already updated in
+  // 'normalizeFuncOpMemRefs()'.
+  if (!funcOp.isExternal()) {
+    SmallVector<Type, 8> argTypes;
+    for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
+      argTypes.push_back(argEn.value().getType());
 
-  resultTypes = llvm::to_vector<4>(functionType.getResults());
-  // We create a new function type and modify the function signature with this
-  // new type.
-  FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes,
-                                               /*results=*/resultTypes,
-                                               /*context=*/&getContext());
-
-  // TODO: Handle ReturnOps to update function results the caller site.
-  funcOp.setType(newFuncType);
+    // Traverse ReturnOps to check if an update to the return type in the
+    // function signature is required.
+    funcOp.walk([&](ReturnOp returnOp) {
+      for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
+        Type opType = operandEn.value().getType();
+        MemRefType memrefType = opType.dyn_cast<MemRefType>();
+        // If type is not memref or if the memref type is same as that in
+        // function's return signature then no update is required.
+        if (!memrefType || memrefType == resultTypes[operandEn.index()])
+          continue;
+        // Update function's return type signature.
+        // Return type gets normalized either as a result of function argument
+        // normalization, AllocOp normalization or an update made at CallOp.
+        // There can be many call flows inside a function and an update to a
+        // specific ReturnOp has not yet been made. So we check that the result
+        // memref type is normalized.
+        // TODO: When selective normalization is implemented, handle multiple
+        // results case where some are normalized, some aren't.
+        if (memrefType.getAffineMaps().empty())
+          resultTypes[operandEn.index()] = memrefType;
+      }
+    });
+
+    // We create a new function type and modify the function signature with this
+    // new type.
+    newFuncType = FunctionType::get(/*inputs=*/argTypes,
+                                    /*results=*/resultTypes,
+                                    /*context=*/&getContext());
+  }
+
+  // Since we update the function signature, it might affect the result types at
+  // the caller site. Since this result might even be used by the caller
+  // function in ReturnOps, the caller function's signature will also change.
+  // Hence we record the caller function in 'funcOpsToUpdate' to update their
+  // signature as well.
+  llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate;
+  // We iterate over all symbolic uses of the function and update the return
+  // type at the caller site.
+  Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
+  for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+    Operation *callOp = symbolUse.getUser();
+    OpBuilder builder(callOp);
+    StringRef callee = cast<CallOp>(callOp).getCallee();
+    Operation *newCallOp = builder.create<CallOp>(
+        callOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
+        callOp->getOperands());
+    bool replacingMemRefUsesFailed = false;
+    bool returnTypeChanged = false;
+    for (unsigned resIndex : llvm::seq<unsigned>(0, callOp->getNumResults())) {
+      OpResult oldResult = callOp->getResult(resIndex);
+      OpResult newResult = newCallOp->getResult(resIndex);
+      // This condition ensures that if the result is not of type memref or if
+      // the resulting memref was already having a trivial map layout then we
+      // need not perform any use replacement here.
+      if (oldResult.getType() == newResult.getType())
+        continue;
+      AffineMap layoutMap =
+          oldResult.getType().dyn_cast<MemRefType>().getAffineMaps().front();
+      if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
+                                          /*extraIndices=*/{},
+                                          /*indexRemap=*/layoutMap,
+                                          /*extraOperands=*/{},
+                                          /*symbolOperands=*/{},
+                                          /*domInstFilter=*/nullptr,
+                                          /*postDomInstFilter=*/nullptr,
+                                          /*allowDereferencingOps=*/true,
+                                          /*replaceInDeallocOp=*/true))) {
+        // If it failed (due to escapes for example), bail out.
+        // It should never hit this part of the code because it is called by
+        // only those functions which are normalizable.
+        newCallOp->erase();
+        replacingMemRefUsesFailed = true;
+        break;
+      }
+      returnTypeChanged = true;
+    }
+    if (replacingMemRefUsesFailed)
+      continue;
+    // Replace all uses for other non-memref result types.
+    callOp->replaceAllUsesWith(newCallOp);
+    callOp->erase();
+    if (returnTypeChanged) {
+      // Since the return type changed it might lead to a change in function's
+      // signature.
+      // TODO: If funcOp doesn't return any memref type then no need to update
+      // signature.
+      // TODO: Further optimization - Check if the memref is indeed part of
+      // ReturnOp at the parentFuncOp and only then updation of signature is
+      // required.
+      // TODO: Extend this for ops that are FunctionLike. This would require
+      // creating an OpInterface for FunctionLike ops.
+      FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>();
+      funcOpsToUpdate.insert(parentFuncOp);
+    }
+  }
+  // Because external function's signature is already updated in
+  // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
+  if (!funcOp.isExternal())
+    funcOp.setType(newFuncType);
+
+  // Updating the signature type of those functions which call the current
+  // function. Only if the return type of the current function has a normalized
+  // memref will the caller function become a candidate for signature update.
+  for (FuncOp parentFuncOp : funcOpsToUpdate)
+    updateFunctionSignature(parentFuncOp, moduleOp);
 }
 
-void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
+/// Normalizes the memrefs within a function which includes those arising as a
+/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
+/// is used to help update function's signature after normalization.
+void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
+                                              ModuleOp moduleOp) {
   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
   // alloc ops first and then process since normalizeMemRef replaces/erases ops
   // during memref rewriting.
@@ -169,22 +354,27 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
   // We use this OpBuilder to create new memref layout later.
   OpBuilder b(funcOp);
 
+  FunctionType functionType = funcOp.getType();
+  SmallVector<Type, 8> inputTypes;
   // Walk over each argument of a function to perform memref normalization (if
-  // any).
-  for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
-    Type argType = funcOp.getArgument(argIndex).getType();
+  for (unsigned argIndex :
+       llvm::seq<unsigned>(0, functionType.getNumInputs())) {
+    Type argType = functionType.getInput(argIndex);
     MemRefType memrefType = argType.dyn_cast<MemRefType>();
     // Check whether argument is of MemRef type. Any other argument type can
     // simply be part of the final function signature.
-    if (!memrefType)
+    if (!memrefType) {
+      inputTypes.push_back(argType);
       continue;
+    }
     // Fetch a new memref type after normalizing the old memref to have an
     // identity map layout.
     MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
                                                    /*numSymbolicOperands=*/0);
-    if (newMemRefType == memrefType) {
+    if (newMemRefType == memrefType || funcOp.isExternal()) {
       // Either memrefType already had an identity map or the map couldn't be
       // transformed to an identity map.
+      inputTypes.push_back(newMemRefType);
       continue;
     }
 
@@ -202,7 +392,7 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
                                         /*domInstFilter=*/nullptr,
                                         /*postDomInstFilter=*/nullptr,
                                         /*allowNonDereferencingOps=*/true,
-                                        /*handleDeallocOp=*/true))) {
+                                        /*replaceInDeallocOp=*/true))) {
       // If it failed (due to escapes for example), bail out. Removing the
       // temporary argument inserted previously.
       funcOp.front().eraseArgument(argIndex);
@@ -214,5 +404,36 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
     funcOp.front().eraseArgument(argIndex + 1);
   }
 
-  updateFunctionSignature(funcOp);
+  // In a normal function, memrefs in the return type signature gets normalized
+  // as a result of normalization of functions arguments, AllocOps or CallOps'
+  // result types. Since an external function doesn't have a body, memrefs in
+  // the return type signature can only get normalized by iterating over the
+  // individual return types.
+  if (funcOp.isExternal()) {
+    SmallVector<Type, 4> resultTypes;
+    for (unsigned resIndex :
+         llvm::seq<unsigned>(0, functionType.getNumResults())) {
+      Type resType = functionType.getResult(resIndex);
+      MemRefType memrefType = resType.dyn_cast<MemRefType>();
+      // Check whether result is of MemRef type. Any other argument type can
+      // simply be part of the final function signature.
+      if (!memrefType) {
+        resultTypes.push_back(resType);
+        continue;
+      }
+      // Computing a new memref type after normalizing the old memref to have an
+      // identity map layout.
+      MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
+                                                     /*numSymbolicOperands=*/0);
+      resultTypes.push_back(newMemRefType);
+      continue;
+    }
+
+    FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes,
+                                                 /*results=*/resultTypes,
+                                                 /*context=*/&getContext());
+    // Setting the new function signature for this external function.
+    funcOp.setType(newFuncType);
+  }
+  updateFunctionSignature(funcOp, moduleOp);
 }

diff  --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 73df4fa939bf..c310702745a2 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -274,12 +274,12 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
     // for the memref to be used in a non-dereferencing way outside of the
     // region where this replacement is happening.
     if (!isMemRefDereferencingOp(*op)) {
-      // Currently we support the following non-dereferencing types to be a
-      // candidate for replacement: Dealloc and CallOp.
-      // TODO: Add support for other kinds of ops.
       if (!allowNonDereferencingOps)
         return failure();
-      if (!(isa<DeallocOp, CallOp>(*op)))
+      // Currently we support the following non-dereferencing ops to be a
+      // candidate for replacement: Dealloc, CallOp and ReturnOp.
+      // TODO: Add support for other kinds of ops.
+      if (!isa<DeallocOp, CallOp, ReturnOp>(*op))
         return failure();
     }
 

diff  --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir
index 7d56c8893940..9c1e610c483a 100644
--- a/mlir/test/Transforms/normalize-memrefs.mlir
+++ b/mlir/test/Transforms/normalize-memrefs.mlir
@@ -126,14 +126,6 @@ func @symbolic_operands(%s : index) {
   return
 }
 
-// Memref escapes; no normalization.
-// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}>
-func @escaping() ->  memref<64xf32, affine_map<(d0) -> (d0 + 2)>> {
-  // CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}>
-  %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 2)>>
-  return %A : memref<64xf32, affine_map<(d0) -> (d0 + 2)>>
-}
-
 // Semi-affine maps, normalization not implemented yet.
 // CHECK-LABEL: func @semi_affine_layout_map
 func @semi_affine_layout_map(%s0: index, %s1: index) {
@@ -205,9 +197,125 @@ func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 {
   return %d : i1
 }
 
-// Test case 4: No normalization should take place because the function is returning the memref.
-// CHECK-LABEL: func @memref_used_in_return
-// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}>
-func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) {
-  return %A : memref<8xf64, #tile>
+// Test cases here onwards deal with normalization of memref in function signature, caller site.
+
+// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls.
+// CHECK-LABEL: func @ret_multiple_argument_type
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64)
+func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) {
+  %a = affine.load %A[0] : memref<16xf64, #tile>
+  %p = mulf %a, %a : f64
+  %cond = constant 1 : i1
+  cond_br %cond, ^bb1, ^bb2
+  ^bb1:
+    %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+    return %res2, %p: memref<8xf64, #tile>, f64
+  ^bb2:
+    return %C, %p: memref<8xf64, #tile>, f64
+}
+
+// CHECK:   %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
+// CHECK:   %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64
+// CHECK:   %true = constant true
+// CHECK:   cond_br %true, ^bb1, ^bb2
+// CHECK: ^bb1:  // pred: ^bb0
+// CHECK:   %[[res:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK:   return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64
+// CHECK: ^bb2:  // pred: ^bb0
+// CHECK:   return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64
+
+// CHECK-LABEL: func @ret_single_argument_type
+// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){
+  %a = alloc() : memref<8xf64, #tile>
+  %b = alloc() : memref<16xf64, #tile>
+  %d = constant 23.0 : f64
+  call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+  call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+  %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64)
+  %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
+  return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile>
+}
+
+// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64>
+// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64>
+// CHECK: %cst = constant 2.300000e+01 : f64
+// CHECK: %[[resA:[0-9]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: %[[resB:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: %[[resC:[0-9]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64)
+// CHECK: %[[resD:[0-9]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
+// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64>
+
+// Test case set #5: To check normalization in a chain of interconnected functions.
+// CHECK-LABEL: func @func_A
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_A(%A: memref<8xf64, #tile>) {
+  call @func_B(%A) : (memref<8xf64, #tile>) -> ()
+  return
+}
+// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> ()
+
+// CHECK-LABEL: func @func_B
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_B(%A: memref<8xf64, #tile>) {
+  call @func_C(%A) : (memref<8xf64, #tile>) -> ()
+  return
+}
+// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> ()
+
+// CHECK-LABEL: func @func_C
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
+func @func_C(%A: memref<8xf64, #tile>) {
+  return
+}
+
+// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type.
+// CHECK-LABEL: func @some_func_A
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_A(%A: memref<8xf64, #tile>) {
+  call @some_func_B(%A) : (memref<8xf64, #tile>) -> ()
+  return
+}
+// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
+
+// CHECK-LABEL: func @some_func_B
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_B(%A: memref<8xf64, #tile>) {
+  "test.test"(%A) : (memref<8xf64, #tile>) -> ()
+  call @some_func_C(%A) : (memref<8xf64, #tile>) -> ()
+  return
+}
+// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
+
+// CHECK-LABEL: func @some_func_C
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
+func @some_func_C(%A: memref<8xf64, #tile>) {
+  return
+}
+
+// Test case set #7: Check normalization in case of external functions.
+// CHECK-LABEL: func @external_func_A
+// CHECK-SAME: (memref<4x4xf64>)
+func @external_func_A(memref<16xf64, #tile>) -> ()
+
+// CHECK-LABEL: func @external_func_B
+// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64>
+func @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
+
+// CHECK-LABEL: func @simply_call_external()
+func @simply_call_external() {
+  %a = alloc() : memref<16xf64, #tile>
+  call @external_func_A(%a) : (memref<16xf64, #tile>) -> ()
+  return
+}
+// CHECK: %[[a:[0-9]+]] = alloc() : memref<4x4xf64>
+// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> ()
+
+// CHECK-LABEL: func @use_value_of_external
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64) -> memref<2x4xf64>
+func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) {
+  %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
+  return %res : memref<8xf64, #tile>
 }
+// CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64>
+// CHECK: return %{{.*}} : memref<2x4xf64>


        


More information about the Mlir-commits mailing list