[Mlir-commits] [mlir] [mlir][Linalg]: Optimize any structured linalg operation in transform::PromoteOp to avoid unnecessary copies (PR #69876)

Aviad Cohen llvmlistbot at llvm.org
Sun Oct 22 04:58:35 PDT 2023


https://github.com/AviadCo created https://github.com/llvm/llvm-project/pull/69876

Before promotion, there is no need to copy outputs thats are not considered to init tensors.


>From ac98c7a79b927eed62b272819276c553c9020da3 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 22 Oct 2023 14:37:57 +0300
Subject: [PATCH] [mlir][Linalg]: Optimize any structured linalg operation in
 transform::PromoteOp to avoid unnecessary copies

Before promotion, there is no need to copy outputs thats are not
considered to init tensors.
---
 mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp  | 6 ++----
 mlir/test/Dialect/Linalg/promote.mlir             | 6 +++---
 mlir/test/Dialect/Linalg/promotion_options.mlir   | 2 +-
 mlir/test/Dialect/Linalg/transform-promotion.mlir | 2 +-
 4 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index a131f3097666197..34f8bdf844b3fef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -177,10 +177,8 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
     Operation *op = opOperand.get().getDefiningOp();
     if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
       subViews[operandNumber] = sv;
-      // In case of linalg generic, copy in only if subview is used in linalg
-      // payload.
-      if (!isa<linalg::GenericOp>(linalgOp) ||
-          linalgOp.payloadUsesValueFromOperand(&opOperand))
+      // Copy in only if subview is being used by the linalg operation.
+      if (linalgOp.isDpsInput(&opOperand) || !linalgOp.isInitTensor(&opOperand))
         operandsNumbersToCopyIn.insert(operandNumber);
       useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
     }
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index 31b29c0e105d99d..e97a2ca9cf8cb48 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -54,7 +54,7 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 
 //       CHECK:         memref.copy %[[vA]], %[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
 //       CHECK:         memref.copy %[[vB]], %[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vC]], %[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+//       CHECK-NOT:     memref.copy %[[vC]], %[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
 //
 //       CHECK:         linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
 //
@@ -124,7 +124,7 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 
 //       CHECK:         memref.copy %[[vA_f64]], %[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
 //       CHECK:         memref.copy %[[vB_f64]], %[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vC_f64]], %[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
+//       CHECK-NOT:     memref.copy %[[vC_f64]], %[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
 //
 //       CHECK:         linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
 //
@@ -255,7 +255,7 @@ func.func @promote_rank_reducing_subviews(%arg0:  memref<?x?x?x64xf32, strided<[
   // CHECK: %[[c_view:.+]] = memref.view
   // CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]
 
-  // CHECK-COUNT-3: memref.copy
+  // CHECK-COUNT-2: memref.copy
   // CHECK: linalg.generic
   // CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
   // CHECK-SAME: outs(%[[c_pro_subview]]
diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index a6daa9af2f37cec..9cce028df48bcb2 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -28,7 +28,7 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
 //      CHECK:       %[[svCC:.+]] = memref.subview %[[VC]]
 
 //      CHECK:       memref.copy %[[svA]], %[[svAA]]
-//      CHECK:       memref.copy %[[svC]], %[[svCC]]
+//      CHECK-NOT:   memref.copy %[[svC]], %[[svCC]]
 //      CHECK:       linalg.matmul ins(%[[VA]], %[[svB]]{{.*}} outs(%[[VC]]
 //      CHECK:       memref.copy %[[svCC]], %[[svC]]
 //      CHECK:       memref.dealloc %[[tmpA]]
diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir
index 2f98e394fe05198..cd9d40863bc1a4d 100644
--- a/mlir/test/Dialect/Linalg/transform-promotion.mlir
+++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir
@@ -53,7 +53,7 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
 // CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
 // CHECK:               memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
 // CHECK:               memref.copy %[[s1]], %[[l1]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK:               memref.copy %[[s2]], %[[l2]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
+// CHECK-NOT:           memref.copy %[[s2]], %[[l2]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
 // CHECK:               linalg.matmul
 // CHECK-SAME:                 ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
 // CHECK-SAME:                outs(%[[v2]] : memref<?x?xf32>)



More information about the Mlir-commits mailing list