[Mlir-commits] [mlir] [Linalg] Add *Conv1D* matchers (PR #168050)

Abhishek Varma llvmlistbot at llvm.org
Sun Nov 16 22:42:00 PST 2025


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/168050

>From 9eb2d484003a917a533f385ec32c23b6e2a6bdd1 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 14 Nov 2025 06:08:20 -0600
Subject: [PATCH 1/2] [Linalg] Add *Conv1D* matchers

-- This commit is the second in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: https://github.com/llvm/llvm-project/pull/163724
-- In this commit all variants of Conv1D convolution ops have been
   added.
-- For sake of completion for a specific infra required for those
   ops which don't require dilations/strides information during their
   creation, this commit also includes a basic Conv2D and Conv3D op as
   part of the lit test.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  28 +-
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       | 300 +++++++++++++++++-
 .../convolution/roundtrip-convolution.mlir    |  88 ++++-
 3 files changed, 409 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 249a74b007dce..c2485a08932dd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -245,14 +245,22 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
                    ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
   SmallVector<Value> inputs = genericOp.getDpsInputs();
   ValueRange outputs = genericOp.getDpsInits();
-  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
   SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
                                       ? TypeRange(ValueRange(outputs))
                                       : TypeRange{};
-  Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
-  Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
-  LinalgOp namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
-      genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+  LinalgOp namedOp;
+  // Ops with no dilations and no strides.
+  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+                                                    inputs, outputs);
+  } else {
+    Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+    Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+        genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+  }
   return namedOp;
 }
 
@@ -265,9 +273,19 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
     return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations,        \
                                         strides);                              \
   // -----------------------------
+  // Convolution ops.
+  // -----------------------------
+  CONV_OP_SPECIALIZER(linalg::Conv1DOp);
+  CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
+  CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+  CONV_OP_SPECIALIZER(linalg::Conv3DOp);
+  // -----------------------------
   // Depthwise Convolution ops.
   // -----------------------------
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
   // -----------------------------
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5dd5e1b055f0d..6b85e6ba0ede2 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -390,7 +390,7 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
   unsigned inputMapIdx = 0, filterMapIdx = 1,
            outputMapIdx = indexingMaps.size() - 1;
   AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
-  auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+  auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
   if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
     return false;
 
@@ -434,6 +434,263 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
                           })));
 }
 
