[Mlir-commits] [mlir] 1e89a76 - [MLIR] Refactor to create vectorization convOp precondition check (#130181)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 17 06:32:49 PDT 2025


Author: Zhuoran Yin
Date: 2025-03-17T09:32:45-04:00
New Revision: 1e89a76a0490b6c55a3e46ecf967da3e30c9112b

URL: https://github.com/llvm/llvm-project/commit/1e89a76a0490b6c55a3e46ecf967da3e30c9112b
DIFF: https://github.com/llvm/llvm-project/commit/1e89a76a0490b6c55a3e46ecf967da3e30c9112b.diff

LOG: [MLIR] Refactor to create vectorization convOp precondition check (#130181)

In corner situations, the vectorization pass may face to lower a conv2d
op and assert in a completely irrelevant location in
vectorizeConvolution() subroutine.

~~This PR rejects the conv2d op early and make the asserted routine to
return failure as a defensive workaround.~~

In addressing this, the PR moved all condition check away from the
`Conv1dGenerator` into the `convOpPreconditionCheck()` function. This
makes the unsupported ops such as conv2d to be rejected early and leave
a cleaner `Conv1dGenerator` constructor.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..2dcd897330d1e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1939,6 +1939,130 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+namespace {
+enum class ConvOperationKind { Conv, Pool };
+} // namespace
+
+static bool isCastOfBlockArgument(Operation *op) {
+  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
+         isa<BlockArgument>(op->getOperand(0));
+}
+
+// Returns the ConvOperationKind of the op using reduceOp of the generic
+// payload. If it is neither a convolution nor a pooling, it returns
+// std::nullopt.
+//
+// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
+// + yield) and rhs is not used) then it is the body of a pooling
+// If conv, check for single `mul` predecessor. The `mul` operands must be
+// block arguments or extension of block arguments.
+// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
+// must be block arguments or extension of block arguments.
+static std::optional<ConvOperationKind>
+getConvOperationKind(Operation *reduceOp) {
+  int numBlockArguments =
+      llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
+
+  switch (numBlockArguments) {
+  case 1: {
+    // Will be convolution if feeder is a MulOp.
+    // A strength reduced version of MulOp for i1 type is AndOp which is also
+    // supported. Otherwise, it can be pooling. This strength reduction logic
+    // is in `buildBinaryFn` helper in the Linalg dialect.
+    auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
+                                       llvm::IsaPred<BlockArgument>);
+    assert(feedValIt != reduceOp->operand_end() &&
+           "Expected a non-block argument operand");
+    Operation *feedOp = (*feedValIt).getDefiningOp();
+    if (isCastOfBlockArgument(feedOp)) {
+      return ConvOperationKind::Pool;
+    }
+
+    if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+           (isa<arith::AndIOp>(feedOp) &&
+            feedOp->getResultTypes()[0].isInteger(1))) &&
+          llvm::all_of(feedOp->getOperands(), [](Value v) {
+            if (isa<BlockArgument>(v))
+              return true;
+            if (Operation *op = v.getDefiningOp())
+              return isCastOfBlockArgument(op);
+            return false;
+          }))) {
+      return std::nullopt;
+    }
+
+    return ConvOperationKind::Conv;
+  }
+  case 2:
+    // Must be pooling
+    return ConvOperationKind::Pool;
+  default:
+    return std::nullopt;
+  }
+}
+
+static bool isSupportedPoolKind(vector::CombiningKind kind) {
+  switch (kind) {
+  case vector::CombiningKind::ADD:
+  case vector::CombiningKind::MAXNUMF:
+  case vector::CombiningKind::MAXIMUMF:
+  case vector::CombiningKind::MAXSI:
+  case vector::CombiningKind::MAXUI:
+  case vector::CombiningKind::MINNUMF:
+  case vector::CombiningKind::MINIMUMF:
+  case vector::CombiningKind::MINSI:
+  case vector::CombiningKind::MINUI:
+    return true;
+  default:
+    return false;
+  }
+}
+
+static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
+  auto getOperandType = [&](auto operand) {
+    return dyn_cast<ShapedType>((operand->get()).getType());
+  };
+  ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
+  ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
+  ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
+  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
+  // (non-channeled convolution -> LHS and RHS both have single dimensions).
+  // Note that this also ensures 2D and 3D convolutions are rejected.
+  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
+      (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
+    return failure();
+
+  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
+  if (!reduceOp)
+    return failure();
+
+  auto maybeOper = getConvOperationKind(reduceOp);
+  if (!maybeOper.has_value())
+    return failure();
+
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  // Typically convolution will have a `Add` CombiningKind but for i1 type it
+  // can get strength reduced to `OR` which is also supported. This strength
+  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
+  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+                      *maybeKind != vector::CombiningKind::OR) &&
+                     (*maybeOper != ConvOperationKind::Pool ||
+                      !isSupportedPoolKind(*maybeKind)))) {
+    return failure();
+  }
+
+  auto rhsRank = rhsShapedType.getRank();
+  if (*maybeOper == ConvOperationKind::Pool) {
+    if (rhsRank != 1)
+      return failure();
+  } else {
+    if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
+      return failure();
+  }
+
+  return success();
+}
+
 static LogicalResult vectorizeLinalgOpPrecondition(
     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -1991,7 +2115,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   // features. But we will still need stride/dilation attributes that will be
   // annoying to reverse-engineer...
   if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
-    return success();
+    return vectorizeConvOpPrecondition(linalgOp);
+
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
@@ -3067,28 +3192,6 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
 }
 
 namespace {
-bool isCastOfBlockArgument(Operation *op) {
-  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
-         isa<BlockArgument>(op->getOperand(0));
-}
-
-bool isSupportedPoolKind(vector::CombiningKind kind) {
-  switch (kind) {
-  case vector::CombiningKind::ADD:
-  case vector::CombiningKind::MAXNUMF:
-  case vector::CombiningKind::MAXIMUMF:
-  case vector::CombiningKind::MAXSI:
-  case vector::CombiningKind::MAXUI:
-  case vector::CombiningKind::MINNUMF:
-  case vector::CombiningKind::MINIMUMF:
-  case vector::CombiningKind::MINSI:
-  case vector::CombiningKind::MINUI:
-    return true;
-  default:
-    return false;
-  }
-}
-
 /// Generate a vector implementation for either:
 /// ```
 ///   Op def: (     w,     kw  )
@@ -3125,58 +3228,32 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
 /// kw is unrolled, w is unrolled iff dilationW > 1.
 struct Conv1DGenerator
     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
-  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
-                  int dilationW)
-      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
-        strideW(strideW), dilationW(dilationW) {
-    // Determine whether `linalgOp` can be generated with this generator
-    if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
-      return;
+  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
+      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
+
     lhsShaped = linalgOp.getDpsInputOperand(0)->get();
     rhsShaped = linalgOp.getDpsInputOperand(1)->get();
     resShaped = linalgOp.getDpsInitOperand(0)->get();
     lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
     rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
     resShapedType = dyn_cast<ShapedType>(resShaped.getType());
-    if (!lhsShapedType || !rhsShapedType || !resShapedType)
-      return;
-    // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
-    // (non-channeled convolution -> LHS and RHS both have single dimensions).
-    if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
-        (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
-      return;
 
     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
-    if (!reduceOp)
-      return;
     redOp = reduceOp->getName().getIdentifier();
 
-    if (!setOperKind(reduceOp))
-      return;
+    setConvOperationKind(reduceOp);
+
     auto maybeKind = getCombinerOpKind(reduceOp);
-    // Typically convolution will have a `Add` CombiningKind but for i1 type it
-    // can get strength reduced to `OR` which is also supported. This strength
-    // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
-    if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
-                        *maybeKind != vector::CombiningKind::OR) &&
-                       (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
-      return;
-    }
     reductionKind = maybeKind.value();
 
-    auto rhsRank = rhsShapedType.getRank();
-    switch (oper) {
-    case Conv:
-      if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
-        return;
-      break;
-    case Pool:
-      if (rhsRank != 1)
-        return;
-      break;
-    }
-    // The op is now known to be valid.
-    valid = true;
+    // The ConvolutionOpInterface gives us guarantees of existence for
+    // strides/dilations. However, we do not need to rely on those, we can
+    // simply use them if present, otherwise use the default and let the generic
+    // conv. matcher in the ConvGenerator succeed or fail.
+    auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+    auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+    strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+    dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
   }
 
   /// Generate a vector implementation for:
@@ -3198,9 +3275,6 @@ struct Conv1DGenerator
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
-    if (!valid)
-      return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
-
     int64_t nSize, wSize, cSize, kwSize, fSize;
     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
     bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
@@ -3222,11 +3296,11 @@ struct Conv1DGenerator
       // out{n, w, f}
       bindShapeDims(resShapedType, nSize, wSize, fSize);
       switch (oper) {
-      case Conv:
+      case ConvOperationKind::Conv:
         // kernel{kw, c, f}
         bindShapeDims(rhsShapedType, kwSize, cSize);
         break;
-      case Pool:
+      case ConvOperationKind::Pool:
         // kernel{kw}
         bindShapeDims(rhsShapedType, kwSize);
         cSize = fSize;
@@ -3240,10 +3314,10 @@ struct Conv1DGenerator
                       1,
                   cSize};
       switch (oper) {
-      case Conv:
+      case ConvOperationKind::Conv:
         rhsShape = {kwSize, cSize, fSize};
         break;
-      case Pool:
+      case ConvOperationKind::Pool:
         rhsShape = {kwSize};
         break;
       }
@@ -3253,11 +3327,11 @@ struct Conv1DGenerator
       // out{n, f, w}
       bindShapeDims(resShapedType, nSize, fSize, wSize);
       switch (oper) {
-      case Conv:
+      case ConvOperationKind::Conv:
         // kernel{f, c, kw}
         bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
         break;
-      case Pool:
+      case ConvOperationKind::Pool:
         // kernel{kw}
         bindShapeDims(rhsShapedType, kwSize);
         cSize = fSize;
@@ -3270,10 +3344,10 @@ struct Conv1DGenerator
                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
                       1};
       switch (oper) {
-      case Conv:
+      case ConvOperationKind::Conv:
         rhsShape = {fSize, cSize, kwSize};
         break;
-      case Pool:
+      case ConvOperationKind::Pool:
         rhsShape = {kwSize};
         break;
       }
