[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)
Kunwar Grover
llvmlistbot at llvm.org
Tue Feb 13 04:20:20 PST 2024
================
@@ -119,3 +119,184 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK: func.func @test([[A:%.*]]: {{.*}}, [[B1:%.*]]: {{.*}}, [[B2:%.*]]: {{.*}} {{.*}}
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %out_alloc = tensor.empty() : tensor<128x128xf32>
+ %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-DAG: [[C32:%.*]] = arith.constant 32 : index
+ // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: [[EMPTY:%.*]] = tensor.empty() : tensor<128x128xf32>
+ // CHECK-DAG: [[BUF:%.*]] = linalg.fill ins([[ZERO]] : {{.*}}) outs([[EMPTY]] : {{.*}}) {{.*}}
+ // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C32]] iter_args([[IA0:%.*]] = [[BUF]], [[IA1:%.*]] = [[BUF]]) {{.*}}
+ // CHECK-DAG: [[ASLICE:%.*]] = tensor.extract_slice [[A]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK-DAG: [[SLICE0:%.*]] = tensor.extract_slice [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT1:%.*]] = linalg.matmul ins([[ASLICE]], [[B1]] : {{.*}}) outs([[SLICE0]]
+ // CHECK-NEXT: [[INS0:%.*]] = tensor.insert_slice [[OUT1]] into [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK-DAG: [[SLICE1:%.*]] = tensor.extract_slice [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT2:%.*]] = linalg.matmul ins([[ASLICE]], [[B2]] : {{.*}}) outs([[SLICE1]]
+ // CHECK-NEXT: [[INS1:%.*]] = tensor.insert_slice [[OUT2]] into [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: scf.yield [[INS0]], [[INS1]] : {{.*}}
+ %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ // CHECK: return [[RST]]#0, [[RST]]#1 : {{.*}}
+ func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+ %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+ %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %tiled_mm1, %loop1 = transform.structured.tile_using_for %mm1 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %tiled_mm2, %loop2 = transform.structured.tile_using_for %mm2 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ //transform.print %variant_op : !transform.any_op
----------------
Groverkss wrote:
nit: remove debug prints
https://github.com/llvm/llvm-project/pull/81495
More information about the Mlir-commits
mailing list