[Mlir-commits] [mlir] e444dce - [mlir][Transform] Add classify_convolution_dims op

Quinn Dawkins llvmlistbot at llvm.org
Mon Jul 31 13:18:06 PDT 2023


Author: Quinn Dawkins
Date: 2023-07-31T16:15:25-04:00
New Revision: e444dcef7cadb29490deca726cb743e38d0412b3

URL: https://github.com/llvm/llvm-project/commit/e444dcef7cadb29490deca726cb743e38d0412b3
DIFF: https://github.com/llvm/llvm-project/commit/e444dcef7cadb29490deca726cb743e38d0412b3.diff

LOG: [mlir][Transform] Add classify_convolution_dims op

Includes `inferConvolutionDims` based on the existing helper for
contractions, `inferContractionDims`. This allows matching and
identifying the relevant dims for a convolution sub-computation of
a linalg operation.

Additionally adds stride/dilations inference to the captures and
convolution interface matcher.

Differential Revision: https://reviews.llvm.org/D156080

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
    mlir/test/Dialect/Linalg/match-ops-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index b321a376309976..e735fdc59ec1be 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -67,6 +67,45 @@ FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
 // TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
 bool isaContractionOpInterface(LinalgOp linalgOp);
 
+/// Positions of a Linalg op loops that correspond to 
diff erent kinds of a
+/// convolution dimension.
+struct ConvolutionDimensions {
+  SmallVector<unsigned, 2> batch;
+  SmallVector<unsigned, 2> outputImage;
+  SmallVector<unsigned, 2> outputChannel;
+  SmallVector<unsigned, 2> filterLoop;
+  SmallVector<unsigned, 2> inputChannel;
+  SmallVector<unsigned, 2> depth;
+  SmallVector<int64_t, 2> strides;
+  SmallVector<int64_t, 2> dilations;
+};
+
+/// Find at least 1 parallel (output_image) and reduction (filter_loop)
+/// dimension candidates that form a convolution subcomputation within
+/// `linalgOp`. The LHS is assumed to be the convolution input while the
+/// RHS is assumed as the filter.
+/// These dimensions are such that:
+///   1. Optional batch dimensions that appear in the input and filter.
+///   2. The output_image dimension is involved in a cross-correlation along LHS
+///      (i.e. it is a permutation on RES and LHS and has an associated
+///      filter_loop in RHS).
+///   3. Optional output_channel dimension is involved in an outer-product along
+///      RHS (i.e. it is a permutation on RES and RHS and does not appear in
+///      LHS).
+///   4. Optional input_channel dimension appears as a permutation on LHS and
+///      RHS.
+///   5. The filter_loop dimension appears as a permutation on the RHS and
+///      represents the shape of the kernel cross-correlated along a
+///      corresponding output_image dim.
+///   6. The input_channel dimension appears as a permutation on LHS and RHS.
+///   7. All dimensions appear only once in any given indexing map.
+/// This allows e.g. detecting that some convolution is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+/// When multiple dimension occurrences exist that match any classification
+/// indices are returned in sorted order.
+/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
+FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
+
 /// Checks whether `linalgOp` conforms to ConvolutionOpInterface.
 // TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
 bool isaConvolutionOpInterface(LinalgOp linalgOp);
@@ -100,9 +139,6 @@ enum class MatchContractionResult;
 /// Checks whether `op` conforms to ContractionOpInterface and populates
 /// `dimensions` with indexes of the 
diff erent kinds of dimensions when
 /// present.
-// TODO: Extract a standalone `inferConvolutionDims` that can also detect
-// whether a conv pattern exists within a bigger linalg op (see
-// inferContractionDims).
 MatchContractionResult
 isContractionInterfaceImpl(Operation *op,
                            ContractionDimensions *dimensions = nullptr);
@@ -115,17 +151,6 @@ StringRef getMatchContractionMessage(MatchContractionResult res);
 /// convolution.
 enum class MatchConvolutionResult;
 
-/// Positions of a Linalg op loops that correspond to 
diff erent kinds of a
-/// convolution dimension.
-struct ConvolutionDimensions {
-  SmallVector<unsigned, 2> batch;
-  SmallVector<unsigned, 2> outputImage;
-  SmallVector<unsigned, 2> outputChannel;
-  SmallVector<unsigned, 2> filterLoop;
-  SmallVector<unsigned, 2> inputChannel;
-  SmallVector<unsigned, 2> depth;
-};
-
 /// Checks whether `op` conforms to ConvolutionOpInterface and populates
 /// `dimensions` with indexes of the 
