[Mlir-commits] [mlir] 76d0750 - [MLIR] Introduce inter-procedural memref layout normalization

Uday Bondhugula llvmlistbot at llvm.org
Thu Jul 30 05:44:47 PDT 2020


Author: Abhishek Varma
Date: 2020-07-30T18:12:56+05:30
New Revision: 76d07503f0c69f6632e6d8d4736e2a4cb4055a92

URL: https://github.com/llvm/llvm-project/commit/76d07503f0c69f6632e6d8d4736e2a4cb4055a92
DIFF: https://github.com/llvm/llvm-project/commit/76d07503f0c69f6632e6d8d4736e2a4cb4055a92.diff

LOG: [MLIR] Introduce inter-procedural memref layout normalization

-- Introduces a pass that normalizes the affine layout maps to the identity layout map both within and across functions by rewriting function arguments and call operands where necessary.
-- Memref normalization is now implemented entirely in the module pass '-normalize-memrefs' and the limited intra-procedural version has been removed from '-simplify-affine-structures'.
-- Run using -normalize-memrefs.
-- Return ops are not handled and would be handled in the subsequent revisions.

Signed-off-by: Abhishek Varma <abhishek.varma at polymagelabs.com>

Differential Revision: https://reviews.llvm.org/D84490

Added: 
    mlir/lib/Transforms/NormalizeMemRefs.cpp
    mlir/test/Transforms/normalize-memrefs.mlir

Modified: 
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/include/mlir/Transforms/Utils.h
    mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
    mlir/lib/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Utils/Utils.cpp

Removed: 
    mlir/test/Transforms/memref-normalize.mlir


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 955b0e99a1d1..1ffff1a25a6d 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -24,7 +24,9 @@ class AffineForOp;
 class FuncOp;
 class ModuleOp;
 class Pass;
-template <typename T> class OperationPass;
+
+template <typename T>
+class OperationPass;
 
 /// Creates an instance of the BufferPlacement pass.
 std::unique_ptr<Pass> createBufferPlacementPass();
@@ -89,6 +91,10 @@ std::unique_ptr<Pass> createSCCPPass();
 /// Creates a pass which delete symbol operations that are unreachable. This
 /// pass may *only* be scheduled on an operation that defines a SymbolTable.
 std::unique_ptr<Pass> createSymbolDCEPass();
+
+/// Creates an interprocedural pass to normalize memrefs to have a trivial
+/// (identity) layout map.
+std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_PASSES_H

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 9e0d5c40d61f..bd905b0e20f6 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -309,6 +309,11 @@ def MemRefDataFlowOpt : FunctionPass<"memref-dataflow-opt"> {
   let constructor = "mlir::createMemRefDataFlowOptPass()";
 }
 
+def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
+  let summary = "Normalize memrefs";
+  let constructor = "mlir::createNormalizeMemRefsPass()";
+}
+
 def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> {
   let summary = "Collapse parallel loops to use less induction variables";
   let constructor = "mlir::createParallelLoopCollapsingPass()";
@@ -405,5 +410,4 @@ def SymbolDCE : Pass<"symbol-dce"> {
   }];
   let constructor = "mlir::createSymbolDCEPass()";
 }
-
 #endif // MLIR_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h
index 6f29c1b41ae6..81b4dfd0a01b 100644
--- a/mlir/include/mlir/Transforms/Utils.h
+++ b/mlir/include/mlir/Transforms/Utils.h
@@ -45,10 +45,19 @@ class OpBuilder;
 /// operations that are dominated by the former; similarly, `postDomInstFilter`
 /// restricts replacement to only those operations that are postdominated by it.
 ///
+/// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing
+/// uses of a memref without any requirement for access index rewrites. The
+/// default value of this flag variable is false.
+///
+/// 'replaceInDeallocOp', if set, lets DeallocOp, a non-dereferencing user, to
+/// also be a candidate for replacement. The default value of this flag is
+/// false.
+///
 /// Returns true on success and false if the replacement is not possible,
-/// whenever a memref is used as an operand in a non-dereferencing context,
-/// except for dealloc's on the memref which are left untouched. See comments at
-/// function definition for an example.
+/// whenever a memref is used as an operand in a non-dereferencing context and
+/// 'allowNonDereferencingOps' is false, except for dealloc's on the memref
+/// which are left untouched. See comments at function definition for an
+/// example.
 //
 //  Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]:
 //  The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and
