[Mlir-commits] [mlir] 046ebeb - [mlir][linalg] Relax convolution vectorization to support mixed types
Thomas Raoux
llvmlistbot at llvm.org
Thu Jun 16 09:36:02 PDT 2022
Author: Thomas Raoux
Date: 2022-06-16T16:29:46Z
New Revision: 046ebeb60504abfb053c24c1ee5ab6747cfa5a08
URL: https://github.com/llvm/llvm-project/commit/046ebeb60504abfb053c24c1ee5ab6747cfa5a08
DIFF: https://github.com/llvm/llvm-project/commit/046ebeb60504abfb053c24c1ee5ab6747cfa5a08.diff
LOG: [mlir][linalg] Relax convolution vectorization to support mixed types
Support the case where convolution does float extension of the inputs.
Differential Revision: https://reviews.llvm.org/D127925
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ebf20acab4171..794fe97ead57e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1374,10 +1374,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
return;
- maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
- if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
+ // Check for single `mul` predecessor. The `mul` operands must be block
+ // arguments or extension of block arguments.
+ Operation *mulOp = nullptr;
+ for (Value operand : reduceOp->getOperands()) {
+ if (operand.isa<BlockArgument>())
+ continue;
+ if (mulOp)
+ return;
+ mulOp = operand.getDefiningOp();
+ if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
+ return;
+ }
+ if (!mulOp)
return;
-
+ for (Value operand : mulOp->getOperands()) {
+ if (Operation *def = operand.getDefiningOp()) {
+ if (!isa<arith::ExtFOp>(def))
+ return;
+ operand = def->getOperand(0);
+ }
+ if (!operand.isa<BlockArgument>())
+ return;
+ }
// The op is now known to be valid.
valid = true;
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index a4eb9d26e9c8e..7e1f39cbda3e9 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -224,3 +224,29 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
// Write the result back in one shot.
// CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+
+// -----
+
+func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {
+ linalg.conv_1d_nwc_wcf
+ {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : memref<1x2x3xf16>, memref<1x3x2xf16>)
+ outs(%output : memref<1x2x2xf32>)
+ return
+}
+
+// CHECK: func @conv_1d_nwc_wcf_mixed_type_memref
+// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xf16>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xf16>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x2xf16>
+// CHECK: %[[CONT:.*]] = vector.contract
+// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
+// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
More information about the Mlir-commits
mailing list