[Mlir-commits] [mlir] c050dd4 - [mlir][linalg] Add support for vectorizing convs that have different types.
Hanhan Wang
llvmlistbot at llvm.org
Wed Nov 2 11:03:26 PDT 2022
Author: Hanhan Wang
Date: 2022-11-02T11:03:14-07:00
New Revision: c050dd4717ec4317bd45adfca8243cb9ea7b6370
URL: https://github.com/llvm/llvm-project/commit/c050dd4717ec4317bd45adfca8243cb9ea7b6370
DIFF: https://github.com/llvm/llvm-project/commit/c050dd4717ec4317bd45adfca8243cb9ea7b6370.diff
LOG: [mlir][linalg] Add support for vectorizing convs that have different types.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D137208
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d565efb30241d..cedec72b9cb33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1465,7 +1465,7 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
return;
for (Value operand : mulOp->getOperands()) {
if (Operation *def = operand.getDefiningOp()) {
- if (!isa<arith::ExtFOp>(def))
+ if (!isa<CastOpInterface>(def))
return;
operand = def->getOperand(0);
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index e7495765b3ec7..1374c996128a1 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -61,6 +61,70 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
// -----
+// The i8i8i32 case is similar to f32 case, so checking one case is enough for
+// test coverage.
+func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: memref<1x3x8xi8>, %output: memref<4x2x8xi32>) {
+ linalg.conv_1d_nwc_wcf
+ {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+ ins(%input, %filter : memref<4x6x3xi8>, memref<1x3x8xi8>)
+ outs(%output : memref<4x2x8xi32>)
+ return
+}
+
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK: func @conv1d_nwc_4x2x8_i8i8i32_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xi8>, %[[FILTER:.+]]: memref<1x3x8xi8>, %[[OUTPUT:.+]]: memref<4x2x8xi32>)
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
+
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I32]]
+
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
+
+// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xi8>
+
+// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
+// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xi32> to vector<4x1x8xi32>
+
+/// w == 0, kw == 0
+// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
+// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
+
+/// w == 1, kw == 0
+// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
+// CHECK-SAME: : vector<4x1x3xi8>, vector<3x8xi8> into vector<4x1x8xi32>
+
+/// w == 0, kw == 0
+// CHECK: %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xi32> into vector<4x2x8xi32>
+/// w == 1, kw == 0
+// CHECK: %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]]
+// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xi32> into vector<4x2x8xi32>
+
+// Write the result back in one shot.
+// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+// -----
+
func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
linalg.conv_1d_nwc_wcf
{dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
More information about the Mlir-commits
mailing list