+// #inputMap = affine_map<(W, w) -> (W + w)>
+// #filterMap = affine_map<(W, w) -> (w)>
+// #outputMap = affine_map<(W, w) -> (W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
+                                              SmallVector<int64_t> *dilations,
+                                              SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv1DOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(1, 1);
+  *strides = SmallVector<int64_t>(1, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr W = getAffineDimExpr(0, context);
+  AffineExpr w = getAffineDimExpr(1, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]},
+           /*filterMap=*/{w},
+           /*outputMap=*/{W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
+// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)>
+// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv1DNwcWcfOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(1, 1);
+  *strides = SmallVector<int64_t>(1, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr W = getAffineDimExpr(1, context);
+  AffineExpr F = getAffineDimExpr(2, context);
+  AffineExpr w = getAffineDimExpr(3, context);
+  AffineExpr c = getAffineDimExpr(4, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c},
+           /*filterMap=*/{w, c, F},
+           /*outputMap=*/{N, W, F}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
+// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)>
+// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv1DNcwFcwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(1, 1);
+  *strides = SmallVector<int64_t>(1, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr F = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr c = getAffineDimExpr(3, context);
+  AffineExpr w = getAffineDimExpr(4, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]},
+           /*filterMap=*/{F, c, w},
+           /*outputMap=*/{N, F, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(H, W, h, w) -> (H + h, W + w)>
+// #filterMap = affine_map<(H, W, h, w) -> (h, w)>
+// #outputMap = affine_map<(H, W, h, w) -> (H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
+                                              SmallVector<int64_t> *dilations,
+                                              SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr H = getAffineDimExpr(0, context);
+  AffineExpr W = getAffineDimExpr(1, context);
+  AffineExpr h = getAffineDimExpr(2, context);
+  AffineExpr w = getAffineDimExpr(3, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{h, w},
+           /*outputMap=*/{H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
+// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
+// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
+                                              SmallVector<int64_t> *dilations,
+                                              SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv3DOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(3, 1);
+  *strides = SmallVector<int64_t>(3, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr D = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr d = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: D * stride + d * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[2], (*strides)[2]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
+                         H * (*strides)[1] + h * (*dilations)[1],
+                         W * (*strides)[2] + w * (*dilations)[2]},
+           /*filterMap=*/{d, h, w},
+           /*outputMap=*/{D, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, W, C, w) -> (N, C, W + w)>
+// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(1, 1);
+  *strides = SmallVector<int64_t>(1, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr W = getAffineDimExpr(1, context);
+  AffineExpr C = getAffineDimExpr(2, context);
+  AffineExpr w = getAffineDimExpr(3, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
+           /*filterMap=*/{C, w},
+           /*outputMap=*/{N, C, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
 // #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
 // #filterMap = affine_map<(N, W, C, w) -> (w, C)>
 // #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
@@ -474,6 +731,47 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
+// #inputMap  = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
+// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(1, 1);
+  *strides = SmallVector<int64_t>(1, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr W = getAffineDimExpr(1, context);
+  AffineExpr C = getAffineDimExpr(2, context);
+  AffineExpr CM = getAffineDimExpr(3, context);
+  AffineExpr w = getAffineDimExpr(4, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
+           /*filterMap=*/{w, C, CM},
+           /*outputMap=*/{N, W, C, CM}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
 // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
 // #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
 // #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 8f22cc749bee9..131bcd81e580c 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -1,8 +1,81 @@
 // The following test examples of linalg convolution named ops lowered to linalg.generic and then
 // lifted back up to named op.
+// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test.
+
 // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic
 
-// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test. 
+func.func @conv_1d(%in : tensor<?xf32>, %filter : tensor<?xf32>, %out : tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.conv_1d
+         ins(%in, %filter : tensor<?xf32>, tensor<?xf32>)
+         outs(%out : tensor<?xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+//      CHECK: @conv_1d
+//      CHECK:   linalg.conv_1d
+
+// -----
+
+func.func @conv_1d_nwc_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.conv_1d_nwc_wcf
+         {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+         outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @conv_1d_nwc_wcf
+//      CHECK:   linalg.conv_1d_nwc_wcf
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
+func.func @conv_1d_ncw_fcw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.conv_1d_ncw_fcw
+         {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+         outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @conv_1d_ncw_fcw
+//      CHECK:   linalg.conv_1d_ncw_fcw
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
+func.func @conv_2d(%in : tensor<?x?xf32>, %filter : tensor<?x?xf32>, %out : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.conv_2d
+         ins(%in, %filter : tensor<?x?xf32>, tensor<?x?xf32>)
+         outs(%out: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//      CHECK: @conv_2d
+//      CHECK:   linalg.conv_2d
+
+// -----
+
+func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.conv_3d
+         ins(%in, %filter : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+         outs(%out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @conv_3d
+//      CHECK:   linalg.conv_3d
+
+// -----
+
+func.func @depthwise_conv_1d_ncw_cw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.depthwise_conv_1d_ncw_cw
+         {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?xf32>)
+         outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_1d_ncw_cw
+//      CHECK:   linalg.depthwise_conv_1d_ncw_cw
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+
+// -----
+
 func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: tensor<3x8xi8>, %output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> {
   %0 = linalg.depthwise_conv_1d_nwc_wc 
          {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
@@ -16,6 +89,19 @@ func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: t
 
 // -----
 
+func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_1d_nwc_wcm
+         {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+         ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+         outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_1d_nwc_wcm
+//      CHECK:   linalg.depthwise_conv_1d_nwc_wcm
+// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+
+// -----
+
 func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf16>, %filter: tensor<?x?x?xf16>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
   %0 = linalg.depthwise_conv_2d_nchw_chw
          {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>}

>From c7eebaef9255236cf215a37a6baaecec85fd7472 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 17 Nov 2025 00:41:32 -0600
Subject: [PATCH 2/2] Review comments v1.0

---
 .../Linalg/convolution/roundtrip-convolution.mlir        | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 131bcd81e580c..4b2d42a3ae4e0 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -4,6 +4,9 @@
 
 // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic
 
+// -----------------------------
+// Convolution ops.
+// -----------------------------
 func.func @conv_1d(%in : tensor<?xf32>, %filter : tensor<?xf32>, %out : tensor<?xf32>) -> tensor<?xf32> {
   %0 = linalg.conv_1d
          ins(%in, %filter : tensor<?xf32>, tensor<?xf32>)
@@ -63,6 +66,9 @@ func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out :
 
 // -----
 
+// -----------------------------
+// Depthwise Convolution ops.
+// -----------------------------
 func.func @depthwise_conv_1d_ncw_cw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
   %0 = linalg.depthwise_conv_1d_ncw_cw
          {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
@@ -128,6 +134,9 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter:
 
 // -----
 
+// -----------------------------
+// Pooling ops.
+// -----------------------------
 func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
   %0 = linalg.pooling_nhwc_max
          {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}



More information about the Mlir-commits mailing list