diff erent kinds of dimensions when
 /// present.

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index ad348d0ce89a64..9e108529ec129b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -170,6 +170,57 @@ def MatchStructuredClassifyContractionDimsOp
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
 }
 
+def MatchStructuredClassifyConvolutionDimsOp
+    : Op<Transform_Dialect, "match.structured.classify_convolution_dims", [
+    SingleOpMatcher,
+    StructuredPredicate,
+    MatchOpInterface,
+    MemoryEffectsOpInterface]> {
+  let summary =
+      "Checks if an operation has convolution-like dimensions and returns them";
+  let description = !strconcat([{
+    Checks if the structured payload op has convolution-like dimensions as
+    follows:
+
+      C(batch, depth, oi, oc) += A(batch, depth, oi, ic) * B(fl, depth, ic, oc)
+
+    That is:
+
+      - 'batch' are parallel dimensions used in the input and result;
+      - 'output_image' ('oi') are parallel dimensions used in the input and result;
+      - 'output_channel' ('oc') are parallel dimensions used in the filter and result;
+      - 'filter_loop' ('fl') are reduction dimensions representing the dimensions of the sliding window;
+      - 'input_channel' ('ic') are reduction dimensions present only in the input and filter.
+      - 'depth' ('ic') are parallel dimensions present in the input, filter, and output.
+
+    Additionally this will match stride and dilation information for the convolution:
+      - 'strides' are the static strides per convolution window dimension;
+      - 'dilations' are the static dilations per convolution window dimension.
+
+    Note that this doesn't check the operation in the body.
+
+  }], StructuredPredicate.extraDescription, [{
+
+    #### Return modes
+
+    Succeeds if the operation has the convolution-like dimensions, produces a
+    silenceable failure otherwise.
+  }]);
+
+  let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+  let results = (outs TransformParamTypeInterface:$batch,
+                      TransformParamTypeInterface:$output_image,
+                      TransformParamTypeInterface:$output_channel,
+                      TransformParamTypeInterface:$filter_loop,
+                      TransformParamTypeInterface:$input_channel,
+                      TransformParamTypeInterface:$depth,
+                      TransformParamTypeInterface:$strides,
+                      TransformParamTypeInterface:$dilations);
+  let assemblyFormat =
+    "$operand_handle attr-dict `:` functional-type(operands, results)";
+  let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+}
+
 class StructuredDimDescription<string kind> {
   string description = !strconcat([{
      The following }], kind ,[{ specifications are supported:

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 08de32818f0047..95a766b3375713 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -161,10 +161,10 @@ static bool isContractionBody(Block &block) {
 /// determining whether:
 ///   - It is a single AffineDimExpr.
 ///   - It is the only result involving this AffineDimExpr.
-static DenseSet<int64_t>
+static llvm::SmallDenseSet<int64_t>
 findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
                                 utils::IteratorType iter) {
-  DenseSet<int64_t> res;
+  llvm::SmallDenseSet<int64_t> res;
   assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
   AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
   for (AffineExpr e : indexingMap.getResults()) {
@@ -200,30 +200,30 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
   if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
     return failure();
 
-  DenseSet<int64_t> a = findPermutationsIndexingOperand(
+  llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
       linalgOp, linalgOp.getDpsInputOperand(0), par);
-  DenseSet<int64_t> b = findPermutationsIndexingOperand(
+  llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
       linalgOp, linalgOp.getDpsInputOperand(1), par);
-  DenseSet<int64_t> c = findPermutationsIndexingOperand(
+  llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
       linalgOp, linalgOp.getDpsInitOperand(0), par);
 
   // A & C - B are the iterators involved in an outer-product along A (the LHS).
-  DenseSet<int64_t> ac = a;
+  llvm::SmallDenseSet<int64_t> ac = a;
   llvm::set_intersect(ac, c);
   llvm::set_subtract(ac, b);
   // B & C - A are the iterators involved in an outer-product along B (the RHS).
-  DenseSet<int64_t> bc = b;
+  llvm::SmallDenseSet<int64_t> bc = b;
   llvm::set_intersect(bc, c);
   llvm::set_subtract(bc, a);
   // A & B & C are the "batch" dimensions.
-  DenseSet<int64_t> batches = a;
+  llvm::SmallDenseSet<int64_t> batches = a;
   llvm::set_intersect(batches, b);
   llvm::set_intersect(batches, c);
 
   // A & B red are the reduction dimensions.
-  DenseSet<int64_t> ra = findPermutationsIndexingOperand(
+  llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
       linalgOp, linalgOp.getDpsInputOperand(0), red);
-  DenseSet<int64_t> rb = findPermutationsIndexingOperand(
+  llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
       linalgOp, linalgOp.getDpsInputOperand(1), red);
   llvm::set_intersect(ra, rb);
 
@@ -236,10 +236,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
       SmallVector<unsigned, 2>(ac.begin(), ac.end()),
       SmallVector<unsigned, 2>(bc.begin(), bc.end()),
       SmallVector<unsigned, 2>(ra.begin(), ra.end())};
-  std::sort(dimensions.batch.begin(), dimensions.batch.end());
-  std::sort(dimensions.m.begin(), dimensions.m.end());
-  std::sort(dimensions.n.begin(), dimensions.n.end());
-  std::sort(dimensions.k.begin(), dimensions.k.end());
+  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
+  llvm::sort(dimensions.m.begin(), dimensions.m.end());
+  llvm::sort(dimensions.n.begin(), dimensions.n.end());
+  llvm::sort(dimensions.k.begin(), dimensions.k.end());
   return dimensions;
 }
 
@@ -359,8 +359,37 @@ namespace {
 /// dimensions and verifies each dimension occurs only once.
 struct ConvAccessExprWalker
     : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
-  llvm::SmallDenseSet<unsigned> convolvedDims;
-  llvm::SmallDenseSet<unsigned> unConvolvedDims;
+  // Stores dimensions used in expressions of the above form.
+  llvm::SmallDenseSet<int64_t> convolvedDims;
+  // Stores the dual mapping between LHS and RHS of convolution exprs.
+  llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
+  // Stores single use dimensions used by an AffineDimExpr.
+  llvm::SmallDenseSet<int64_t> unConvolvedDims;
+  // Stores a mapping from convolved dims to their coefficient.
+  llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
+
+  // Removes dims with multiple uses in the source input map from dimension
+  // sets tracked by this walker.
+  void clearMultiUseDims(AffineMap map) {
+    for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
+      if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
+            return e.isFunctionOfDim(dimPos);
+          }) > 1) {
+        convolvedDims.erase(dimPos);
+        unConvolvedDims.erase(dimPos);
+        // If a duplicate dim is marked as convolved, the pair of the duplicate
+        // dim must be removed from the map as well.
+        if (convolvedDimMapping.contains(dimPos)) {
+          int64_t pairedDim = convolvedDimMapping[dimPos];
+          convolvedDims.erase(pairedDim);
+          unConvolvedDims.erase(pairedDim);
+          strideAndDilationMapping.erase(pairedDim);
+          convolvedDimMapping.erase(dimPos);
+          convolvedDimMapping.erase(pairedDim);
+        }
+      }
+    }
+  }
 
   LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
     unsigned position = dimExpr.getPosition();
@@ -379,17 +408,25 @@ struct ConvAccessExprWalker
     // In pre-order visit, top level op has to be an add op.
     if (binaryExpr.getKind() != AffineExprKind::Add)
       return failure();
-    return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) &&
-                   succeeded(isDimExprOrMulExpr(binaryExpr.getRHS())));
+    auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
+    auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
+    if (failed(lhsDimPos) || failed(rhsDimPos))
+      return failure();
+    convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
+    convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
+    return success();
   }
 
