[Mlir-commits] [mlir] [mlir][linalg] Replace CopyOp from memref to linalg in linalg PromoteOp (PR #69154)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 16 01:07:02 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Aviad Cohen (AviadCo)

<details>
<summary>Changes</summary>

linalg::CopyOp is much more generic and useful to promote buffers. In addition, this is linalg transform and makes more sense to use linalg operations when possible.


---
Full diff: https://github.com/llvm/llvm-project/pull/69154.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (+1-1) 
- (modified) mlir/test/Dialect/Linalg/promote.mlir (+13-16) 
- (modified) mlir/test/Dialect/Linalg/promotion_options.mlir (+3-3) 
- (modified) mlir/test/Dialect/Linalg/transform-promotion.mlir (+7-7) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index a131f3097666197..5c140a7d692a930 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -209,7 +209,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
   Location loc = linalgOp.getLoc();
   auto defaultCopyCallBack = [loc](OpBuilder &b, Value src,
                                    Value dst) -> LogicalResult {
-    b.create<memref::CopyOp>(loc, src, dst);
+    b.create<linalg::CopyOp>(loc, src, dst);
     return success();
   };
   copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack);
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index 31b29c0e105d99d..a6f32fea814dd7d 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -52,20 +52,19 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xf32>
 //       CHECK:         %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
 
-//       CHECK:         memref.copy %[[vA]], %[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vB]], %[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vC]], %[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+//       CHECK:         linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //
 //       CHECK:         linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
 //
-//       CHECK:         memref.copy %[[partialC]], %[[vC]] :
-//       CHECK:           memref<?x?xf32, strided<[?, 1], offset: ?>> to
-//       CHECK:           memref<?x?xf32, strided<[?, 1], offset: ?>>
+//       CHECK:         linalg.copy ins(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //
 //   CHECK-NOT:         memref.dealloc %[[tmpA]] : memref<32xi8>
 //   CHECK-NOT:         memref.dealloc %[[tmpB]] : memref<48xi8>
 //   CHECK-NOT:         memref.dealloc %[[tmpC]] : memref<24xi8>
 
+
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -122,15 +121,13 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf64>
 //       CHECK:         %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
 
-//       CHECK:         memref.copy %[[vA_f64]], %[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vB_f64]], %[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
-//       CHECK:         memref.copy %[[vC_f64]], %[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>> to memref<?x?xf64, strided<[?, 1], offset: ?>>
+//       CHECK:         linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
 //
 //       CHECK:         linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
 //
-//       CHECK:         memref.copy %[[partialC_f64]], %[[vC_f64]] :
-//       CHECK:           memref<?x?xf64, strided<[?, 1], offset: ?>> to
-//       CHECK:           memref<?x?xf64, strided<[?, 1], offset: ?>>
+//       CHECK:         linalg.copy ins(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
 //
 //       CHECK:         memref.dealloc %[[tmpA_f64]] : memref<64xi8>
 //       CHECK:         memref.dealloc %[[tmpB_f64]] : memref<96xi8>