@@ -57,28 +66,38 @@ class OpBuilder;
 //  extra operands, note that 'indexRemap' would just be applied to existing
 //  indices (%i, %j).
 //  TODO: allow extraIndices to be added at any position.
-LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
-                                       ArrayRef<Value> extraIndices = {},
-                                       AffineMap indexRemap = AffineMap(),
-                                       ArrayRef<Value> extraOperands = {},
-                                       ArrayRef<Value> symbolOperands = {},
-                                       Operation *domInstFilter = nullptr,
-                                       Operation *postDomInstFilter = nullptr);
+LogicalResult replaceAllMemRefUsesWith(
+    Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices = {},
+    AffineMap indexRemap = AffineMap(), ArrayRef<Value> extraOperands = {},
+    ArrayRef<Value> symbolOperands = {}, Operation *domInstFilter = nullptr,
+    Operation *postDomInstFilter = nullptr,
+    bool allowNonDereferencingOps = false, bool replaceInDeallocOp = false);
 
 /// Performs the same replacement as the other version above but only for the
-/// dereferencing uses of `oldMemRef` in `op`.
+/// dereferencing uses of `oldMemRef` in `op`, except in cases where
+/// 'allowNonDereferencingOps' is set to true where we replace the
+/// non-dereferencing uses as well.
 LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
                                        Operation *op,
                                        ArrayRef<Value> extraIndices = {},
                                        AffineMap indexRemap = AffineMap(),
                                        ArrayRef<Value> extraOperands = {},
-                                       ArrayRef<Value> symbolOperands = {});
+                                       ArrayRef<Value> symbolOperands = {},
+                                       bool allowNonDereferencingOps = false);
 
 /// Rewrites the memref defined by this alloc op to have an identity layout map
 /// and updates all its indexing uses. Returns failure if any of its uses
 /// escape (while leaving the IR in a valid state).
 LogicalResult normalizeMemRef(AllocOp op);
 
+/// Uses the old memref type map layout and computes the new memref type to have
+/// a new shape and a layout map, where the old layout map has been normalized
+/// to an identity layout map. It returns the old memref in case no
+/// normalization was needed or a failure occurs while transforming the old map
+/// layout to an identity layout map.
+MemRefType normalizeMemRefType(MemRefType memrefType, OpBuilder builder,
+                               unsigned numSymbolicOperands);
+
 /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
 /// its results equal to the number of operands, as a composition
 /// of all other AffineApplyOps reachable from input parameter 'operands'. If

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 0cd59b52d543..d8ffb9742fae 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -96,13 +96,4 @@ void SimplifyAffineStructures::runOnFunction() {
     if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
       applyOpPatternsAndFold(op, patterns);
   });
-
-  // Turn memrefs' non-identity layouts maps into ones with identity. Collect
-  // alloc ops first and then process since normalizeMemRef replaces/erases ops
-  // during memref rewriting.
-  SmallVector<AllocOp, 4> allocOps;
-  func.walk([&](AllocOp op) { allocOps.push_back(op); });
-  for (auto allocOp : allocOps) {
-    normalizeMemRef(allocOp);
-  }
 }

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 3c6b3933de2a..58c5fa672088 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTransforms
   LoopFusion.cpp
   LoopInvariantCodeMotion.cpp
   MemRefDataFlowOpt.cpp
+  NormalizeMemRefs.cpp
   OpStats.cpp
   ParallelLoopCollapsing.cpp
   PipelineDataTransfer.cpp

diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
new file mode 100644
index 000000000000..1484dcb9e10c
--- /dev/null
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -0,0 +1,218 @@
+//===- NormalizeMemRefs.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an interprocedural pass to normalize memrefs to have
+// identity layout maps.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+#define DEBUG_TYPE "normalize-memrefs"
+
+using namespace mlir;
+
+namespace {
+
+/// All memrefs passed across functions with non-trivial layout maps are
+/// converted to ones with trivial identity layout ones.
+
+// Input :-
+// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
+// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
+// (memref<16xf64, #tile>) {
+//   affine.for %arg3 = 0 to 16 {
+//         %a = affine.load %A[%arg3] : memref<16xf64, #tile>
+//         %p = mulf %a, %a : f64
+//         affine.store %p, %A[%arg3] : memref<16xf64, #tile>
+//   }
+//   %c = alloc() : memref<16xf64, #tile>
+//   %d = affine.load %c[0] : memref<16xf64, #tile>
+//   return %A: memref<16xf64, #tile>
+// }
+
+// Output :-
+//   func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
+//   -> memref<4x4xf64> {
+//     affine.for %arg3 = 0 to 16 {
+//       %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
+//       %3 = mulf %2, %2 : f64
+//       affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
+//     }
+//     %0 = alloc() : memref<16xf64, #map0>
+//     %1 = affine.load %0[0] : memref<16xf64, #map0>
+//     return %arg0 : memref<4x4xf64>
+//   }
+
+struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
+  void runOnOperation() override;
+  void runOnFunction(FuncOp funcOp);
+  bool areMemRefsNormalizable(FuncOp funcOp);
+  void updateFunctionSignature(FuncOp funcOp);
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
+  return std::make_unique<NormalizeMemRefs>();
+}
+
+void NormalizeMemRefs::runOnOperation() {
+  ModuleOp moduleOp = getOperation();
+  // We traverse each function within the module in order to normalize the
+  // memref type arguments.
+  // TODO: Handle external functions.
+  moduleOp.walk([&](FuncOp funcOp) {
+    if (areMemRefsNormalizable(funcOp))
+      runOnFunction(funcOp);
+  });
+}
+
+// Return true if this operation dereferences one or more memref's.
+// TODO: Temporary utility, will be replaced when this is modeled through
+// side-effects/op traits.
+static bool isMemRefDereferencingOp(Operation &op) {
+  return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
+             AffineDmaWaitOp>(op);
+}
+
+// Check whether all the uses of oldMemRef are either dereferencing uses or the
+// op is of type : DeallocOp, CallOp. Only if these constraints are satisfied
+// will the value become a candidate for replacement.
+static bool isMemRefNormalizable(Value::user_range opUsers) {
+  if (llvm::any_of(opUsers, [](Operation *op) {
+        if (isMemRefDereferencingOp(*op))
+          return false;
+        return !isa<DeallocOp, CallOp>(*op);
+      }))
+    return false;
+  return true;
+}
+
+// Check whether all the uses of AllocOps, CallOps and function arguments of a
+// function are either of dereferencing type or of type: DeallocOp, CallOp. Only
+// if these constraints are satisfied will the function become a candidate for
+// normalization.
+bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
+  if (funcOp
+          .walk([&](AllocOp allocOp) -> WalkResult {
+            Value oldMemRef = allocOp.getResult();
+            if (!isMemRefNormalizable(oldMemRef.getUsers()))
+              return WalkResult::interrupt();
+            return WalkResult::advance();
+          })
+          .wasInterrupted())
+    return false;
+
+  if (funcOp
+          .walk([&](CallOp callOp) -> WalkResult {
+            for (unsigned resIndex :
+                 llvm::seq<unsigned>(0, callOp.getNumResults())) {
+              Value oldMemRef = callOp.getResult(resIndex);
+              if (oldMemRef.getType().isa<MemRefType>())
+                if (!isMemRefNormalizable(oldMemRef.getUsers()))
+                  return WalkResult::interrupt();
+            }
+            return WalkResult::advance();
+          })
+          .wasInterrupted())
+    return false;
+
+  for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+    BlockArgument oldMemRef = funcOp.getArgument(argIndex);
+    if (oldMemRef.getType().isa<MemRefType>())
+      if (!isMemRefNormalizable(oldMemRef.getUsers()))
+        return false;
+  }
+
+  return true;
+}
+
+// Fetch the updated argument list and result of the function and update the
+// function signature.
+void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp) {
+  FunctionType functionType = funcOp.getType();
+  SmallVector<Type, 8> argTypes;
+  SmallVector<Type, 4> resultTypes;
+
+  for (const auto &arg : llvm::enumerate(funcOp.getArguments()))
+    argTypes.push_back(arg.value().getType());
+
+  resultTypes = llvm::to_vector<4>(functionType.getResults());
+  // We create a new function type and modify the function signature with this
+  // new type.
+  FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes,
+                                               /*results=*/resultTypes,
+                                               /*context=*/&getContext());
+
+  // TODO: Handle ReturnOps to update function results the caller site.
+  funcOp.setType(newFuncType);
+}
+
+void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
+  // Turn memrefs' non-identity layouts maps into ones with identity. Collect
+  // alloc ops first and then process since normalizeMemRef replaces/erases ops
+  // during memref rewriting.
+  SmallVector<AllocOp, 4> allocOps;
+  funcOp.walk([&](AllocOp op) { allocOps.push_back(op); });
+  for (AllocOp allocOp : allocOps)
+    normalizeMemRef(allocOp);
+
+  // We use this OpBuilder to create new memref layout later.
+  OpBuilder b(funcOp);
+
+  // Walk over each argument of a function to perform memref normalization (if
+  // any).
+  for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+    Type argType = funcOp.getArgument(argIndex).getType();
+    MemRefType memrefType = argType.dyn_cast<MemRefType>();
+    // Check whether argument is of MemRef type. Any other argument type can
+    // simply be part of the final function signature.
+    if (!memrefType)
+      continue;
+    // Fetch a new memref type after normalizing the old memref to have an
+    // identity map layout.
+    MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
+                                                   /*numSymbolicOperands=*/0);
+    if (newMemRefType == memrefType) {
+      // Either memrefType already had an identity map or the map couldn't be
+      // transformed to an identity map.
+      continue;
+    }
+
+    // Insert a new temporary argument with the new memref type.
+    BlockArgument newMemRef =
+        funcOp.front().insertArgument(argIndex, newMemRefType);
+    BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
+    AffineMap layoutMap = memrefType.getAffineMaps().front();
+    // Replace all uses of the old memref.
+    if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
+                                        /*extraIndices=*/{},
+                                        /*indexRemap=*/layoutMap,
+                                        /*extraOperands=*/{},
+                                        /*symbolOperands=*/{},
+                                        /*domInstFilter=*/nullptr,
+                                        /*postDomInstFilter=*/nullptr,
+                                        /*allowNonDereferencingOps=*/true,
+                                        /*handleDeallocOp=*/true))) {
+      // If it failed (due to escapes for example), bail out. Removing the
+      // temporary argument inserted previously.
+      funcOp.front().eraseArgument(argIndex);
+      continue;
+    }
+
+    // All uses for the argument with old memref type were replaced
+    // successfully. So we remove the old argument now.
+    funcOp.front().eraseArgument(argIndex + 1);
+  }
+
+  updateFunctionSignature(funcOp);
+}

