[llvm-branch-commits] [mlir] b2391d5 - [MLIR] Normalize the results of normalizable operations

Uday Bondhugula via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 3 06:09:33 PST 2020


Author: Haruki Imai
Date: 2020-12-03T19:34:07+05:30
New Revision: b2391d5f0da2a9122d337de6fa1a6aa9baad3fb4

URL: https://github.com/llvm/llvm-project/commit/b2391d5f0da2a9122d337de6fa1a6aa9baad3fb4
DIFF: https://github.com/llvm/llvm-project/commit/b2391d5f0da2a9122d337de6fa1a6aa9baad3fb4.diff

LOG: [MLIR] Normalize the results of normalizable operations

Memrefs with affine_map in the results of normalizable operation were
not normalized by `--normalize-memrefs` option. This patch normalizes
them.

Differential Revision: https://reviews.llvm.org/D88719

Added: 
    

Modified: 
    mlir/lib/Transforms/NormalizeMemRefs.cpp
    mlir/test/Transforms/normalize-memrefs-ops.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index 44b3ccbd2c3f..d7fa212baa73 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -36,6 +36,7 @@ struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
   void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
   void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
                                            DenseSet<FuncOp> &normalizableFuncs);
+  Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp);
 };
 
 } // end anonymous namespace
@@ -384,6 +385,59 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
     funcOp.front().eraseArgument(argIndex + 1);
   }
 
+  // Walk over normalizable operations to normalize memrefs of the operation
+  // results. When `op` has memrefs with affine map in the operation results,
+  // new operation containin normalized memrefs is created. Then, the memrefs
+  // are replaced. `CallOp` is skipped here because it is handled in
+  // `updateFunctionSignature()`.
+  funcOp.walk([&](Operation *op) {
+    if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
+        op->getNumResults() > 0 && !isa<CallOp>(op) && !funcOp.isExternal()) {
+      // Create newOp containing normalized memref in the operation result.
+      Operation *newOp = createOpResultsNormalized(funcOp, op);
+      // When all of the operation results have no memrefs or memrefs without
+      // affine map, `newOp` is the same with `op` and following process is
+      // skipped.
+      if (op != newOp) {
+        bool replacingMemRefUsesFailed = false;
+        for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
+          // Replace all uses of the old memrefs.
+          Value oldMemRef = op->getResult(resIndex);
+          Value newMemRef = newOp->getResult(resIndex);
+          MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
+          // Check whether the operation result is MemRef type.
+          if (!oldMemRefType)
+            continue;
+          MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
+          if (oldMemRefType == newMemRefType)
+            continue;
+          // TODO: Assume single layout map. Multiple maps not supported.
+          AffineMap layoutMap = oldMemRefType.getAffineMaps().front();
+          if (failed(replaceAllMemRefUsesWith(oldMemRef,
+                                              /*newMemRef=*/newMemRef,
+                                              /*extraIndices=*/{},
+                                              /*indexRemap=*/layoutMap,
+                                              /*extraOperands=*/{},
+                                              /*symbolOperands=*/{},
+                                              /*domInstFilter=*/nullptr,
+                                              /*postDomInstFilter=*/nullptr,
+                                              /*allowDereferencingOps=*/true,
+                                              /*replaceInDeallocOp=*/true))) {
+            newOp->erase();
+            replacingMemRefUsesFailed = true;
+            continue;
+          }
+        }
+        if (!replacingMemRefUsesFailed) {
+          // Replace other ops with new op and delete the old op when the
+          // replacement succeeded.
+          op->replaceAllUsesWith(newOp);
+          op->erase();
+        }
+      }
+    }
+  });
+
   // In a normal function, memrefs in the return type signature gets normalized
   // as a result of normalization of functions arguments, AllocOps or CallOps'
   // result types. Since an external function doesn't have a body, memrefs in
@@ -417,3 +471,49 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
   }
   updateFunctionSignature(funcOp, moduleOp);
 }
+
+/// Create an operation containing normalized memrefs in the operation results.
+/// When the results of `oldOp` have memrefs with affine map, the memrefs are
+/// normalized, and new operation containing them in the operation results is
+/// returned. If all of the results of `oldOp` have no memrefs or memrefs
+/// without affine map, `oldOp` is returned without modification.
+Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
+                                                       Operation *oldOp) {
+  // Prepare OperationState to create newOp containing normalized memref in
+  // the operation results.
+  OperationState result(oldOp->getLoc(), oldOp->getName());
+  result.addOperands(oldOp->getOperands());
+  result.addAttributes(oldOp->getAttrs());
+  // Add normalized MemRefType to the OperationState.
+  SmallVector<Type, 4> resultTypes;
+  OpBuilder b(funcOp);
+  bool resultTypeNormalized = false;
+  for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
+    auto resultType = oldOp->getResult(resIndex).getType();
+    MemRefType memrefType = resultType.dyn_cast<MemRefType>();
+    // Check whether the operation result is MemRef type.
+    if (!memrefType) {
+      resultTypes.push_back(resultType);
+      continue;
+    }
+    // Fetch a new memref type after normalizing the old memref.
+    MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
+                                                   /*numSymbolicOperands=*/0);
+    if (newMemRefType == memrefType) {
+      // Either memrefType already had an identity map or the map couldn't
+      // be transformed to an identity map.
+      resultTypes.push_back(memrefType);
+      continue;
+    }
+    resultTypes.push_back(newMemRefType);
+    resultTypeNormalized = true;
+  }
+  result.addTypes(resultTypes);
+  // When all of the results of `oldOp` have no memrefs or memrefs without
+  // affine map, `oldOp` is returned without modification.
+  if (resultTypeNormalized) {
+    OpBuilder bb(oldOp);
+    return bb.createOperation(result);
+  } else
+    return oldOp;
+}