@@ -3305,7 +3379,7 @@ struct Conv1DGenerator
                                                         lhsPadding);
     // This is needed only for Conv.
     Value rhs = nullptr;
-    if (oper == Conv)
+    if (oper == ConvOperationKind::Conv)
       rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                     rhsPadding);
     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
@@ -3328,7 +3402,7 @@ struct Conv1DGenerator
       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
 
       // This is needed only for Conv.
-      if (oper == Conv)
+      if (oper == ConvOperationKind::Conv)
         rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
       // nfw -> nwf
       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
@@ -3346,7 +3420,7 @@ struct Conv1DGenerator
                                      kwSize, strideW, dilationW, wSizeStep,
                                      isSingleChanneled);
     // Do not do for pooling.
-    if (oper == Conv)
+    if (oper == ConvOperationKind::Conv)
       rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
     resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
                                       wSizeStep, isSingleChanneled);
@@ -3361,7 +3435,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         switch (oper) {
-        case Conv:
+        case ConvOperationKind::Conv:
           if (isSingleChanneled) {
             resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
                                                    lhsVals[linearIndex(kw, w)],
@@ -3372,7 +3446,7 @@ struct Conv1DGenerator
                                                   rhsVals[kw], resVals[w]);
           }
           break;
-        case Pool:
+        case ConvOperationKind::Pool:
           resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
                                    resVals[w]);
           break;
@@ -3483,9 +3557,6 @@ struct Conv1DGenerator
   FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
                                        bool channelDimScalableFlag,
                                        bool flatten) {
-    if (!valid)
-      return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
-
     bool scalableChDim = false;
     bool useMasking = false;
     int64_t nSize, wSize, cSize, kwSize;
@@ -3830,9 +3901,7 @@ struct Conv1DGenerator
   }
 
 private:
-  enum OperKind { Conv, Pool };
-  bool valid = false;
-  OperKind oper = Conv;
+  ConvOperationKind oper = ConvOperationKind::Conv;
   StringAttr redOp;
   StringAttr poolExtOp;
   bool isPoolExt = false;
@@ -3842,18 +3911,10 @@ struct Conv1DGenerator
   vector::CombiningKind reductionKind;
 
   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