-  LogicalResult isDimExprOrMulExpr(AffineExpr expr) {
+  FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
-      unsigned dim = dimExpr.getPosition();
+      int64_t dim = dimExpr.getPosition();
       if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
         return failure();
+      // Stride/dilation for this dim is implicitly 1.
+      strideAndDilationMapping[dim] =
+          getAffineConstantExpr(1, expr.getContext());
       convolvedDims.insert(dim);
-      return success();
+      return dim;
     }
     if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
       if (symbolMulExpr.getKind() != AffineExprKind::Mul)
@@ -406,26 +443,170 @@ struct ConvAccessExprWalker
       auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
       if (!mulExpr || !dimExpr)
         return failure();
-      unsigned dim = dimExpr.getPosition();
+      int64_t dim = dimExpr.getPosition();
       if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
         return failure();
+      strideAndDilationMapping[dim] = mulExpr;
       convolvedDims.insert(dim);
-      return success();
+      return dim;
     }
     return failure();
   }
 };
 } // namespace
 
-static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
+static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
   assert(map.isProjectedPermutation() &&
          "expected map to have projected permutations");
-  llvm::SmallDenseSet<unsigned> preservedDims;
+  llvm::SmallDenseSet<int64_t> preservedDims;
   for (auto expr : map.getResults())
     preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
   return preservedDims;
 }
 
