[Mlir-commits] [mlir] 183c4a3 - [MLIR][normalize-memrefs] Non-normalizable operations with identity map layouts do not block normalization of the entire function

Uday Bondhugula llvmlistbot at llvm.org
Thu Aug 18 20:35:02 PDT 2022


Author: Tung D. Le
Date: 2022-08-19T08:27:45+05:30
New Revision: 183c4a391ef344220664d1d103d43639468bf103

URL: https://github.com/llvm/llvm-project/commit/183c4a391ef344220664d1d103d43639468bf103
DIFF: https://github.com/llvm/llvm-project/commit/183c4a391ef344220664d1d103d43639468bf103.diff

LOG: [MLIR][normalize-memrefs] Non-normalizable operations with identity map layouts do not block normalization of the entire function

The current approach is convervative in which whenever there is a
non-normalizable operations in a function will the function be labelled
as non-normalizable. It means it requires that all operations must have
MemRefsNormalizable trait.

This patch relaxes the requirement that if the memref map layouts of a
non-normalizable operation are identity, this operation does not block
the normalization of the other operations in the same function.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D125854

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
    mlir/test/Transforms/normalize-memrefs-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index b0b31c9189039..55ce128ff4683 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -145,10 +145,10 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
 /// Check whether all the uses of AllocOps, CallOps and function arguments of a
 /// function are either of dereferencing type or are uses in: DeallocOp, CallOp
 /// or ReturnOp. Only if these constraints are satisfied will the function
-/// become a candidate for normalization. We follow a conservative approach here
-/// wherein even if the non-normalizable memref is not a part of the function's
-/// argument or return type, we still label the entire function as
-/// non-normalizable. We assume external functions to be normalizable.
+/// become a candidate for normalization. When the uses of a memref are
+/// non-normalizable and the memref map layout is trivial (identity), we can
+/// still label the entire function as normalizable. We assume external
+/// functions to be normalizable.
 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
   // We assume external functions to be normalizable.
   if (funcOp.isExternal())
@@ -157,7 +157,11 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
   if (funcOp
           .walk([&](memref::AllocOp allocOp) -> WalkResult {
             Value oldMemRef = allocOp.getResult();
-            if (!isMemRefNormalizable(oldMemRef.getUsers()))
+            if (!oldMemRef.getType()
+                     .cast<MemRefType>()
+                     .getLayout()
+                     .isIdentity() &&
+                !isMemRefNormalizable(oldMemRef.getUsers()))
               return WalkResult::interrupt();
             return WalkResult::advance();
           })
@@ -170,7 +174,11 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
                  llvm::seq<unsigned>(0, callOp.getNumResults())) {
               Value oldMemRef = callOp.getResult(resIndex);
               if (oldMemRef.getType().isa<MemRefType>())
-                if (!isMemRefNormalizable(oldMemRef.getUsers()))
+                if (!oldMemRef.getType()
+                         .cast<MemRefType>()
+                         .getLayout()
+                         .isIdentity() &&
+                    !isMemRefNormalizable(oldMemRef.getUsers()))
                   return WalkResult::interrupt();
             }
             return WalkResult::advance();
@@ -181,7 +189,8 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
   for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
     BlockArgument oldMemRef = funcOp.getArgument(argIndex);
     if (oldMemRef.getType().isa<MemRefType>())
-      if (!isMemRefNormalizable(oldMemRef.getUsers()))
+      if (!oldMemRef.getType().cast<MemRefType>().getLayout().isIdentity() &&
+          !isMemRefNormalizable(oldMemRef.getUsers()))
         return false;
   }
 

diff  --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir
index a16ae149c1923..b45b62a92e4a6 100644
--- a/mlir/test/Transforms/normalize-memrefs-ops.mlir
+++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir
@@ -41,6 +41,24 @@ func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
     return
 }
 
+// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm
+// does not block the normalization of other operations.
+
+// CHECK-LABEL: test_nonnorm_identity_layout
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
+func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
+    %0 = memref.alloc() : memref<1x16x14x14xf32>
+    "test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
+    "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> ()
+    memref.dealloc %0 :  memref<1x16x14x14xf32>
+
+    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32>
+    // CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
+    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> ()
+    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32>
+    return
+}
+
 // Test with op_norm, with maps in the operations in the function.
 
 // CHECK-LABEL: test_norm_mix


        


More information about the Mlir-commits mailing list