[Mlir-commits] [mlir] cca662b - [mlir][linalg] add conv_2d_nhwc_fhwc named op
Christopher Bate
llvmlistbot at llvm.org
Mon Jun 6 12:18:53 PDT 2022
Author: Christopher Bate
Date: 2022-06-06T13:18:08-06:00
New Revision: cca662b8495537b67f7b70090a33fc36154de81c
URL: https://github.com/llvm/llvm-project/commit/cca662b8495537b67f7b70090a33fc36154de81c
DIFF: https://github.com/llvm/llvm-project/commit/cca662b8495537b67f7b70090a33fc36154de81c.diff
LOG: [mlir][linalg] add conv_2d_nhwc_fhwc named op
This operation should be supported as a named op because
when the operands are viewed as having canonical layouts
with decreasing strides, then the "reduction" dimensions
of the filter (h, w, and c) are contiguous relative to each
output channel. When lowered to a matrix multiplication,
this layout is the simplest to deal with, and thus future
transforms/vectorizations of `conv2d` may find using this
named op convenient.
Differential Revision: https://reviews.llvm.org/D126995
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/test/Dialect/Linalg/named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3cc8f8e32cc94..522bec689f122 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1311,6 +1311,104 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: K
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: conv_2d_nhwc_fhwc
+ cpp_class_name: Conv2DNhwcFhwcOp
+ doc: |-
+ Performs 2-D convolution.
+
+ Layout:
+ * Input: NHWC.
+ * Kernel: FHWC.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
+ s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
+ - !LinalgOperandDefConfig
+ name: K
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, s3,
+ s7, s9)>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
+ s1, s5, s10)>
+ - !LinalgOperandDefConfig
+ name: strides
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s2, s6)>
+ default_indices:
+ - 1
+ - 1
+ - !LinalgOperandDefConfig
+ name: dilations
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s4, s8)>
+ default_indices:
+ - 1
+ - 1
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+ s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+ s9, s10] -> (d3, d4, d5, d6)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+ s9, s10] -> (d0, d1, d2, d3)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: K
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nhwc_hwcf_q
cpp_class_name: Conv2DNhwcHwcfQOp
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 992da7e80ad09..b984a1baaa94e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -178,6 +178,38 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x
// -----
+// CHECK-LABEL: func @conv_2d_nhwc_fhwc
+func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc
+ // CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+ // CHECK-SAME: strides = dense<1> : tensor<2xi64>
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_nhwc_fhwc_static
+func.func @conv_2d_nhwc_fhwc_static(%input: tensor<?x128x128x32xf32>, %filter: tensor<64x3x3x32xf32>, %init: tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32> {
+ // CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc
+ // CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+ // CHECK-SAME: strides = dense<1> : tensor<2xi64>
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x128x128x32xf32>, tensor<64x3x3x32xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32>
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x128x128x32xf32>, tensor<64x3x3x32xf32>)
+ outs (%init: tensor<?x126x126x64xf32>) -> tensor<?x126x126x64xf32>
+ return %0 : tensor<?x126x126x64xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @conv_2d_nhwc_hwcf
func.func @conv_2d_nhwc_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
// CHECK: linalg.conv_2d_nhwc_hwcf
More information about the Mlir-commits
mailing list