[Mlir-commits] [mlir] [mlir][linalg] Introduce new `linalg.conv` op (PR #117688)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 26 01:28:10 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

This patch lays the groundwork for the new `linalg.conv` op which is designed to replace the multitude of `linalg.conv_...` as well as `linalg.depthwise_conv_...` ops.

A test pass is implemented which can convert the old conv ops to the new op. The `linalg-generalize-named-ops` can then be used to convert both the old and the new ops to a `linalg.generic` op for comparison.

---

Patch is 71.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117688.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+39) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+26) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+27) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+116) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+246) 
- (added) mlir/test/Dialect/Linalg/generalize-new-conv.mlir (+656) 
- (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+57) 
- (modified) mlir/test/lib/Dialect/Linalg/CMakeLists.txt (+2-1) 
- (added) mlir/test/lib/Dialect/Linalg/TestNewConv.cpp (+187) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..b659241b5ed5b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -81,4 +81,43 @@ def IteratorTypeEnum : EnumAttr<Linalg_Dialect, IteratorType, "iterator_type"> {
 def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
   "Iterator type should be an enum.">;
 
+
+def ConvolutionDimArray : ArrayRefParameter<"ConvDimEnum"> {
+  let printer = [{
+    $_printer << '{';
+    llvm::interleaveComma($_self, $_printer, [&](ConvDimEnum en) {
+        $_printer.printStrippedAttrOrType(en);
+    });
+    $_printer << '}';
+  }];
+
+  let parser = [{
+    [&]() -> FailureOr<SmallVector<ConvDimEnum>> {
+        using Result = SmallVector<ConvDimEnum>;
+        if ($_parser.parseLBrace())
+            return failure();
+        FailureOr<Result> result = FieldParser<Result>::parse($_parser);
+        if (failed(result))
+            return failure();
+        if ($_parser.parseRBrace())
+            return failure();
+        return result;
+    }()
+  }];
+}
+
+/// Attribute that represents an ordered set of tensor dimensions involved in
+/// convolution.
+def ConvDimsAttr : AttrDef<Linalg_Dialect, "ConvDims", [], "::mlir::Attribute"> {
+  let mnemonic = "conv_dims";
+
+  let parameters = (ins
+    ConvolutionDimArray:$dims
+  );
+
+  let assemblyFormat = "$dims";
+
+  let returnType = "mlir::linalg::ConvDims";
+  let convertFromStorage = "mlir::linalg::ConvDims($_self.getDims())";
+}
 #endif // LINALG_BASE
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..ef9e00822fbe3b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -63,4 +63,30 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
   let cppNamespace = "::mlir::linalg";
 }
 
+
+class ConvDimEnumAttrCase<string sym, int val, string str = sym>
+    : IntEnumAttrCaseBase<I8, sym, str, val>;
+
+def ConvDimEnumAttr :
+    IntEnumAttr<I8, "ConvDimEnum", "summary", [
+      /// Batch is a dimension of input and output, indexed from a parallel loop.
+      ConvDimEnumAttrCase<"BATCH", 0, "N">,
+      /// Input channel is a dimension in all tensors, indexed from a reduction loop.
+      /// Depthwise convolutions perform no reduction across channels and therefore
+      /// do not use this.
+      ConvDimEnumAttrCase<"INPUT_CHANNEL", 1, "C">,
+      /// Output channel is a dimension in filter and output, index from a parallel loop.
+      ConvDimEnumAttrCase<"OUTPUT_CHANNEL", 2, "F">,
+      /// Group is a dimension in all tensors and indexed from a parallel loop.
+      ConvDimEnumAttrCase<"GROUP", 3, "G">,
+      /// Spatial dimensions occur in all tensors. Output is indexed from a parallel
+      /// loop, filter from a reduction loop and input from both.
+      ConvDimEnumAttrCase<"SPATIAL_0", 4, "0">,
+      ConvDimEnumAttrCase<"SPATIAL_1", 5, "1">,
+      ConvDimEnumAttrCase<"SPATIAL_2", 6, "2">,
+    ]> {
+  let underlyingType = "uint8_t";
+  let cppNamespace = "::mlir::linalg";
+}
+
 #endif // LINALG_ENUMS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6f1c243cc4396d..752fcd8affaa27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -117,6 +117,33 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
 bool isaConvolutionOpInterface(LinalgOp linalgOp,
                                bool allowEmptyConvolvedDims = false);
 
