[Mlir-commits] [mlir] fa0d044 - [mlir] Fix canonicalization of tiled_loop if not all opresults fold.

Alexander Belyaev llvmlistbot at llvm.org
Wed Apr 28 11:03:54 PDT 2021


Author: Alexander Belyaev
Date: 2021-04-28T19:57:48+02:00
New Revision: fa0d044c4499535fb7960a5b7053bd043ad09e52

URL: https://github.com/llvm/llvm-project/commit/fa0d044c4499535fb7960a5b7053bd043ad09e52
DIFF: https://github.com/llvm/llvm-project/commit/fa0d044c4499535fb7960a5b7053bd043ad09e52.diff

LOG: [mlir] Fix canonicalization of tiled_loop if not all opresults fold.

The current canonicalization did not remap operation results correctly
and attempted to erase tiledLoop, which is incorrect if not all tensor
results are folded.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 17ecab19c2a8..750d818bd261 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2094,7 +2094,7 @@ namespace {
 
 static constexpr int64_t kNoMatch = -1;
 
-// Folds away TiledLoopOp input tensors if they have no uses within the body.
+// Folds away TiledLoopOp inputs if they have no uses within the body.
 //
 // Example:
 //
@@ -2117,7 +2117,7 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
       Value in, bbArg;
       size_t index = en.index();
       std::tie(in, bbArg) = en.value();
-      if (!in.getType().isa<RankedTensorType>() || !bbArg.use_empty()) {
+      if (!bbArg.use_empty()) {
         oldInputIdToNew[index] = newInputs.size();
         newInputs.push_back(in);
       }
@@ -2142,7 +2142,7 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
     for (auto &op : *tiledLoop.getBody())
       innerBuilder.clone(op, bvm);
-    rewriter.eraseOp(tiledLoop);
+    rewriter.replaceOp(tiledLoop, newTiledLoop.getResults());
 
     return success();
   }
@@ -2184,6 +2184,10 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
     // Store ids of the corresponding old and new output operands.
     SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
                                              kNoMatch);