diff  --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 35853c17232b..73df4fa939bf 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -48,7 +48,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
                                              ArrayRef<Value> extraIndices,
                                              AffineMap indexRemap,
                                              ArrayRef<Value> extraOperands,
-                                             ArrayRef<Value> symbolOperands) {
+                                             ArrayRef<Value> symbolOperands,
+                                             bool allowNonDereferencingOps) {
   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
   (void)newMemRefRank; // unused in opt mode
   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
@@ -67,11 +68,6 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
          newMemRef.getType().cast<MemRefType>().getElementType());
 
-  if (!isMemRefDereferencingOp(*op))
-    // Failure: memref used in a non-dereferencing context (potentially
-    // escapes); no replacement in these cases.
-    return failure();
-
   SmallVector<unsigned, 2> usePositions;
   for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
     if (opEntry.value() == oldMemRef)
@@ -91,6 +87,18 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
   unsigned memRefOperandPos = usePositions.front();
 
   OpBuilder builder(op);
+  // The following checks if op is dereferencing memref and performs the access
+  // index rewrites.
+  if (!isMemRefDereferencingOp(*op)) {
+    if (!allowNonDereferencingOps)
+      // Failure: memref used in a non-dereferencing context (potentially
+      // escapes); no replacement in these cases unless allowNonDereferencingOps
+      // is set.
+      return failure();
+    op->setOperand(memRefOperandPos, newMemRef);
+    return success();
+  }
+  // Perform index rewrites for the dereferencing op and then replace the op
   NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
   AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
   unsigned oldMapNumInputs = oldMap.getNumInputs();
@@ -112,7 +120,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
       affineApplyOps.push_back(afOp);
     }
   } else {
-    oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
+    oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
   }
 
   // Construct new indices as a remap of the old ones if a remapping has been
@@ -141,14 +149,14 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
     }
   } else {
     // No remapping specified.
-    remapOutputs.append(remapOperands.begin(), remapOperands.end());
+    remapOutputs.assign(remapOperands.begin(), remapOperands.end());
   }
 
   SmallVector<Value, 4> newMapOperands;
   newMapOperands.reserve(newMemRefRank);
 
   // Prepend 'extraIndices' in 'newMapOperands'.