@@ -255,7 +252,7 @@ func.func @promote_rank_reducing_subviews(%arg0:  memref<?x?x?x64xf32, strided<[
   // CHECK: %[[c_view:.+]] = memref.view
   // CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]
 
-  // CHECK-COUNT-3: memref.copy
+  // CHECK-COUNT-3: linalg.copy
   // CHECK: linalg.generic
   // CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
   // CHECK-SAME: outs(%[[c_pro_subview]]
@@ -351,8 +348,8 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
   // CHECK:           %[[VAL_60:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_61:.*]] = memref.view %[[VAL_60]]{{\[}}%[[VAL_56]]]{{\[}}%[[VAL_50]], %[[VAL_53]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
-  // CHECK:           memref.copy %[[VAL_3]], %[[VAL_24]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
-  // CHECK:           memref.copy %[[VAL_4]], %[[VAL_43]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
+// CHECK:           linalg.copy ins(%[[VAL_3]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_24]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
+// CHECK:           linalg.copy ins(%[[VAL_4]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
   // CHECK:           linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>, memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) {
   // CHECK:           ^bb0(%[[VAL_63:.*]]: f32, %[[VAL_64:.*]]: f32, %[[VAL_65:.*]]: f32):
   // CHECK:             %[[VAL_66:.*]] = arith.addf %[[VAL_63]], %[[VAL_64]] : f32
@@ -366,7 +363,7 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
     linalg.yield %1 : f32
   }
 
-  // CHECK:           memref.copy %[[VAL_62]], %[[VAL_5]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>> to memref<4x3xf32, strided<[4, 1]>, 1>
+  // CHECK:           linalg.copy ins(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_5]] : memref<4x3xf32, strided<[4, 1]>, 1>)
   // CHECK:           memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space<workgroup>>
diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index a6daa9af2f37cec..6e5c1c78b329e02 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -27,10 +27,10 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
 //      CHECK:       %[[VC:.*]] = memref.view %[[tmpC]][%[[C0]]][] : memref<1024xi8> to memref<16x16xf32>
 //      CHECK:       %[[svCC:.+]] = memref.subview %[[VC]]
 
-//      CHECK:       memref.copy %[[svA]], %[[svAA]]
-//      CHECK:       memref.copy %[[svC]], %[[svCC]]
+//      CHECK:       linalg.copy ins(%[[svA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svAA]] : memref<?x?xf32, strided<[16, 1]>>)
+//      CHECK:       linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
 //      CHECK:       linalg.matmul ins(%[[VA]], %[[svB]]{{.*}} outs(%[[VC]]
-//      CHECK:       memref.copy %[[svCC]], %[[svC]]
+//      CHECK:       linalg.copy ins(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>) outs(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //      CHECK:       memref.dealloc %[[tmpA]]
 //      CHECK:       memref.dealloc %[[tmpC]]
 
diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir
index 2f98e394fe05198..c311d471f6c4ae9 100644
--- a/mlir/test/Dialect/Linalg/transform-promotion.mlir
+++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir
@@ -51,9 +51,9 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
 // CHECK:               %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
 // CHECK:               %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
 // CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-// CHECK:               memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK:               memref.copy %[[s1]], %[[l1]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK:               memref.copy %[[s2]], %[[l2]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
+// CHECK:               linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK:               linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK:               linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:               linalg.matmul
 // CHECK-SAME:                 ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
 // CHECK-SAME:                outs(%[[v2]] : memref<?x?xf32>)
@@ -112,8 +112,8 @@ func.func @promote_first_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], o
 // CHECK-NOT:     memref.alloc
 // CHECK-NOT:     memref.view
 // CHECK-NOT:     memref.subview
-// CHECK:         memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
-// CHECK-NOT:     memref.copy
+// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
+// CHECK-NOT:     linalg.copy
 // CHECK:         linalg.matmul
 // CHECK-SAME:           ins(%[[v0]], %[[s1]] : memref<?x?xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>)
 // CHECK-SAME:          outs(%[[s2]] : memref<?x?xf32, strided<[?, ?], offset: ?>>)
@@ -148,7 +148,7 @@ func.func @aligned_promote_fill(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
 // CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
 // CHECK:         linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref<?x?xf32>)
-// CHECK:         memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
+// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:         linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref<?x?xf32>)
 
 transform.with_pdl_patterns {
@@ -182,7 +182,7 @@ func.func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, strided<
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref<?x?xcomplex<f32>>
 // CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>>
 // CHECK:         linalg.fill ins({{.*}} : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
-// CHECK:         memref.copy %[[s0]], %[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}> to memref<?x?xcomplex<f32>, strided{{.*}}>
+// CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xcomplex<f32>, strided{{.*}}>) outs(%[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}>)
 // CHECK:         linalg.fill ins(%[[cc]] : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
 
 transform.with_pdl_patterns {

``````````

</details>


https://github.com/llvm/llvm-project/pull/69154


More information about the Mlir-commits mailing list