[Mlir-commits] [mlir] 046ebeb - [mlir][linalg] Relax convolution vectorization to support mixed types

Thomas Raoux llvmlistbot at llvm.org
Thu Jun 16 09:36:02 PDT 2022


Author: Thomas Raoux
Date: 2022-06-16T16:29:46Z
New Revision: 046ebeb60504abfb053c24c1ee5ab6747cfa5a08

URL: https://github.com/llvm/llvm-project/commit/046ebeb60504abfb053c24c1ee5ab6747cfa5a08
DIFF: https://github.com/llvm/llvm-project/commit/046ebeb60504abfb053c24c1ee5ab6747cfa5a08.diff

LOG: [mlir][linalg] Relax convolution vectorization to support mixed types

Support the case where convolution does float extension of the inputs.

Differential Revision: https://reviews.llvm.org/D127925

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ebf20acab4171..794fe97ead57e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1374,10 +1374,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     maybeKind = getCombinerOpKind(reduceOp);
     if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
       return;
-    maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
-    if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
+    // Check for single `mul` predecessor. The `mul` operands must be block
+    // arguments or extension of block arguments.
+    Operation *mulOp = nullptr;
+    for (Value operand : reduceOp->getOperands()) {
+      if (operand.isa<BlockArgument>())
+        continue;
+      if (mulOp)
+        return;
+      mulOp = operand.getDefiningOp();
+      if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
+        return;
+    }
+    if (!mulOp)
       return;
-
+    for (Value operand : mulOp->getOperands()) {
+      if (Operation *def = operand.getDefiningOp()) {
+        if (!isa<arith::ExtFOp>(def))
+          return;
+        operand = def->getOperand(0);
+      }
+      if (!operand.isa<BlockArgument>())
+        return;
+    }
     // The op is now known to be valid.
     valid = true;
   }

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index a4eb9d26e9c8e..7e1f39cbda3e9 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -224,3 +224,29 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
 
 // Write the result back in one shot.
 //      CHECK:   vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+
+// -----
+
+func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {
+  linalg.conv_1d_nwc_wcf
+  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+   ins(%input, %filter : memref<1x2x3xf16>, memref<1x3x2xf16>)
+   outs(%output : memref<1x2x2xf32>)
+  return
+}
+
+//       CHECK: func @conv_1d_nwc_wcf_mixed_type_memref
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<1x2x3xf16>, %[[FILTER:[0-9a-z]+]]: memref<1x3x2xf16>, %[[OUTPUT:[0-9a-z]+]]: memref<1x2x2xf32>)
+
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:   %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:   %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:   %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x2xf16>
+//      CHECK:   %[[CONT:.*]] = vector.contract
+//        {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
+//      CHECK:   vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]


        


More information about the Mlir-commits mailing list