-  for (auto extraIndex : extraIndices) {
+  for (Value extraIndex : extraIndices) {
     assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
            "single result op's expected to generate these indices");
     assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
@@ -167,12 +175,12 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
   newMap = simplifyAffineMap(newMap);
   canonicalizeMapAndOperands(&newMap, &newMapOperands);
   // Remove any affine.apply's that became dead as a result of composition.
-  for (auto value : affineApplyOps)
+  for (Value value : affineApplyOps)
     if (value.use_empty())
       value.getDefiningOp()->erase();
 
-  // Construct the new operation using this memref.
   OperationState state(op->getLoc(), op->getName());
+  // Construct the new operation using this memref.
   state.operands.reserve(op->getNumOperands() + extraIndices.size());
   // Insert the non-memref operands.
   state.operands.append(op->operand_begin(),
@@ -196,11 +204,10 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
   // Add attribute for 'newMap', other Attributes do not change.
   auto newMapAttr = AffineMapAttr::get(newMap);
   for (auto namedAttr : op->getAttrs()) {
-    if (namedAttr.first == oldMapAttrPair.first) {
+    if (namedAttr.first == oldMapAttrPair.first)
       state.attributes.push_back({namedAttr.first, newMapAttr});
-    } else {
+    else
       state.attributes.push_back(namedAttr);
-    }
   }
 
   // Create the new operation.
@@ -211,13 +218,12 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
   return success();
 }
 
-LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
-                                             ArrayRef<Value> extraIndices,
-                                             AffineMap indexRemap,
-                                             ArrayRef<Value> extraOperands,
-                                             ArrayRef<Value> symbolOperands,
-                                             Operation *domInstFilter,
-                                             Operation *postDomInstFilter) {
+LogicalResult mlir::replaceAllMemRefUsesWith(
+    Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
+    AffineMap indexRemap, ArrayRef<Value> extraOperands,
+    ArrayRef<Value> symbolOperands, Operation *domInstFilter,
+    Operation *postDomInstFilter, bool allowNonDereferencingOps,
+    bool replaceInDeallocOp) {
   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
   (void)newMemRefRank; // unused in opt mode
   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
@@ -261,16 +267,21 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
 
     // Skip dealloc's - no replacement is necessary, and a memref replacement
     // at other uses doesn't hurt these dealloc's.
-    if (isa<DeallocOp>(op))
+    if (isa<DeallocOp>(op) && !replaceInDeallocOp)
       continue;
 
     // Check if the memref was used in a non-dereferencing context. It is fine
     // for the memref to be used in a non-dereferencing way outside of the
     // region where this replacement is happening.
-    if (!isMemRefDereferencingOp(*op))
-      // Failure: memref used in a non-dereferencing op (potentially escapes);
-      // no replacement in these cases.
-      return failure();
+    if (!isMemRefDereferencingOp(*op)) {
+      // Currently we support the following non-dereferencing types to be a
+      // candidate for replacement: Dealloc and CallOp.
+      // TODO: Add support for other kinds of ops.
+      if (!allowNonDereferencingOps)
+        return failure();
+      if (!(isa<DeallocOp, CallOp>(*op)))
+        return failure();
+    }
 
     // We'll first collect and then replace --- since replacement erases the op
     // that has the use, and that op could be postDomFilter or domFilter itself!
@@ -278,9 +289,9 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
   }
 
   for (auto *op : opsToReplace) {
-    if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
-                                        indexRemap, extraOperands,
-                                        symbolOperands)))
+    if (failed(replaceAllMemRefUsesWith(
+            oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
+            symbolOperands, allowNonDereferencingOps)))
       llvm_unreachable("memref replacement guaranteed to succeed here");
   }
 
@@ -385,85 +396,102 @@ void mlir::createAffineComputationSlice(
 // TODO: Currently works for static memrefs with a single layout map.
 LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
   MemRefType memrefType = allocOp.getType();
-  unsigned rank = memrefType.getRank();
-  if (rank == 0)
-    return success();
-
-  auto layoutMaps = memrefType.getAffineMaps();
   OpBuilder b(allocOp);
-  if (layoutMaps.size() != 1)
+
+  // Fetch a new memref type after normalizing the old memref to have an
+  // identity map layout.
+  MemRefType newMemRefType =
+      normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands());
+  if (newMemRefType == memrefType)
+    // Either memrefType already had an identity map or the map couldn't be
+    // transformed to an identity map.
     return failure();
 