+static SmallVector<int64_t, 2>
+getConstantsFromExprList(SmallVector<AffineExpr, 2> exprs) {
+  SmallVector<int64_t, 2> vals;
+  for (auto e : exprs) {
+    auto constantExpr = e.dyn_cast<AffineConstantExpr>();
+    assert(constantExpr && "Found non-constant stride/dilation");
+    vals.push_back(constantExpr.getValue());
+  }
+  return vals;
+}
+
+/// Classifies dimensions in the `linalgOp` used by a convolution
+/// subcomputation, as captured by `inputExprWalker`. If
+/// `allowEmptyConvolvedDims` is not set this this will fail if there is not
+/// at least convolved dimension pair (output image + filter loop). Convolution
+/// dimensions are specified in sorted order, and strides match the order of
+/// the filter loop dimensions, while the dilations match the order of the
+/// output image dimensions.
+static FailureOr<ConvolutionDimensions>
+inferConvolutionDimsImpl(LinalgOp linalgOp,
+                         ConvAccessExprWalker &inputExprWalker,
+                         bool allowEmptyConvolvedDims) {
+  llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInputOperand(1), par);
+  llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
+      linalgOp, linalgOp.getDpsInitOperand(0), par);
+
+  // unConvolvedDims & outputDims - filterDims are the batch iterators.
+  llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
+  llvm::set_intersect(batch, outputDims);
+  llvm::set_subtract(batch, filterDims);
+
+  // convolvedDims & outputDims are the output image iterators.
+  llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
+  llvm::set_intersect(oi, outputDims);
+
+  // filterDims & outputDims - unConvolvedDims are the output channel iterators.
+  llvm::SmallDenseSet<int64_t> oc = filterDims;
+  llvm::set_intersect(oc, outputDims);
+  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
+
+  // filterDims & outputDims & unConvolvedDims are the depth iterators.
+  llvm::SmallDenseSet<int64_t> depth = filterDims;
+  llvm::set_intersect(depth, outputDims);
+  llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
+
+  llvm::SmallDenseSet<int64_t> filterReducedDims =
+      findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1),
+                                      red);
+
+  // convolvedDims & filterReducedDims are the filter loop iterators.
+  llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
+  llvm::set_intersect(fl, filterReducedDims);
+
+  // unConvolvedDims & filterReducedDims are the input channel iterators.
+  llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
+  llvm::set_intersect(ic, filterReducedDims);
+
+  if (oi.empty() && !allowEmptyConvolvedDims)
+    return failure();
+
+  // Return each set in sorted order.
+  ConvolutionDimensions dimensions{
+      SmallVector<unsigned, 2>(batch.begin(), batch.end()),
+      SmallVector<unsigned, 2>(oi.begin(), oi.end()),
+      SmallVector<unsigned, 2>(oc.begin(), oc.end()),
+      SmallVector<unsigned, 2>(fl.begin(), fl.end()),
+      SmallVector<unsigned, 2>(ic.begin(), ic.end()),
+      SmallVector<unsigned, 2>(depth.begin(), depth.end()),
+      /*strides=*/SmallVector<int64_t, 2>{},
+      /*dilations=*/SmallVector<int64_t, 2>{}};
+  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
+  llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
+  llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
+  llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
+  llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
+  llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
+
+  // Use the op carried strides/dilations attribute if present.
+  auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+  if (!nativeStrides) {
+    SmallVector<AffineExpr, 2> strideExprs;
+    for (unsigned oiDim : dimensions.outputImage)
+      strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
+    dimensions.strides = getConstantsFromExprList(strideExprs);
+  } else {
+    dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
+  }
+  auto nativeDilations =
+      linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+  if (!nativeDilations) {
+    SmallVector<AffineExpr, 2> dilationExprs;
+    for (unsigned flDim : dimensions.filterLoop)
+      dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
+    dimensions.dilations = getConstantsFromExprList(dilationExprs);
+  } else {
+    dimensions.dilations =
+        llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
+  }
+  return dimensions;
+}
+
+/// Find at least 1 parallel (output_image) and reduction (filter_loop)
+/// dimension candidates that form a convolution subcomputation within
+/// `linalgOp`. The LHS is assumed to be the convolution input while the
+/// RHS is assumed as the filter.
+/// These dimensions are such that:
+///   1. Optional batch dimensions that appear in the input and filter.
+///   2. The output_image dimension is involved in a cross-correlation along LHS
+///      (i.e. it is a permutation on RES and LHS and has an associated
+///      filter_loop in RHS).
+///   3. Optional output_channel dimension is involved in an outer-product along
+///      RHS (i.e. it is a permutation on RES and RHS and does not appear in
+///      LHS).
+///   4. Optional input_channel dimension appears as a permutation on LHS and
+///      RHS.
+///   5. The filter_loop dimension appears as a permutation on the RHS and
+///      represents the shape of the kernel cross-correlated along a
+///      corresponding output_image dim.
+///   6. The input_channel dimension appears as a permutation on LHS and RHS.
+///   7. All dimensions appear only once in any given indexing map.
+/// This allows e.g. detecting that some convolution is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+/// When multiple dimension occurrences exist that match any classification
+/// indices are returned in sorted order.
+/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
+FailureOr<ConvolutionDimensions>
+mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {
+  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+    return failure();
+
+  auto indexingMaps = linalgOp.getIndexingMapsArray();
+
+  // Check the input indexing map has the right form.
+  ConvAccessExprWalker inputExprWalker;
+  for (AffineExpr expr : indexingMaps[0].getResults())
+    (void)inputExprWalker.visit(expr);
+  inputExprWalker.clearMultiUseDims(indexingMaps[0]);
+
+  return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
+                                  /*allowEmptyConvolvedDims=*/false);
+}
+
 namespace mlir::linalg::detail {
 enum class MatchConvolutionResult {
   Success = 0,
@@ -466,9 +647,9 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
 
   auto iteratorTypes = linalgOp.getIteratorTypesArray();
 
-  llvm::SmallDenseSet<unsigned> outputDims =
+  llvm::SmallDenseSet<int64_t> outputDims =
       getPreservedDims(indexingMaps.back());
-  llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
+  llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
   // Make sure all loops are characterized as one of:
   // - Batch loop : present in output, as non-convolved in input, not present in
   //   filter.
@@ -482,17 +663,15 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
   //   present in filter.
   // - Depth multiplier : unconvolved in input, present in output, present in
   //   filter.
-  llvm::SmallDenseSet<unsigned> allLoopDims;
+  llvm::SmallDenseSet<int64_t> allLoopDims;
   for (auto outputExpr : indexingMaps.back().getResults()) {
-    unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
+    int64_t outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
         !filterDims.count(outputDim)) {
       // Batch dimension.
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
-      if (dimensions)
-        dimensions->batch.push_back(outputDim);
       continue;
     }
     if (inputExprWalker.convolvedDims.count(outputDim) &&
@@ -501,8 +680,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
-      if (dimensions)
-        dimensions->outputImage.push_back(outputDim);
       continue;
     }
     if (!inputExprWalker.convolvedDims.count(outputDim) &&
@@ -512,8 +689,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
-      if (dimensions)
-        dimensions->outputChannel.push_back(outputDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -522,21 +697,16 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
-      if (dimensions)
-        dimensions->depth.push_back(outputDim);
       continue;
     }
     return MatchConvolutionResult::NonConvolutionLoop;
   }
   for (auto filterExpr : indexingMaps[1].getResults()) {
-    unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
+    int64_t filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
     if (outputDims.count(filterDim) &&
         !inputExprWalker.unConvolvedDims.count(filterDim) &&
         !inputExprWalker.convolvedDims.count(filterDim)) {
       // Output channel dimension. This is already seen, continue;
-      assert((!dimensions ||
-              llvm::is_contained(dimensions->outputChannel, filterDim)) &&
-             "expected output channel to have been found from output dims");
       continue;
     }
     if (inputExprWalker.convolvedDims.count(filterDim) &&
@@ -547,8 +717,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
       allLoopDims.insert(filterDim);
-      if (dimensions)
-        dimensions->filterLoop.push_back(filterDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
@@ -559,16 +727,11 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
       allLoopDims.insert(filterDim);
-      if (dimensions)
-        dimensions->inputChannel.push_back(filterDim);
       continue;
     }
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
         outputDims.count(filterDim)) {
       // Depthwise loop. Already seen.
-      assert(
-          (!dimensions || llvm::is_contained(dimensions->depth, filterDim)) &&
-          "expected depthwise dimension to have been found from output dims");
       continue;
     }
     return MatchConvolutionResult::NonConvolutionLoop;
@@ -578,12 +741,11 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
     return MatchConvolutionResult::NonConvolutionLoop;
 
   if (dimensions) {
-    assert(dimensions->batch.size() + dimensions->outputImage.size() +
-                   dimensions->outputChannel.size() +
-                   dimensions->filterLoop.size() +
-                   dimensions->inputChannel.size() + dimensions->depth.size() ==
-               linalgOp.getNumLoops() &&
-           "expected all loops to be classified");
+    FailureOr<ConvolutionDimensions> res =
+        inferConvolutionDimsImpl(linalgOp, inputExprWalker,
+                                 /*allowEmptyConvolvedDims=*/true);
+    assert(succeeded(res) && "unexpected failure to infer convolution dims");
+    *dimensions = *res;
   }
 
   return MatchConvolutionResult::Success;

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 2220930e8f4dd0..7b5b4924fe6926 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -259,6 +259,53 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// MatchStructuredClassifyConvolutionDimsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
+    Operation *current, transform::TransformResults &results,
+    transform::TransformState &state) {
+  FailureOr<linalg::ConvolutionDimensions> convolutionDims =
+      linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
+  if (failed(convolutionDims))
+    return emitSilenceableError() << "could not infer convolution dimensions";
+
+  MLIRContext *context = current->getContext();
+  Builder builder(context);
+  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
+    return llvm::to_vector(
+        llvm::map_range(values, [&](unsigned value) -> Attribute {
+          return builder.getI64IntegerAttr(value);
+        }));
+  };
+  results.setParams(getBatch().cast<OpResult>(),
+                    makeI64Attrs(convolutionDims->batch));
+  results.setParams(getOutputImage().cast<OpResult>(),
+                    makeI64Attrs(convolutionDims->outputImage));
+  results.setParams(getOutputChannel().cast<OpResult>(),
+                    makeI64Attrs(convolutionDims->outputChannel));
+  results.setParams(getFilterLoop().cast<OpResult>(),
+                    makeI64Attrs(convolutionDims->filterLoop));
+  results.setParams(getInputChannel().cast<OpResult>(),
+                    makeI64Attrs(convolutionDims->inputChannel));
+  results.setParams(getDepth().cast<OpResult>(),
+                    makeI64Attrs(convolutionDims->depth));
+
+  auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+    return llvm::to_vector(
+        llvm::map_range(values, [&](int64_t value) -> Attribute {
+          return builder.getI64IntegerAttr(value);
+        }));
+  };
+  results.setParams(getStrides().cast<OpResult>(),
+                    makeI64AttrsFromI64(convolutionDims->strides));
+  results.setParams(getDilations().cast<OpResult>(),
+                    makeI64AttrsFromI64(convolutionDims->dilations));
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Utilities for structured match predicates.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 063e1e44fd604c..bad6893eaa99e1 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -934,3 +934,97 @@ module attributes { transform.target_tag = "start_here" } {
     return %result : tensor<40x10x50x15xf32>
   }
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @match_convolution(%arg0: !transform.any_op {transform.readonly})
+    -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+    %1:8 = transform.match.structured %arg0 : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+    ^bb0(%struct: !transform.any_op):
+      transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
+      %0:8 = transform.match.structured.classify_convolution_dims %struct
+        : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
+      transform.match.structured.yield %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7
+        : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+    }
+    transform.yield %arg0, %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+  }
+
+  transform.named_sequence @print_convolution(
+      %op: !transform.any_op {transform.readonly},
+      %batch: !transform.param<i64> {transform.readonly},
+      %oi: !transform.param<i64> {transform.readonly},
+      %oc: !transform.param<i64> {transform.readonly},
+      %fl: !transform.param<i64> {transform.readonly},
+      %ic: !transform.param<i64> {transform.readonly},
+      %depth: !transform.param<i64> {transform.readonly},
+      %strides: !transform.param<i64> {transform.readonly},
+      %dilations: !transform.param<i64> {transform.readonly}) {
+    transform.test_print_remark_at_operand %op, "convolution" : !transform.any_op
+    transform.test_print_param %batch, "batch dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %oi, "output image dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %oc, "output channel dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %fl, "filter loop dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %ic, "input channel dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %depth, "depth dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %strides, "strides" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %dilations, "dilations" at %op : !transform.param<i64>, !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+  ^bb0(%arg0: !transform.any_op):
+    %3 = transform.foreach_match in %arg0 @match_convolution -> @print_convolution : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+module attributes { transform.target_tag = "start_here" } {
+  func.func @convolution_simple(%input: tensor<10x20x30xf32>, %filter: tensor<3x30x15xf32>) -> tensor<10x18x15xf64> {
+    %cst = arith.constant 0.0 : f64
+    %empty = tensor.empty() : tensor<10x18x15xf64>
+    %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x18x15xf64>) -> tensor<10x18x15xf64>
+    // expected-remark @below {{convolution}}
+    // expected-remark @below {{batch dims 0}}
+    // expected-remark @below {{output image dims 1}}
+    // expected-remark @below {{output channel dims 2}}
+    // expected-remark @below {{filter loop dims 3}}
+    // expected-remark @below {{input channel dims 4}}
+    // expected-remark @below {{depth dims}}
+    // expected-remark @below {{strides 1}}
+    // expected-remark @below {{dilations 1}}
+    %result = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
+                                      strides = dense<1> : tensor<1xi64>}
+       ins(%input, %filter: tensor<10x20x30xf32>, tensor<3x30x15xf32>) outs(%fill: tensor<10x18x15xf64>) -> tensor<10x18x15xf64>
+    return %result : tensor<10x18x15xf64>
+  }
+
+  func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> {
+    %cst = arith.constant 0.0 : f32
+    %empty = tensor.empty() : tensor<8x32x32x16xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8x32x32x16xf32>) -> tensor<8x32x32x16xf32>
+    // expected-remark @below {{convolution}}
+    // expected-remark @below {{batch dims}}
+    // expected-remark @below {{output image dims 1 : i64, 2 : i64}}
+    // expected-remark @below {{output channel dims 0 : i64, 3 : i64}}
+    // expected-remark @below {{filter loop dims 5 : i64, 6 : i64}}
+    // expected-remark @below {{input channel dims 4 : i64, 7 : i64}}
+    // expected-remark @below {{depth dims}}
+    // expected-remark @below {{strides 1 : i64, 2 : i64}}
+    // expected-remark @below {{dilations 1 : i64, 1 : i64}}
+    %result = linalg.generic {
+        indexing_maps = [
+            affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1 + d5, 2 * d2 + d6, d7)>,
+            affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d4, d5, d6, d7, d3)>,
+            affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3)>],
+        iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]}
+        ins(%input, %filter : tensor<2x34x68x16xf32>, tensor<8x2x3x5x16x16xf32>) outs(%fill : tensor<8x32x32x16xf32>) {
+          ^bb0(%in: f32, %in_0: f32, %out: f32):
+            %mul = arith.mulf %in, %in_0 : f32
+            %add = arith.addf %mul, %out : f32
+            linalg.yield %add : f32
+          } -> tensor<8x32x32x16xf32>
+    return %result : tensor<8x32x32x16xf32>
+  }
+}


        


More information about the Mlir-commits mailing list