[Mlir-commits] [mlir] [mlir][Linalg]: Optimize any structured linalg operation in transform::PromoteOp to avoid unnecessary copies (PR #69876)
Aviad Cohen
llvmlistbot at llvm.org
Fri Oct 27 01:47:15 PDT 2023
https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/69876
>From 6dcc8645627adef2725e1071f64675b643ed417f Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 22 Oct 2023 14:37:57 +0300
Subject: [PATCH] [mlir][Linalg]: Optimize any structured linalg operation in
transform::PromoteOp to avoid unnecessary copies
Before promotion, there is no need to copy outputs thats are not
considered to init tensors.
---
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 6 ++----
mlir/test/Dialect/Linalg/promote.mlir | 9 ++++++---
mlir/test/Dialect/Linalg/promotion_options.mlir | 2 +-
mlir/test/Dialect/Linalg/transform-promotion.mlir | 3 ++-
4 files changed, 11 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 5c140a7d692a930..05248661e747971 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -177,10 +177,8 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
Operation *op = opOperand.get().getDefiningOp();
if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
subViews[operandNumber] = sv;
- // In case of linalg generic, copy in only if subview is used in linalg
- // payload.
- if (!isa<linalg::GenericOp>(linalgOp) ||
- linalgOp.payloadUsesValueFromOperand(&opOperand))
+ // Copy in only if subview is being used by the linalg operation.
+ if (linalgOp.isDpsInput(&opOperand) || !linalgOp.isInitTensor(&opOperand))
operandsNumbersToCopyIn.insert(operandNumber);
useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
}
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index fb5f357f3faa8af..896175bb4b650dd 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -54,7 +54,8 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
// CHECK: linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
-// CHECK: linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+// CHECK-NOT: linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+
//
// CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
//
@@ -124,7 +125,8 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
// CHECK: linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
-// CHECK: linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+// CHECK-NOT: linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+
//
// CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
//
@@ -259,7 +261,8 @@ func.func @promote_rank_reducing_subviews(%arg0: memref<?x?x?x64xf32, strided<[
// CHECK: %[[c_view:.+]] = memref.view
// CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]
- // CHECK-COUNT-3: linalg.copy
+
+ // CHECK-COUNT-2: linalg.copy
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
// CHECK-SAME: outs(%[[c_pro_subview]]
diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index 3bf74b708cb82fb..78ce0ad25914d57 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -28,7 +28,7 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
// CHECK: %[[svCC:.+]] = memref.subview %[[VC]]
// CHECK: linalg.copy ins(%[[svA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svAA]] : memref<?x?xf32, strided<[16, 1]>>)
-// CHECK: linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
+// CHECK-NOT: linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
// CHECK: linalg.matmul ins(%[[VA]], %[[svB]]{{.*}} outs(%[[VC]]
// CHECK: linalg.copy ins(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>) outs(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
// CHECK: memref.dealloc %[[tmpA]]
diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir
index d6112db0f7772dd..0ef875462fa8205 100644
--- a/mlir/test/Dialect/Linalg/transform-promotion.mlir
+++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir
@@ -51,9 +51,10 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
// CHECK: %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
// CHECK: %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
// CHECK-SAME: memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+
// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
// CHECK: linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
-// CHECK: linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK-NOT: linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[v2]] : memref<?x?xf32>)
More information about the Mlir-commits
mailing list