[Mlir-commits] [mlir] [mlir][linalg] Enable masked vectorisation for depthwise convolutions (PR #81625)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Mar 6 07:24:07 PST 2024
================
@@ -0,0 +1,187 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
+ %filter: tensor<1x?xi8>,
+ %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+ %res = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+ outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+ return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, 8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
+// CHECK-SAME: %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[CH_DIM_IN:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x4xi1>
+/// Read the input tensor
+// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
+
+/// Create a mask for the filter tensor
+// CHECK: %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x4xi1>
+/// Read the filter tensor
+// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x4xi8> } : vector<1x4xi1> -> vector<1x4xi8>
+
+/// Create a mask for the output tensor
+// CHECK: %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x4xi1>
+// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
+
+/// Convolution
+// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
+// CHECK: %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<4xi8> from vector<1x4xi8>
+// CHECK: %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
+// CHECK: %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<4xi8> to vector<1x8x4xi8>
+// CHECK: %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x4xi8>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x4xi8>
+// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x4xi8> into vector<1x8x4xi8>
+// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x4xi8>, tensor<1x8x?xi8> } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
+// CHECK: return %[[OUT]] : tensor<1x8x?xi8>
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
+ %input: tensor<1x8x?xi8>,
+ %filter: tensor<1x?xi8>,
+ %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+ %res = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+ outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+ return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
+// CHECK-SAME: %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[CH_DIM_IN:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK: %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK: %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK: %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK: %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x[4]xi8>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x[4]xi8>
+// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+// CHECK: return %[[OUT]] : tensor<1x8x?xi8>
+
+
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
+ %input: memref<3x5x?xf32>,
+ %filter: memref<2x?xf32>,
+ %output: memref<3x2x?xf32>) {
+ linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
----------------
banach-space wrote:
I've been struggling to find good documentation for this. This book provides a nice description:
* https://udlbook.github.io/udlbook/
See Figure 10.3.
The reason for adding this case here is to demonstrate that this patch is not doing anything that would only work for the default case (dilation = 1). But I appreciate that it's a bit tricky to follow without knowing the semantics of the Op itself. Is there anything I can do to make it clearer? You could try comparing against: https://github.com/llvm/llvm-project/blob/982e9022ca4abaad58c693d2b0aba0e908ee2d39/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir#L51-L108
But note that I wasn't trying to keep the check-lines consistent and, more importantly, that is testing "flattened" convs.
https://github.com/llvm/llvm-project/pull/81625
More information about the Mlir-commits
mailing list