+    // Store ids of the corresponding old and new results.
+    SmallVector<int64_t, 2> oldResultIdToNew(tiledLoop.getNumResults(),
+                                             kNoMatch);
+    SmallVector<Value, 2> resultReplacement(tiledLoop.getNumResults());
     for (auto en : llvm::enumerate(
              llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
       size_t index = en.index();
@@ -2199,6 +2203,8 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
       Value yieldArg = yieldOp.getOperand(resultId);
       if (yieldArg != outRegionArg || !result.use_empty()) {
         oldOutputIdToNew[index] = newOutputOperands.size();
+        oldResultIdToNew[resultId] = newYieldArgs.size();
+        resultReplacement[resultId] = out;
         newOutputOperands.push_back(out);
         newYieldArgs.push_back(yieldArg);
       }
@@ -2228,8 +2234,14 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
     for (auto &op : tiledLoop.getBody()->without_terminator())
       innerBuilder.clone(op, bvm);
-    innerBuilder.create<linalg::YieldOp>(loc, newYieldArgs);
-    rewriter.eraseOp(tiledLoop);
+    innerBuilder.create<linalg::YieldOp>(
+        loc, llvm::to_vector<2>(llvm::map_range(
+                 newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
+
+    for (const auto &en : llvm::enumerate(oldResultIdToNew))
+      if (en.value() != kNoMatch)
+        resultReplacement[en.index()] = newTiledLoop.getResult(en.value());
+    rewriter.replaceOp(tiledLoop, resultReplacement);
 
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index e66ee388c65e..244b78f39e1c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -867,75 +867,71 @@ func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32
   return %1 : tensor<?x?xf32>
 }
 
-// -----
 
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+// -----
 
-func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
-                  %C: memref<192x192xf32>) -> ()
+func private @foo(%A: memref<48xf32>, %B: tensor<48xf32>,
+                  %C: memref<48xf32>) -> (tensor<48xf32>)
 
-func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
-                              %C: memref<192x192xf32>,
-                              %C_tensor: tensor<192x192xf32>) {
-  %cst = constant 0.000000e+00 : f32
-  %c24 = constant 24 : index
-  %c16 = constant 16 : index
+func @fold_tiled_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>,
+    %C: memref<48xf32>, %C_tensor: tensor<48xf32>) -> tensor<48xf32> {
   %c0 = constant 0 : index
-  %c192 = constant 192 : index
-  %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
-      step (%c24, %c16)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
-      outs (%CT_ = %C_tensor: tensor<192x192xf32>,
-            %C_ = %C: memref<192x192xf32>) {
-        call @foo(%A_, %B_, %C_)
-          : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
-    linalg.yield %CT_ : tensor<192x192xf32>
+  %c24 = constant 24 : index
+  %c48 = constant 48 : index
+  %useful, %useless = linalg.tiled_loop (%i) = (%c0) to (%c48) step (%c24)
+      ins (%A_ = %A: memref<48xf32>)
+      outs (%B_ = %B: tensor<48xf32>,
+            %CT_ = %C_tensor: tensor<48xf32>,
+            %C_ = %C: memref<48xf32>) {
+        %result = call @foo(%A_, %B_, %C_)
+          : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>)
+    linalg.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32>
   }
-  return
+  return %useful : tensor<48xf32>
 }
 
 // CHECK-LABEL: func @fold_tiled_loop_results(
-// CHECK-SAME:    %[[A:.*]]: [[TY:.*]], %[[B:.*]]: [[TY]], %[[C:.*]]: [[TY]],
-// CHECK-SAME:    %[[C_TENSOR:.*]]: tensor<{{.*}}>) {
-// CHECK:  %[[C24:.*]] = constant 24 : index
-// CHECK:  %[[C16:.*]] = constant 16 : index
+// CHECK-SAME:   %[[A:.*]]: [[BUF_TY:memref<48xf32>]], %[[B:.*]]: [[TY:tensor<48xf32>]],
+// CHECK-SAME:   %[[C:.*]]: [[BUF_TY]],  %[[C_TENSOR:.*]]: [[TY]]) -> [[TY]] {
+
 // CHECK:  %[[C0:.*]] = constant 0 : index
-// CHECK:  %[[C192:.*]] = constant 192 : index
+// CHECK:  %[[C24:.*]] = constant 24 : index
+// CHECK:  %[[C48:.*]] = constant 48 : index
 
 // CHECK-NOT: %{{.*}} = linalg.tiled_loop
-// CHECK:  linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]])
-// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]])
-// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: memref<192x192xf32>, %[[B_:.*]] = %[[B]]: memref<192x192xf32>)
-// CHECK-SAME: outs (%[[C_:.*]] = %[[C]]: memref<192x192xf32>) {
-// CHECK-NEXT:   call @foo(%[[A_]], %[[B_]], %[[C_]])
-// CHECK-NEXT:   linalg.yield
+// CHECK:  %[[RESULT:.*]] = linalg.tiled_loop (%{{.*}}) = (%[[C0]])
+// CHECK-SAME: to (%[[C48]]) step (%[[C24]])
+// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]])
+// CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) {
+// CHECK-NEXT:   %[[RES:.*]] = call @foo(%[[A_]], %[[B_]], %[[C_]])
+// CHECK-NEXT:   linalg.yield %[[RES]] :
 
-// -----
+// CHECK: return %[[RESULT]]
 
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+// -----
 
-func private @foo(%A: memref<192xf32>) -> ()
+func private @foo(%A: memref<192xf32>, %B: tensor<192xf32>) -> tensor<192xf32>
 
-func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>) {
+func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
+                             %B_tensor: tensor<192xf32>) -> tensor<192xf32> {
   %c0 = constant 0 : index
   %c24 = constant 24 : index
   %c192 = constant 192 : index
-  linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24)
-      ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) {
-        call @foo(%A_) : (memref<192xf32>)-> ()
-    linalg.yield
+  %result = linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24)
+      ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>)
+      outs (%BT_ = %B_tensor: tensor<192xf32>) {
+    %0 = call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32>
+    linalg.yield %0 : tensor<192xf32>
   }
-  return
+  return %result : tensor<192xf32>
 }
 
 // CHECK-LABEL: func @fold_tiled_loop_inputs
-// CHECK: linalg.tiled_loop
+// CHECK: %[[RESULT:.*]] = linalg.tiled_loop
 // CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>)
 
+// CHECK: return %[[RESULT]]
+
 // -----
 
 func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,


        


More information about the Mlir-commits mailing list