+enum class ConvDimEnum : uint8_t;
+class ConvDims {
+  ArrayRef<ConvDimEnum> storage;
+
+public:
+  ConvDims() = default;
+  ConvDims(ArrayRef<ConvDimEnum> dims) : storage(dims) {}
+  ConvDims(SmallVectorImpl<ConvDimEnum> &dims) : storage(dims) {}
+
+  bool contains(ConvDimEnum dim) const {
+    return llvm::is_contained(storage, dim);
+  }
+
+  int64_t getPos(ConvDimEnum dim) const {
+    auto it = llvm::find(storage, dim);
+    assert(it != storage.end() && "expected dimension to be present");
+
+    return std::distance(storage.begin(), it);
+  }
+
+  int64_t size() const { return storage.size(); }
+  operator ArrayRef<ConvDimEnum>() const { return storage; }
+
+  auto begin() const { return storage.begin(); }
+  auto end() const { return storage.end(); }
+};
+
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 37eec6e07963b1..09b2dfd75cf67e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -683,6 +683,122 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for ConvOp
+//===----------------------------------------------------------------------===//
+
+def ConvOp : LinalgStructuredBase_Op<"conv", [AttrSizedOperandSegments]> {
+
+  let summary = [{
+    Configurable convolution operation with configurable tensor layouts.
+  }];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    The subtype of convolution is defined by the tensor layouts of `input`,
+    `filter`, and `output`. For example, a standard batched 2D convolution:
+
+    ```
+      %0 = linalg.conv {
+          input_dims = #linalg<conv_dims {N, C, "1", "0"}>,
+          filter_dims = #linalg<conv_dims {F, C, "1", "0"}>,
+          output_dims = #linalg<conv_dims {N, F, "1", "0"}>
+        }
+        ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<16x4x3x3xf32>)
+        outs(%output : tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32>
+    ```
+
+    This op could be turned into a depthwise convolution as follows:
+    ```
+      %0 = linalg.conv {
+          input_dims = #linalg<conv_dims {N, G, "1", "0"}>,
+          filter_dims = #linalg<conv_dims {G, "1", "0"}>,
+          output_dims = #linalg<conv_dims {N, G, "1", "0"}>
+        }
+        ins(%input, %filter : tensor<8x4x16x16xf32>, tensor<4x3x3xf32>)
+        outs(%output : tensor<8x4x14x14xf32>) -> tensor<8x4x14x14xf32>
+    ```
+
+    For the detailed semantics of the available tensor dimensions, refer to
+    `mlir::linalg::ConvDimsEnum`.
+
+    Strides and dilations can be supplied as optional attributes, where
+    `strides[0]` is the stride for the `SPATIAL_0` dimension, etc.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$inputs, Variadic<AnyShaped>:$outputs,
+    ConvDimsAttr:$input_dims, ConvDimsAttr:$filter_dims, ConvDimsAttr:$output_dims,
+    OptionalAttr<I64ElementsAttr>:$strides, OptionalAttr<I64ElementsAttr>:$dilations
+  );
+  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+  let regions = (region AnyRegion:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<
+      (ins "TypeRange":$resTys, "Value":$input, "Value":$filter, "Value":$output, "ConvDims":$input_dims,
+            "ConvDims":$filter_dims, "ConvDims":$output_dims, "ArrayRef<int64_t>":$strides,
+            "ArrayRef<int64_t>":$dilations, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildConvOp($_builder, $_state, resTys, input, filter, output,
+            input_dims, filter_dims, output_dims, strides, dilations,
+            attributes, ConvOp::getRegionBuilder());
+      }]>,
+    OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs, "ConvDimsAttr":$input_dims,
+            "ConvDimsAttr":$filter_dims, "ConvDimsAttr":$output_dims,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildConvOp($_builder, $_state, std::nullopt, inputs, outputs,
+            input_dims, filter_dims, output_dims, nullptr, nullptr,
+            attributes, ConvOp::getRegionBuilder());
+      }]>,
+    OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs, "ConvDimsAttr":$input_dims,
+            "ConvDimsAttr":$filter_dims, "ConvDimsAttr":$output_dims,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildConvOp($_builder, $_state, resultTensorTypes,
+            inputs, outputs, input_dims, filter_dims, output_dims, nullptr, nullptr,
+            attributes, ConvOp::getRegionBuilder());
+      }]>
+  ];
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
+    ArrayAttr getIndexingMaps();
+
+    /// Implements the block region builder.
+    static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+    /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+    static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+    static std::function<void(ImplicitLocOpBuilder &,
+                              Block &, ArrayRef<NamedAttribute>)>
+    getRegionBuilder() { return regionBuilder; }
+
+    ::mlir::MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+
+    bool hasDynamicIndexingMaps() { return true; }
+
+    /// Returns the number of spatial dimensions, i.e. 1 for 1D convolution,
+    /// 2 for 2D convolution, etc.
+    int64_t getNumSpatialDims();
+
+    bool isDepthwise();
+    bool isGrouped();
+    bool isBatched();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..03d9a7f3f09ce3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,41 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildConvOp(OpBuilder &b, OperationState &state,
+                        std::optional<TypeRange> resultTensorTypes,
+                        ValueRange inputs, ValueRange outputs,
+                        ConvDimsAttr inputDims, ConvDimsAttr filterDims,
+                        ConvDimsAttr outputDims, Attribute strides,
+                        Attribute dilations,
+                        ArrayRef<NamedAttribute> attributes,
+                        RegionBuilderFn regionBuilder) {
+  state.addAttribute("input_dims", inputDims);
+  state.addAttribute("filter_dims", filterDims);
+  state.addAttribute("output_dims", outputDims);
+  if (strides)
+    state.addAttribute("strides", strides);
+
+  if (dilations)
+    state.addAttribute("dilations", dilations);
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
+static void buildConvOp(OpBuilder &b, OperationState &state,
+                        std::optional<TypeRange> resultTensorTypes, Value input,
+                        Value filter, Value output, ConvDims inputDims,
+                        ConvDims filterDims, ConvDims outputDims,
+                        ArrayRef<int64_t> strides, ArrayRef<int64_t> dilations,
+                        ArrayRef<NamedAttribute> attributes,
+                        RegionBuilderFn regionBuilder) {
+  auto iAttr = ConvDimsAttr::get(b.getContext(), inputDims);
+  auto fAttr = ConvDimsAttr::get(b.getContext(), filterDims);
+  auto oAttr = ConvDimsAttr::get(b.getContext(), outputDims);
+  return buildConvOp(b, state, resultTensorTypes, {input, filter}, {output},
+                     iAttr, fAttr, oAttr, b.getI64VectorAttr(strides),
+                     b.getI64VectorAttr(dilations), attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3611,5 +3646,216 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ConvOp
+//===----------------------------------------------------------------------===//
+
+bool ConvOp::isDepthwise() {
+  return !getFilterDims().contains(ConvDimEnum::INPUT_CHANNEL);
+}
+
+bool ConvOp::isGrouped() {
+  // If not all tensors contain the GROUP dimension, then it's either not a
+  // grouped convolution, or the number of groups is 1, which we also don't
+  // consider grouped.
+  return getInputDims().contains(ConvDimEnum::GROUP) &&
+         getFilterDims().contains(ConvDimEnum::GROUP) &&
+         getOutputDims().contains(ConvDimEnum::GROUP);
+}
+
+bool ConvOp::isBatched() {
+  // Both input and output tensors must contain the BATCH dimension.
+  return getInputDims().contains(ConvDimEnum::BATCH) &&
+         getOutputDims().contains(ConvDimEnum::BATCH);
+}
+
+int64_t ConvOp::getNumSpatialDims() {
+  if (getInputDims().contains(ConvDimEnum::SPATIAL_2))
+    return 3;
+  if (getInputDims().contains(ConvDimEnum::SPATIAL_1))
+    return 2;
+  return 1;
+}
+
+SmallVector<utils::IteratorType> ConvOp::getIteratorTypesArray() {
+  int numParallelDims = getOutputDims().size();
+
+  int numReductionDims = getNumSpatialDims();
+  if (!isDepthwise())
+    ++numReductionDims; // input channel
+
+  SmallVector<utils::IteratorType> iteratorTypes(numParallelDims,
+                                                 utils::IteratorType::parallel);
+  iteratorTypes.append(numReductionDims, utils::IteratorType::reduction);
+  return iteratorTypes;
+}
+
+ArrayAttr ConvOp::getIndexingMaps() {
+  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
+      LinalgDialect::kMemoizedIndexingMapsAttrName);
+  if (cached)
+    return cached;
+
+  Builder b(getContext());
+  SmallVector<AffineExpr> strides, dilations;
+  {
+    SmallVector<int64_t> strideValues, dilationValues;
+
+    if (getStrides())
+      strideValues = SmallVector<int64_t>(getStrides()->getValues<int64_t>());
+    else
+      strideValues = SmallVector<int64_t>(getNumSpatialDims(), 1);
+
+    if (getDilations())
+      dilationValues =
+          SmallVector<int64_t>(getDilations()->getValues<int64_t>());
+    else
+      dilationValues = SmallVector<int64_t>(getNumSpatialDims(), 1);
+
+    for (int j = 0; j < getNumSpatialDims(); ++j) {
+      strides.push_back(b.getAffineConstantExpr(strideValues[j]));
+      dilations.push_back(b.getAffineConstantExpr(dilationValues[j]));
+    }
+  }
+
+  llvm::DenseMap<ConvDimEnum, AffineExpr> parallelDims;
+  llvm::DenseMap<ConvDimEnum, AffineExpr> reductionDims;
+  SmallVector<AffineExpr> oExprs;
+
+  // Via the iterator types, we have defined the parallel loops to come first,
+  // followed by the reduction loops. We choose the order of the parallel loops
+  // to match the order of the output tensor dimensions. This is arbitrary and
+  // is done to follow the convention which most/some of the old linalg
+  // convolution ops follow.
+  int64_t i = 0;
+  for (auto d : getOutputDims()) {
+    auto expr = b.getAffineDimExpr(i++);
+    parallelDims[d] = expr;
+    oExprs.push_back(expr);
+  }
+  // Reduction loops are ordered to match the order of the filter tensor.
+  for (auto d : getFilterDims())
+    if (d == ConvDimEnum::INPUT_CHANNEL || d == ConvDimEnum::SPATIAL_0 ||
+        d == ConvDimEnum::SPATIAL_1 || d == ConvDimEnum::SPATIAL_2)
+      reductionDims[d] = b.getAffineDimExpr(i++);
+
+  SmallVector<AffineExpr> iExprs =
+      llvm::map_to_vector(getInputDims(), [&](ConvDimEnum dim) -> AffineExpr {
+        switch (dim) {
+        case ConvDimEnum::SPATIAL_0:
+          return (parallelDims[dim] * strides[0]) +
+                 (reductionDims[dim] * dilations[0]);
+        case ConvDimEnum::SPATIAL_1:
+          return (parallelDims[dim] * strides[1]) +
+                 (reductionDims[dim] * dilations[1]);
+        case ConvDimEnum::SPATIAL_2:
+          return (parallelDims[dim] * strides[2]) +
+                 (reductionDims[dim] * dilations[2]);
+        case ConvDimEnum::INPUT_CHANNEL:
+          return reductionDims[dim];
+        default:
+          return parallelDims[dim];
+        }
+      });
+  SmallVector<AffineExpr> fExprs =
+      llvm::map_to_vector(getFilterDims(), [&](ConvDimEnum dim) -> AffineExpr {
+        if (reductionDims.contains(dim))
+          return reductionDims[dim];
+        return parallelDims[dim];
+      });
+
+  cached = b.getAffineMapArrayAttr(
+      {AffineMap::get(getNumLoops(), 0, iExprs, getContext()),
+       AffineMap::get(getNumLoops(), 0, fExprs, getContext()),
+       AffineMap::get(getNumLoops(), 0, oExprs, getContext())});
+  getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+  return cached;
+}
+
+void ConvOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                           ArrayRef<NamedAttribute> attrs) {
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
+  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult ConvOp::parse(OpAsmParser &parser, OperationState &result) {
+  return ::parseNamedStructuredOp(parser, result, 3,
+                                  ConvOp::getRegionBuilder());
+}
+void ConvOp::print(OpAsmPrinter &p) {
+  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes",
+                                           "linalg.memoized_indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
+}
+
+LogicalResult ConvOp::verify() {
+  // Batch dimension cannot be present in filter tensor.
+  if (getFilterDims().contains(ConvDimEnum::BATCH))
+    return emitOpError("Batch dimension cannot be present in filter tensor.");
+
+  // Output channel cannot be present in input tensor.
+  if (getInputDims().contains(ConvDimEnum::OUTPUT_CHANNEL))
+    return emitOpError("Output channel cannot be present in input tensor.");
+
+  // Higher space dimensions cannot occur without the respective lower ones, so
+  // as to work with the `strides` and `dilations` attributes.
+  bool isSpat2 = getInputDims().contains(ConvDimEnum::SPATIAL_2);
+  bool isSpat1 = getInputDims().contains(ConvDimEnum::SPATIAL_1);
+  bool isSpat0 = getInputDims().contains(ConvDimEnum::SPATIAL_0);
+
+  if ((isSpat2 && (!isSpat1 || !isSpat0)) || (isSpat1 && !isSpat0))
+    return emitOpError("Inconsistent spatial dimensions in `input_dims`.");
+
+  if (!isSpat0)
+    return emitOpError("Requires at least one spatial dimension.");
+
+  // Spatial dimensions have to match between all tensors.
+  if (isSpat2 != getFilterDims().contains(ConvDimEnum::SPATIAL_2) ||
+      isSpat2 != getOutputDims().contains(ConvDimEnum::SPATIAL_2) ||
+      isSpat1 != getFilterDims().contains(ConvDimEnum::SPATIAL_1) ||
+      isSpat1 != getOutputDims().contains(ConvDimEnum::SPATIAL_1) ||
+      isSpat0 != getFilterDims().contains(ConvDimEnum::SPATIAL_0) ||
+      isSpat0 != getOutputDims().contains(ConvDimEnum::SPATIAL...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/117688


More information about the Mlir-commits mailing list