[Mlir-commits] [mlir] [mlir][Linalg]: Optimize any structured linalg operation in transform::PromoteOp to avoid unnecessary copies (PR #69876)

Aviad Cohen llvmlistbot at llvm.org
Fri Oct 27 01:47:15 PDT 2023


https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/69876

>From 6dcc8645627adef2725e1071f64675b643ed417f Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 22 Oct 2023 14:37:57 +0300
Subject: [PATCH] [mlir][Linalg]: Optimize any structured linalg operation in
 transform::PromoteOp to avoid unnecessary copies

Before promotion, there is no need to copy outputs thats are not
considered to init tensors.
---
 mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp  | 6 ++----
 mlir/test/Dialect/Linalg/promote.mlir             | 9 ++++++---
 mlir/test/Dialect/Linalg/promotion_options.mlir   | 2 +-
 mlir/test/Dialect/Linalg/transform-promotion.mlir | 3 ++-
 4 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 5c140a7d692a930..05248661e747971 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -177,10 +177,8 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
     Operation *op = opOperand.get().getDefiningOp();
     if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
       subViews[operandNumber] = sv;
-      // In case of linalg generic, copy in only if subview is used in linalg
-      // payload.
-      if (!isa<linalg::GenericOp>(linalgOp) ||
-          linalgOp.payloadUsesValueFromOperand(&opOperand))
+      // Copy in only if subview is being used by the linalg operation.
+      if (linalgOp.isDpsInput(&opOperand) || !linalgOp.isInitTensor(&opOperand))
         operandsNumbersToCopyIn.insert(operandNumber);
       useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
     }
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index fb5f357f3faa8af..896175bb4b650dd 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -54,7 +54,8 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 
 //       CHECK:         linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //       CHECK:         linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
-//       CHECK:         linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+//       CHECK-NOT:     linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+
 //
 //       CHECK:         linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
 //
@@ -124,7 +125,8 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 
 //       CHECK:         linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
 //       CHECK:         linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
-//       CHECK:         linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+//       CHECK-NOT:     linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+
 //
 //       CHECK:         linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
 //
@@ -259,7 +261,8 @@ func.func @promote_rank_reducing_subviews(%arg0:  memref<?x?x?x64xf32, strided<[
   // CHECK: %[[c_view:.+]] = memref.view
   // CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]
 
-  // CHECK-COUNT-3: linalg.copy
+
+  // CHECK-COUNT-2: linalg.copy
   // CHECK: linalg.generic
   // CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
   // CHECK-SAME: outs(%[[c_pro_subview]]
diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index 3bf74b708cb82fb..78ce0ad25914d57 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -28,7 +28,7 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
 //      CHECK:       %[[svCC:.+]] = memref.subview %[[VC]]
 
 //      CHECK:       linalg.copy ins(%[[svA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svAA]] : memref<?x?xf32, strided<[16, 1]>>)
-//      CHECK:       linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
+//      CHECK-NOT:   linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
 //      CHECK:       linalg.matmul ins(%[[VA]], %[[svB]]{{.*}} outs(%[[VC]]
 //      CHECK:       linalg.copy ins(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>) outs(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //      CHECK:       memref.dealloc %[[tmpA]]
diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir
index d6112db0f7772dd..0ef875462fa8205 100644
--- a/mlir/test/Dialect/Linalg/transform-promotion.mlir
+++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir
@@ -51,9 +51,10 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
 // CHECK:               %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
 // CHECK:               %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
 // CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+
 // CHECK:               linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:               linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
-// CHECK:               linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK-NOT:           linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:               linalg.matmul
 // CHECK-SAME:                 ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
 // CHECK-SAME:                outs(%[[v2]] : memref<?x?xf32>)



More information about the Mlir-commits mailing list