diff  --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir
index 92b23c07887f..9567ab5f8f83 100644
--- a/mlir/test/Transforms/normalize-memrefs-ops.mlir
+++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir
@@ -14,13 +14,13 @@
 // Test with op_norm and maps in arguments and in the operations in the function.
 
 // CHECK-LABEL: test_norm
-// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>)
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
 func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
     %0 = alloc() : memref<1x16x14x14xf32, #map0>
     "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
     dealloc %0 :  memref<1x16x14x14xf32, #map0>
 
-    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
+    // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x64xf32>
     // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
     // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
     return
@@ -29,13 +29,13 @@ func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
 // Same test with op_nonnorm, with maps in the arguments and the operations in the function.
 
 // CHECK-LABEL: test_nonnorm
-// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32, #map>)
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #map>)
 func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
     %0 = alloc() : memref<1x16x14x14xf32, #map0>
     "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
     dealloc %0 :  memref<1x16x14x14xf32, #map0>
 
-    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32, #map>
+    // CHECK: %[[v0:.*]] = alloc() : memref<1x16x14x14xf32, #map>
     // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map>, memref<1x16x14x14xf32, #map>) -> ()
     // CHECK: dealloc %[[v0]] : memref<1x16x14x14xf32, #map>
     return
@@ -44,13 +44,13 @@ func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
 // Test with op_norm, with maps in the operations in the function.
 
 // CHECK-LABEL: test_norm_mix
-// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>
 func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
     %0 = alloc() : memref<1x16x14x14xf32, #map0>
     "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
     dealloc %0 :  memref<1x16x14x14xf32, #map0>
 
-    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
+    // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x64xf32>
     // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
     // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
     return
@@ -61,12 +61,12 @@ func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
 #map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)>
 
 // CHECK-LABEL: test_load_store
-// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32>
 func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
     %0 = alloc() : memref<1x16x14x14xf32, #map_tile>
-    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x32xf32>
+    // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x32xf32>
     %1 = alloc() : memref<1x16x14x14xf32>
-    // CHECK: %[[v1:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32>
+    // CHECK: %[[v1:.*]] = alloc() : memref<1x16x14x14xf32>
     "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> ()
     // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> ()
     %cst = constant 3.0 : f32
@@ -90,6 +90,25 @@ func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
     return
 }
 
+// Test with op_norm_ret, with maps in the results of normalizable operation.
+
+// CHECK-LABEL: test_norm_ret
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) {
+func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) {
+    %0 = alloc() : memref<1x16x14x14xf32, #map_tile>
+    // CHECK-NEXT: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x32xf32>
+    %1, %2 = "test.op_norm_ret"(%arg0) : (memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>)
+    // CHECK-NEXT: %[[v1:.*]], %[[v2:.*]] = "test.op_norm_ret"
+    // CHECK-SAME: (memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>)
+    "test.op_norm"(%1, %0) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32, #map_tile>) -> ()
+    // CHECK-NEXT: "test.op_norm"
+    // CHECK-SAME: : (memref<1x16x1x1x32x32xf32>, memref<1x16x1x1x32x32xf32>) -> ()
+    dealloc %0 : memref<1x16x14x14xf32, #map_tile>
+    // CHECK-NEXT: dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
+    return %1, %2 : memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>
+    // CHECK-NEXT: return %[[v1]], %[[v2]] : memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>
+}
+
 // Test with an arbitrary op that references the function symbol.
 
 "test.op_funcref"() {func = @test_norm_mix} : () -> ()

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 995022ab2115..7d9274fbddee 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -666,6 +666,11 @@ def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
 def OpNonNorm : TEST_Op<"op_nonnorm"> {
   let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
 }
+// Test for memrefs normalization of an op that has normalizable memref results.
+def OpNormRet : TEST_Op<"op_norm_ret", [MemRefsNormalizable]> {
+  let arguments = (ins AnyMemRef:$X);
+  let results = (outs AnyMemRef:$Y, AnyMemRef:$Z);
+}
 
 // Test for memrefs normalization of an op with a reference to a function
 // symbol.


        


More information about the llvm-branch-commits mailing list