-  // Returns true iff it is a valid conv/pooling op.
-  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
-  // + yield) and rhs is not used) then it is the body of a pooling
-  // If conv, check for single `mul` predecessor. The `mul` operands must be
-  // block arguments or extension of block arguments.
-  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
-  // must be block arguments or extension of block arguments.
-  bool setOperKind(Operation *reduceOp) {
+  void setConvOperationKind(Operation *reduceOp) {
     int numBlockArguments =
         llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
-    switch (numBlockArguments) {
-    case 1: {
+    if (numBlockArguments == 1) {
       // Will be convolution if feeder is a MulOp.
       // A strength reduced version of MulOp for i1 type is AndOp which is also
       // supported. Otherwise, it can be pooling. This strength reduction logic
@@ -3862,31 +3923,17 @@ struct Conv1DGenerator
                                          llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
       if (isCastOfBlockArgument(feedOp)) {
-        oper = Pool;
+        oper = ConvOperationKind::Pool;
         isPoolExt = true;
         poolExtOp = feedOp->getName().getIdentifier();
-      } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
-                    (isa<arith::AndIOp>(feedOp) &&
-                     feedOp->getResultTypes()[0].isInteger(1))) &&
-                   llvm::all_of(feedOp->getOperands(), [](Value v) {
-                     if (isa<BlockArgument>(v))
-                       return true;
-                     if (Operation *op = v.getDefiningOp())
-                       return isCastOfBlockArgument(op);
-                     return false;
-                   }))) {
-        return false;
+        return;
       }
-      return true;
-    }
-    case 2:
-      // Must be pooling
-      oper = Pool;
-      isPoolExt = false;
-      return true;
-    default:
-      return false;
+      oper = ConvOperationKind::Conv;
+      return;
     }
+    // numBlockArugments == 2 and this is a pooling op.
+    oper = ConvOperationKind::Pool;
+    isPoolExt = false;
   }
 };
 } // namespace
@@ -3896,28 +3943,20 @@ struct Conv1DGenerator
 static FailureOr<Operation *> vectorizeConvolution(
     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
-  // The ConvolutionOpInterface gives us guarantees of existence for
-  // strides/dilations. However, we do not need to rely on those, we can
-  // simply use them if present, otherwise use the default and let the generic
-  // conv. matcher in the ConvGenerator succeed or fail.
-  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
-  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
-  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
-  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-  Conv1DGenerator e(rewriter, op, stride, dilation);
-  auto res = e.generateNonChanneledConv();
+  Conv1DGenerator conv1dGen(rewriter, op);
+  auto res = conv1dGen.generateNonChanneledConv();
   if (succeeded(res))
     return res;
-  res = e.generateNwcConv();
+  res = conv1dGen.generateNwcConv();
   if (succeeded(res))
     return res;
-  res = e.generateNcwConv();
+  res = conv1dGen.generateNcwConv();
   if (succeeded(res))
     return res;
-  res = e.generateNwcPooling();
+  res = conv1dGen.generateNwcPooling();
   if (succeeded(res))
     return res;
-  res = e.generateNcwPooling();
+  res = conv1dGen.generateNcwPooling();
   if (succeeded(res))
     return res;
 
@@ -3940,8 +3979,8 @@ static FailureOr<Operation *> vectorizeConvolution(
     vecChDimSize = inputVecSizes[chDimIdx];
     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
   }
-  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
-                               flatten1DDepthwiseConv);
+  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
+                                       flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {

diff  --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 8f3b199145ce0..2d1f0191eb798 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,6 +112,55 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @conv2d_nchw_fchw(%input: tensor<1x5x8x8xf32>, %filter:  tensor<4x5x3x3xf32>, %output: tensor<1x4x6x6xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x5x8x8xf32>, tensor<4x5x3x3xf32>) outs(%output : tensor<1x4x6x6xf32>) -> tensor<1x4x6x6xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @conv2d_nhwc_fhwc(%input: tensor<1x8x8x5xf32>, %filter: tensor<4x3x3x5xf32>, %output: tensor<1x6x6x4xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x8x8x5xf32>, tensor<4x3x3x5xf32>) outs(%output : tensor<1x6x6x4xf32>) -> tensor<1x6x6x4xf32>
+  return
+}
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @conv3d_ncdhw_fcdhw(%input: tensor<1x5x8x8x8xf32>, %filter: tensor<4x5x3x3x3xf32>, %output: tensor<1x4x6x6x6xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : vector<3xi64>, strides = dense<1> : vector<3xi64>} ins(%input, %filter : tensor<1x5x8x8x8xf32>, tensor<4x5x3x3x3xf32>) outs(%output : tensor<1x4x6x6x6xf32>) -> tensor<1x4x6x6x6xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_3d_ncdhw_fcdhw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
   %pad = arith.constant 0.000000e+00 : f32
   // expected-error @+1 {{Attempted to vectorize, but failed}}


        


More information about the Mlir-commits mailing list