-  AffineMap layoutMap = layoutMaps.front();
+  Value oldMemRef = allocOp.getResult();
 
-  // Nothing to do for identity layout maps.
-  if (layoutMap == b.getMultiDimIdentityMap(rank))
-    return success();
+  SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
+  AllocOp newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType,
+                                       llvm::None, allocOp.alignmentAttr());
+  AffineMap layoutMap = memrefType.getAffineMaps().front();
+  // Replace all uses of the old memref.
+  if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
+                                      /*extraIndices=*/{},
+                                      /*indexRemap=*/layoutMap,
+                                      /*extraOperands=*/{},
+                                      /*symbolOperands=*/symbolOperands,
+                                      /*domInstFilter=*/nullptr,
+                                      /*postDomInstFilter=*/nullptr,
+                                      /*allowDereferencingOps=*/true))) {
+    // If it failed (due to escapes for example), bail out.
+    newAlloc.erase();
+    return failure();
+  }
+  // Replace any uses of the original alloc op and erase it. All remaining uses
+  // have to be dealloc's; RAMUW above would've failed otherwise.
+  assert(llvm::all_of(oldMemRef.getUsers(),
+                      [](Operation *op) { return isa<DeallocOp>(op); }));
+  oldMemRef.replaceAllUsesWith(newAlloc);
+  allocOp.erase();
+  return success();
+}
+
+MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
+                                     unsigned numSymbolicOperands) {
+  unsigned rank = memrefType.getRank();
+  if (rank == 0)
+    return memrefType;
+
+  ArrayRef<AffineMap> layoutMaps = memrefType.getAffineMaps();
+  if (layoutMaps.empty() ||
+      layoutMaps.front() == b.getMultiDimIdentityMap(rank)) {
+    // Either no maps is associated with this memref or this memref has
+    // a trivial (identity) map.
+    return memrefType;
+  }
 
   // We don't do any checks for one-to-one'ness; we assume that it is
   // one-to-one.
 
   // TODO: Only for static memref's for now.
   if (memrefType.getNumDynamicDims() > 0)
-    return failure();
+    return memrefType;
 
-  // We have a single map that is not an identity map. Create a new memref with
-  // the right shape and an identity layout map.
-  auto shape = memrefType.getShape();
-  FlatAffineConstraints fac(rank, allocOp.getNumSymbolicOperands());
+  // We have a single map that is not an identity map. Create a new memref
+  // with the right shape and an identity layout map.
+  ArrayRef<int64_t> shape = memrefType.getShape();
+  // FlatAffineConstraint may later on use symbolicOperands.
+  FlatAffineConstraints fac(rank, numSymbolicOperands);
   for (unsigned d = 0; d < rank; ++d) {
     fac.addConstantLowerBound(d, 0);
     fac.addConstantUpperBound(d, shape[d] - 1);
   }
-
-  // We compose this map with the original index (logical) space to derive the
-  // upper bounds for the new index space.
+  // We compose this map with the original index (logical) space to derive
+  // the upper bounds for the new index space.
+  AffineMap layoutMap = layoutMaps.front();
   unsigned newRank = layoutMap.getNumResults();
   if (failed(fac.composeMatchingMap(layoutMap)))
-    // TODO: semi-affine maps.
-    return failure();
-
+    return memrefType;
+  // TODO: Handle semi-affine maps.
   // Project out the old data dimensions.
   fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
   SmallVector<int64_t, 4> newShape(newRank);
   for (unsigned d = 0; d < newRank; ++d) {
     // The lower bound for the shape is always zero.
     auto ubConst = fac.getConstantUpperBound(d);
-    // For a static memref and an affine map with no symbols, this is always
-    // bounded.
+    // For a static memref and an affine map with no symbols, this is
+    // always bounded.
     assert(ubConst.hasValue() && "should always have an upper bound");
     if (ubConst.getValue() < 0)
       // This is due to an invalid map that maps to a negative space.
-      return failure();
+      return memrefType;
     newShape[d] = ubConst.getValue() + 1;
   }
 
-  auto oldMemRef = allocOp.getResult();
-  SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
-
+  // Create the new memref type after trivializing the old layout map.
   MemRefType newMemRefType =
       MemRefType::Builder(memrefType)
           .setShape(newShape)
           .setAffineMaps(b.getMultiDimIdentityMap(newRank));
 
-  auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType, llvm::None,
-                                    allocOp.alignmentAttr());
-
-  // Replace all uses of the old memref.
-  if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
-                                      /*extraIndices=*/{},
-                                      /*indexRemap=*/layoutMap,
-                                      /*extraOperands=*/{},
-                                      /*symbolOperands=*/symbolOperands))) {
-    // If it failed (due to escapes for example), bail out.
-    newAlloc.erase();
-    return failure();
-  }
-  // Replace any uses of the original alloc op and erase it. All remaining uses
-  // have to be dealloc's; RAMUW above would've failed otherwise.
-  assert(llvm::all_of(oldMemRef.getUsers(),
-                      [](Operation *op) { return isa<DeallocOp>(op); }));
-  oldMemRef.replaceAllUsesWith(newAlloc);
-  allocOp.erase();
-  return success();
+  return newMemRefType;
 }

