[Mlir-commits] [mlir] 40f5f3d - [mlir][linalg][bufferize] Use memref.copy instead of linalg.copy

Matthias Springer llvmlistbot at llvm.org
Fri Jan 14 05:29:18 PST 2022


Author: Matthias Springer
Date: 2022-01-14T22:29:05+09:00
New Revision: 40f5f3d62dcd7b5e74911431c0974c762836b80f

URL: https://github.com/llvm/llvm-project/commit/40f5f3d62dcd7b5e74911431c0974c762836b80f
DIFF: https://github.com/llvm/llvm-project/commit/40f5f3d62dcd7b5e74911431c0974c762836b80f.diff

LOG: [mlir][linalg][bufferize] Use memref.copy instead of linalg.copy

Differential Revision: https://reviews.llvm.org/D117220

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index e92f852cd6c8f..f0f1beb53ab01 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -81,11 +81,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     options->allocationFns->deallocationFn = [](OpBuilder &b, Location loc,
                                                 Value v) {};
   }
-  // TODO: Change to memref::CopyOp (default memCpyFn).
-  options->allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from,
-                                        Value to) {
-    b.create<linalg::CopyOp>(loc, from, to);
-  };
 
   options->allowReturnMemref = allowReturnMemref;
   options->allowUnknownOps = allowUnknownOps;

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
index d30ab5ac4f9a9..de1b8321ed5eb 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
@@ -26,12 +26,12 @@ func @buffer_forwarding_conflict(
   //     CHECK: linalg.fill({{.*}}, %[[EXTRACT_SLICE_ALLOC]]) : f32, memref<?xf32>
   %f = linalg.fill(%f0, %a) : f32, tensor<?xf32> -> tensor<?xf32>
 
-  //     CHECK: linalg.copy(%[[FUNC_ARG]], %[[ALLOC]]) : memref<?xf32>, memref<?xf32>
+  //     CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref<?xf32> to memref<?xf32>
   //     CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32>
-  //     CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]]) : memref<?xf32>, memref<?xf32>
+  //     CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32>
   %r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor<?xf32> into tensor<?xf32>
 
-  //     CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[T_SUBVIEW]])
+  //     CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[T_SUBVIEW]]
   %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
 
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index e8f484c263261..909c41cef97f2 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -141,7 +141,7 @@ func @unknown_op_not_writable(
   // introducing a RaW conflict.
   // CHECK: %[[dim:.*]] = tensor.dim %[[dummy]]
   // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
-  // CHECK: linalg.copy(%[[dummy_memref]], %[[alloc]])
+  // CHECK: memref.copy %[[dummy_memref]], %[[alloc]]
   // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
   %1 = vector.transfer_write %v, %0[%idx] : vector<5xf32>, tensor<?xf32>
 

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 753e099354cfa..b37c3066a0d65 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -150,7 +150,7 @@ func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vec
 
   /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc.
   //      CHECK: %[[ALLOC:.*]] = memref.alloc
-  //      CHECK: linalg.copy({{.*}}, %[[ALLOC]])
+  //      CHECK: memref.copy {{.*}}, %[[ALLOC]]
   // CHECK-NEXT: vector.transfer_write {{.*}}, %[[ALLOC]]
   %r0 = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
 
@@ -185,27 +185,27 @@ func @insert_slice_fun(%A0 : tensor<?xf32> {linalg.inplaceable = false},
   //      CHECK: %[[REALLOC1:.*]] = memref.alloc
 
   // Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
-  //      CHECK: linalg.copy(%[[A0]], %[[REALLOC3]]
+  //      CHECK: memref.copy %[[A0]], %[[REALLOC3]]
   //      CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC3]]
-  //      CHECK: linalg.copy(%[[t0]], %[[SV_A0]])
+  //      CHECK: memref.copy %[[t0]], %[[SV_A0]]
   %r0 = tensor.insert_slice %t0 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   // Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
-  //      CHECK: linalg.copy(%[[A0]]
+  //      CHECK: memref.copy %[[A0]]
   //      CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC2]]
-  //      CHECK: linalg.copy(%[[t1]], %[[SV_A0_2]])
+  //      CHECK: memref.copy %[[t1]], %[[SV_A0_2]]
   %r1 = tensor.insert_slice %t1 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   //  Still alloc the large tensor because %A1 is read after. Copy the tensor.extract_slice.
-  //      CHECK: linalg.copy(%[[A1]]
+  //      CHECK: memref.copy %[[A1]]
   //      CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC1]]
-  //      CHECK: linalg.copy(%[[t0]], %[[SV_A1]])
+  //      CHECK: memref.copy %[[t0]], %[[SV_A1]]
   %r2 = tensor.insert_slice %t0 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   //  Do not realloc the large tensor. Copy the tensor.extract_slice.
   //  CHECK-NOT: alloc
   //      CHECK: %[[SV_A1_2:.*]] = memref.subview %[[A1]]
-  //      CHECK: linalg.copy(%[[t1]], %[[SV_A1_2]])
+  //      CHECK: memref.copy %[[t1]], %[[SV_A1_2]]
   %r3 = tensor.insert_slice %t1 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   //      CHECK: return %[[REALLOC3]], %[[REALLOC2]], %[[REALLOC1]] :
@@ -229,7 +229,7 @@ func @insert_slice_fun(
 
   //  CHECK-NOT: alloc
   //      CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
-  //      CHECK: linalg.copy(%[[t]], %[[SV_A]])
+  //      CHECK: memref.copy %[[t]], %[[SV_A]]
   %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   /// Overwrite A inplace.
@@ -261,7 +261,7 @@ func @insert_slice_fun(
   //  CHECK-NOT: alloc
   //      CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
   /// Overwrite A inplace by copying into the subview.
-  //      CHECK: linalg.copy(%[[t]], %[[SV_A]])
+  //      CHECK: memref.copy %[[t]], %[[SV_A]]
   %r1 = tensor.insert_slice %t into %r0[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   //     CHECK: return
@@ -282,9 +282,9 @@ func @insert_slice_fun_not_inplace(
   -> tensor<?xf32>
 {
   //      CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) {alignment = 128 : i64} : memref<?xf32>
-  //      CHECK: linalg.copy(%[[A]], %[[ALLOC]]) : memref<?xf32{{.*}}, memref<?xf32>
+  //      CHECK: memref.copy %[[A]], %[[ALLOC]] : memref<?xf32{{.*}} to memref<?xf32>
   //      CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
-  //      CHECK: linalg.copy(%[[t]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32>
+  //      CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32>
   //      CHECK: memref.dealloc %[[ALLOC]] : memref<?xf32>
   %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
@@ -310,7 +310,7 @@ func @scf_for_yield_only(%A : tensor<?xf32> {linalg.inplaceable = false},
 {
   //     CHECK:   %[[ALLOC_FOR_A:.*]] = memref.alloc
   //     CHECK:   %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]]
-  //     CHECK:   linalg.copy(%[[A]], %[[ALLOC_FOR_A]])
+  //     CHECK:   memref.copy %[[A]], %[[ALLOC_FOR_A]]
 
   // The first scf.for remains but just turns into dead code.
   %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
@@ -366,7 +366,7 @@ func @scf_for_with_tensor.insert_slice(
 {
   //     CHECK:   %[[ALLOC_FOR_A:.*]] = memref.alloc
   //     CHECK:   %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]]
-  //     CHECK:   linalg.copy(%[[A]], %[[ALLOC_FOR_A]])
+  //     CHECK:   memref.copy %[[A]], %[[ALLOC_FOR_A]]
 
   //     CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1]
   //     CHECK: %[[svB:.*]] = memref.subview %[[B]][0] [4] [1]
@@ -377,11 +377,11 @@ func @scf_for_with_tensor.insert_slice(
       -> (tensor<?xf32>, tensor<?xf32>)
   {
     // %ttA bufferizes to direct copy of %BUFFER_CAST_C into %svA
-    //     CHECK: linalg.copy(%[[C]], %[[svA]])
+    //     CHECK: memref.copy %[[C]], %[[svA]]
     %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
     // %ttB bufferizes to direct copy of %BUFFER_CAST_C into %BUFFER_CAST_B
-    //     CHECK:   linalg.copy(%[[C]], %[[svB]])
+    //     CHECK:   memref.copy %[[C]], %[[svB]]
     %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
     // CHECK-NOT:   scf.yield
@@ -412,7 +412,7 @@ func @main() {
 
 //      CHECK:   %[[alloc:.*]] = memref.alloc
 //      CHECK:   %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
-//      CHECK:   linalg.copy(%[[A]], %[[alloc]])
+//      CHECK:   memref.copy %[[A]], %[[alloc]]
 //      CHECK:   call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> ()
   call @some_external_func(%A) : (tensor<4xi32>) -> ()
 
@@ -434,7 +434,7 @@ func @main() {
 
 //      CHECK:   %[[alloc:.*]] = memref.alloc
 //      CHECK:   %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
-//      CHECK:   linalg.copy(%[[A]], %[[alloc]])
+//      CHECK:   memref.copy %[[A]], %[[alloc]]
 //      CHECK:   call @some_external_func_within_scf_execute(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> ()
   scf.execute_region {
     call @some_external_func_within_scf_execute(%A) : (tensor<4xi32>) -> ()
@@ -465,11 +465,11 @@ func @scf_for_with_tensor_insert_slice(
       -> (tensor<?xf32>, tensor<?xf32>)
   {
     // CHECK-NEXT:   %[[SVA:.*]] = memref.subview %[[A]]
-    // CHECK-NEXT:   linalg.copy(%[[C]], %[[SVA]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]>
+    // CHECK-NEXT:   memref.copy %[[C]], %[[SVA]] : memref<4xf32, #[[$DYN_1D_MAP]]> to memref<4xf32, #[[$DYN_1D_MAP]]>
     %ttA = tensor.insert_slice %C into %tA[%i][4][1] : tensor<4xf32> into tensor<?xf32>
 
     // CHECK-NEXT:   %[[SVB:.*]] = memref.subview %[[B]]
-    // CHECK-NEXT:   linalg.copy(%[[C]], %[[SVB]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]>
+    // CHECK-NEXT:   memref.copy %[[C]], %[[SVB]] : memref<4xf32, #[[$DYN_1D_MAP]]> to memref<4xf32, #[[$DYN_1D_MAP]]>
     %ttB = tensor.insert_slice %C into %tB[%i][4][1] : tensor<4xf32> into tensor<?xf32>
 
     // scf.yield is empty and is elided
@@ -500,7 +500,7 @@ func @bar(
   // %r0#0 requires a copy because we have no idea what the function is doing.
 //      CHECK:   %[[alloc:.*]] = memref.alloc
 //      CHECK:   %[[casted:.*]] = memref.cast %[[alloc]]
-//      CHECK:   linalg.copy(%[[B]], %[[alloc]])
+//      CHECK:   memref.copy %[[B]], %[[alloc]]
 // CHECK-NEXT:   call @some_external_func(%[[casted]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
   call @some_external_func(%r0#0) : (tensor<?xf32>) -> ()
 
@@ -707,7 +707,7 @@ func @tiled_loop_yield_out_of_place(
       iterators["parallel"]
   {
     // CHECK-NOT:   alloc
-    //     CHECK:   linalg.copy(%[[B]], %[[A]])
+    //     CHECK:   memref.copy %[[B]], %[[A]]
     linalg.yield %B : tensor<?xf32>
     //     CHECK:   linalg.yield
     // CHECK-NOT:   tensor
@@ -762,9 +762,9 @@ func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] ->
 //      CHECK: %[[ALLOC_B:.*]] = memref.alloc
 //      CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]]
 //      CHECK: %[[ALLOC_A:.*]] = memref.alloc
-//      CHECK: linalg.copy(%[[A]], %[[ALLOC_A]])
-//      CHECK: linalg.copy(%[[B]], %[[ALLOC_B]])
-//      CHECK: linalg.copy(%[[C]], %[[ALLOC_C]])
+//      CHECK: memref.copy %[[A]], %[[ALLOC_A]]
+//      CHECK: memref.copy %[[B]], %[[ALLOC_B]]
+//      CHECK: memref.copy %[[C]], %[[ALLOC_C]]
 //      CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]]
 // CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]])
   call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
@@ -831,7 +831,7 @@ func @matmul(
       // insert_slice is inplace but its source comes from an equivalent buffer
       // that is not in place. So we must insert a copy of the small buffer into
       // the bigger buffer.
-      // CHECK: linalg.copy(%[[ALLOC]], %[[T]])
+      // CHECK: memref.copy %[[ALLOC]], %[[T]]
       %7 = tensor.insert_slice %6 into %arg6[%arg3, %arg5] [8, 16] [1, 1] :
         tensor<8x16xf32> into tensor<128x192xf32>
 
@@ -848,8 +848,9 @@ func @matmul(
 // CHECK-LABEL: func @tensor_cast_not_in_place(
 //  CHECK-SAME:     %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
 //       CHECK:   %[[alloc:.*]] = memref.alloc
-//       CHECK:   linalg.copy(%[[A]], %[[alloc]])
-//       CHECK:   %[[cast:.*]] = memref.cast %[[alloc]]
+//       CHECK:   memref.copy %[[A]], %[[alloc]]
+//       CHECK:   %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32
+//       CHECK:   memref.copy %[[alloc]], %[[subview]]
 func @tensor_cast_not_in_place(
     %A : tensor<?xf32> {linalg.inplaceable = true},
     %B : tensor<?xf32> {linalg.inplaceable = false}, %idx: index)
@@ -1014,7 +1015,7 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
   %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
     // CHECK: %[[alloc:.*]] = memref.alloc
     // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
-    // CHECK: linalg.copy(%[[arg0]], %[[alloc]])
+    // CHECK: memref.copy %[[arg0]], %[[alloc]]
     // CHECK: call @inner_func_2(%[[casted]])
     %3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
     scf.yield %t1 : tensor<?xf32>
@@ -1143,7 +1144,7 @@ func @linalg_op_bufferizes_out_of_place_with_input(
     %t3: tensor<?x?xf32> {linalg.inplaceable = false},
     %s1: index, %s2: index, %cst: f32) -> tensor<?x?xf32> {
   // CHECK: %[[alloc:.*]] = memref.alloc
-  // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+  // CHECK: memref.copy %[[t1]], %[[alloc]]
   // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[alloc]] : {{.*}})
   %r = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -1203,7 +1204,7 @@ func @op_is_reading_but_following_ops_are_not(
 {
   // Make sure that a copy is inserted here.
   // CHECK: %[[ALLOC:.*]] = memref.alloc
-  // CHECK: linalg.copy(%[[t0]], %[[ALLOC]])
+  // CHECK: memref.copy %[[t0]], %[[ALLOC]]
   // CHECK: linalg.generic {{.*}} outs(%[[ALLOC]] : memref
   %r0 =linalg.generic #trait outs (%t0 : tensor<?xf32>) {
       ^bb(%0: f32) :
@@ -1257,7 +1258,7 @@ func @write_to_select_op_source(
   %cst = arith.constant 0.0 : f32
   %idx = arith.constant 0 : index
   // CHECK: %[[alloc:.*]] = memref.alloc
-  // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+  // CHECK: memref.copy %[[t1]], %[[alloc]]
   // CHECK: memref.store %{{.*}}, %[[alloc]]
   %w = tensor.insert %cst into %t1[%idx] : tensor<?xf32>
   // CHECK: %[[select:.*]] = select %{{.*}}, %[[t1]], %[[t2]]
@@ -1281,7 +1282,7 @@ func @write_after_select_read_one(
 
   // CHECK: %[[alloc:.*]] = memref.alloc
   // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
-  // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+  // CHECK: memref.copy %[[t1]], %[[alloc]]
   // CHECK: %[[select:.*]] = select %{{.*}}, %[[casted]], %[[t2]]
   %s = std.select %c, %t1, %t2 : tensor<?xf32>
 

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
index dc6a8ab63ef35..648a7ee2832bc 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
@@ -3,7 +3,7 @@
 // RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
 
 // RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
-// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext,%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext |\
 // RUN: FileCheck %s
 
 #map0 = affine_map<(d0, d1)[s0] -> ((d1 - d0) ceildiv s0)>


        


More information about the Mlir-commits mailing list