diff  --git a/mlir/test/Transforms/memref-normalize.mlir b/mlir/test/Transforms/normalize-memrefs.mlir
similarity index 65%
rename from mlir/test/Transforms/memref-normalize.mlir
rename to mlir/test/Transforms/normalize-memrefs.mlir
index 375bd3ef0e6c..7d56c8893940 100644
--- a/mlir/test/Transforms/memref-normalize.mlir
+++ b/mlir/test/Transforms/normalize-memrefs.mlir
@@ -1,4 +1,7 @@
-// RUN: mlir-opt -allow-unregistered-dialect -simplify-affine-structures %s | FileCheck %s
+// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck %s
+
+// This file tests whether the memref type having non-trivial map layouts
+// are normalized to trivial (identity) layouts.
 
 // CHECK-LABEL: func @permute()
 func @permute() {
@@ -150,3 +153,61 @@ func @alignment() {
   // CHECK-NEXT: alloc() {alignment = 32 : i64} : memref<256x64x128xf32>
   return
 }
+
+#tile = affine_map < (i)->(i floordiv 4, i mod 4) >
+
+// Following test cases check the inter-procedural memref normalization.
+
+// Test case 1: Check normalization for multiple memrefs in a function argument list.
+// CHECK-LABEL: func @multiple_argument_type
+// CHECK-SAME:  (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>, %[[D:arg[0-9]+]]: memref<24xf64>) -> f64
+func @multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>, %D: memref<24xf64>) -> f64 {
+  %a = affine.load %A[0] : memref<16xf64, #tile>
+  %p = mulf %a, %a : f64
+  affine.store %p, %A[10] : memref<16xf64, #tile>
+  call @single_argument_type(%C): (memref<8xf64, #tile>) -> ()
+  return %B : f64
+}
+
+// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
+// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64
+// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64>
+// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> ()
+// CHECK: return %[[B]] : f64
+
+// Test case 2: Check normalization for single memref argument in a function.
+// CHECK-LABEL: func @single_argument_type
+// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>)
+func @single_argument_type(%C : memref<8xf64, #tile>) {
+  %a = alloc(): memref<8xf64, #tile>
+  %b = alloc(): memref<16xf64, #tile>
+  %d = constant 23.0 : f64
+  %e = alloc(): memref<24xf64>
+  call @single_argument_type(%a): (memref<8xf64, #tile>) -> ()
+  call @single_argument_type(%C): (memref<8xf64, #tile>) -> ()
+  call @multiple_argument_type(%b, %d, %a, %e): (memref<16xf64, #tile>, f64, memref<8xf64, #tile>, memref<24xf64>) -> f64
+  return
+}
+
+// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64>
+// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64>
+// CHECK: %cst = constant 2.300000e+01 : f64
+// CHECK: %[[e:[0-9]+]] = alloc() : memref<24xf64>
+// CHECK: call @single_argument_type(%[[a]]) : (memref<2x4xf64>) -> ()
+// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> ()
+// CHECK: call @multiple_argument_type(%[[b]], %cst, %[[a]], %[[e]]) : (memref<4x4xf64>, f64, memref<2x4xf64>, memref<24xf64>) -> f64
+
+// Test case 3: Check function returning any other type except memref.
+// CHECK-LABEL: func @non_memref_ret
+// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> i1
+func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 {
+  %d = constant 1 : i1
+  return %d : i1
+}
+
+// Test case 4: No normalization should take place because the function is returning the memref.
+// CHECK-LABEL: func @memref_used_in_return
+// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}>
+func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) {
+  return %A : memref<8xf64, #tile>
+}


        


More information about the Mlir-commits mailing list