[Mlir-commits] [mlir] eef1bfb - [mlir][Linalg] Conv {1, 2, 3}D ops defined with TC syntax

Alex Zinenko llvmlistbot at llvm.org
Fri Jul 31 04:20:25 PDT 2020


Author: Jakub Lichman
Date: 2020-07-31T13:20:17+02:00
New Revision: eef1bfb2d219191cee16ee24efbf2d204488696c

URL: https://github.com/llvm/llvm-project/commit/eef1bfb2d219191cee16ee24efbf2d204488696c
DIFF: https://github.com/llvm/llvm-project/commit/eef1bfb2d219191cee16ee24efbf2d204488696c.diff

LOG: [mlir][Linalg] Conv {1,2,3}D ops defined with TC syntax

Replaced definition of named ND ConvOps with tensor comprehension
syntax which reduces boilerplate code significantly. Furthermore,
new ops to support TF convolutions added (without strides and dilations).

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D84628

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 056f0723e92d..27d4330a54d5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -17,3 +17,55 @@ ods_def<BatchMatmulOp>:
 def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
 }
+
+ods_def<ConvWOp>:
+def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
+  O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw)));
+}
+
+ods_def<ConvNWCOp>:
+def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
+  O(n, w, f) = std_addf(O(n, w, f),
+    std_mulf(I(n, w + kw, c), K(f, kw, c)));
+}
+
+ods_def<ConvNCWOp>:
+def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
+  O(n, f, w) = std_addf(O(n, f, w),
+    std_mulf(I(n, c, w + kw), K(f, c, kw)));
+}
+
+ods_def<ConvHWOp>:
+def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
+  O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
+}
+
+ods_def<ConvNHWCOp>:
+def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
+  O(n, h, w, f) = std_addf(O(n, h, w, f),
+    std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
+}
+
+ods_def<ConvNCHWOp>:
+def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
+  O(n, f, h, w) = std_addf(O(n, f, h, w),
+    std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
+}
+
+ods_def<ConvDHWOp>:
+def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
+  O(d, h, w) = std_addf(O(d, h, w),
+    std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
+}
+
+ods_def<ConvNDHWCOp>:
+def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
+  O(n, d, h, w, f) = std_addf(O(n, d, h, w, f),
+    std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
+}
+
+ods_def<ConvNCDHWOp>:
+def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
+  O(n, f, d, h, w) = std_addf(O(n, f, d, h, w),
+    std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
+}
\ No newline at end of file

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 75e6599bf9fe..21bff4185abf 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -85,14 +85,6 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
 SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
                                   ArrayRef<AffineExpr> b);
 
-/// Generates indexing maps for convolution with the following structure:
-/// input:   (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r)
-/// kernel:  (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r)
-/// output:  (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r)
-/// where r is the rank of the input, kernel and output
-llvm::Optional<SmallVector<AffineMap, 8>>
-createConvNDIndexingMaps(MLIRContext *context, unsigned rank);
-
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc"
 
 #define GET_OP_CLASSES

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 84ae8e440bee..1e3321af981e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -180,131 +180,6 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
   let hasFolder = 1;
 }
 
-class ConvOpBase<string mnemonic, int N>
-  : LinalgStructured_Op<mnemonic, [NInputs<2>, NOutputs<1>]> {
-  let description = [{
-    Base operation for any N-D Convolution implemented as a linalg.generic op.
-
-    Usage:
-
-    ```mlir
-    linalg.conv<N>D(%in, %filter, %out) : memref<(?x)+f32>,
-                                          memref<(?x)+f32>,
-                                          memref<(?x)+f32>
-    ```
-
-    where    %in:     input array
-             %filter: kernel or filter that will be applied on the input array
-             %out:    output array
-
-    and rank of the operands is *N*.
-
-    Every child convolution is expressed as:
-
-    ```mlir
-    #conv_trait = {
-      args_in = 2,
-      args_out = 1,
-      indexing_maps = #conv_accesses,
-      library_call  = "linalg_conv",
-      iterator_types = [("parallel", "parallel")+], // `2 * rank` iterators
-    }
-
-    linalg.generic #conv_trait %in, %filter, %out {
-      ^bb0(%a: f32, %b: f32, %c: f32) :
-        %d = mulf %a, %b : f32
-        %e = addf %c, %d : f32
-        linalg.yield %e : f32
-    } : memref<(?x)+f32>,
-        memref<(?x)+f32>,
-        memref<(?x)+f32>
-    ```
-
-    where #conv_accesses depend on the rank of the operands and thus
-    can be found in the documentation of each N-D case.
-    Please note that the input array is expected to be right-padded i.e.
-    the size of the input is greater than or equal to the size of the output
-    + size of the kernel - 1. If it is not padded the behavior of the op
-    is undefined.
-  }];
-
-  let arguments = (ins AnyStridedMemRefOfRank<N>,
-                       AnyStridedMemRefOfRank<N>,
-                       AnyStridedMemRefOfRank<N>);
-
-  let extraClassDeclaration = libraryCallName # [{
-    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
-      // There are always 2 loops for each dimension of the convolution. First
-      // iterates output and second kernel. Since ranks of all 3 operands must
-      // be the same it does not matter which operand is picked to get the rank.
-      // Loops iterating the output can be parallelized and thus are marked as
-      // "parallel" while loops iterating the kernel are accumulating the
-      // products and therefore are marked as "reduction".
-      unsigned rank = getInputShapedType(0).getRank();
-      SmallVector<StringRef, 8> parallel(rank, getParallelIteratorTypeName());
-      SmallVector<StringRef, 8> reduction(rank, getReductionIteratorTypeName());
-      parallel.insert(parallel.end(), reduction.begin(), reduction.end());
-      return parallel;
-    }
-
-    // Generates indexing maps with the following structure:
-    // input:   (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r)
-    // kernel:  (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r)
-    // output:  (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r)
-    // where r is the rank of the input, kernel and output
-    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
-      MLIRContext *context = getContext();
-      unsigned rank = getInputShapedType(0).getRank();
-      return createConvNDIndexingMaps(context, rank);
-    }
-  }];
-
-  let hasFolder = 1;
-  let verifier = [{ return ::verify(*this); }];
-}
-
-def Conv1DOp : ConvOpBase<"conv1D", 1> {
-  let description = [{
-    *1D* convolution which uses following affine maps to access operands:
-
-    ```mlir
-    #conv_accesses = [
-      affine_map<(m, n) -> (m + n)>, // in
-      affine_map<(m, n) -> (n)>, // kernel
-      affine_map<(m, n) -> (m)> // out
-    ]
-    ```
-  }];
-}
-
-def Conv2DOp : ConvOpBase<"conv2D", 2> {
-  let description = [{
-    *2D* convolution which uses following affine maps to access operands:
-
-    ```mlir
-    #conv_accesses = [
-      affine_map<(m1, m2, n1, n2) -> (m1 + n1, m2 + n2)>, // in
-      affine_map<(m1, m2, n1, n2) -> (n1, n2)>, // kernel
-      affine_map<(m1, m2, n1, n2) -> (m1, m2) // out
-    ]
-    ```
-  }];
-}
-
-def Conv3DOp : ConvOpBase<"conv3D", 3> {
-  let description = [{
-    *3D* convolution which uses following affine maps to access operands:
-
-    ```mlir
-    #conv_accesses = [
-      affine_map<(m1, m2, m3, n1, n2, n3) -> (m1 + n1, m2 + n2, m3 + n3)>, // in
-      affine_map<(m1, m2, m3, n1, n2, n3) -> (n1, n2, n3)>, // kernel
-      affine_map<(m1, m2, m3, n1, n2, n3) -> (m1, m2, m3)> // out
-    ]
-    ```
-  }];
-}
-
 /// A base class for pooling operation such as conv. The arguments must contain
 /// optional arguments `strides`, `dilations` and `padding` with following type:
 ///   OptionalAttr<I64ArrayAttr>:$strides

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 921445bd03b1..55ffa3f8b6e6 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -236,9 +236,6 @@ void mlir::populateLinalgToStandardConversionPatterns(
       LinalgOpConversion<PoolingMinOp>,
       LinalgOpConversion<PoolingSumOp>,
       LinalgOpConversion<CopyOp>,
-      LinalgOpConversion<Conv1DOp>,
-      LinalgOpConversion<Conv2DOp>,
-      LinalgOpConversion<Conv3DOp>,
       LinalgOpConversion<FillOp>,
       LinalgOpConversion<GenericOp>,
       LinalgOpConversion<IndexedGenericOp>>(ctx);

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e67adf8c2042..03bd71f17716 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -986,17 +986,6 @@ static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
   return success();
 }
 
-template <typename ConvNDOp>
-static LogicalResult verify(ConvNDOp op) {
-  auto outputType = op.getOutputShapedType(0).getElementType();
-  auto inputType = op.getInputShapedType(0).getElementType();
-  auto kernelType = op.getInputShapedType(1).getElementType();
-  if (outputType != inputType || inputType != kernelType)
-    return op.emitOpError("expected all element types of operands to match");
-
-  return success();
-}
-
 static LogicalResult verify(ConvOp op) {
   auto oType = op.output().getType().cast<MemRefType>();
   auto fType = op.filter().getType().cast<MemRefType>();
@@ -1107,27 +1096,6 @@ mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
   return res;
 }
 
-llvm::Optional<SmallVector<AffineMap, 8>>
-mlir::linalg::createConvNDIndexingMaps(MLIRContext *context, unsigned rank) {
-  unsigned numDims = rank * 2, idx = 0;
-
-  SmallVector<AffineExpr, 8> dims, in, kernel, out;
-  dims = makeAffineDimExprs(numDims, idx, context);
-  in.reserve(rank);
-  kernel.reserve(rank);
-  out.reserve(rank);
-
-  for (unsigned i = 0; i < rank; i++) {
-    in.push_back(dims[i] + dims[rank + i]);
-    kernel.push_back(dims[rank + i]);
-    out.push_back(dims[i]);
-  }
-
-  return SmallVector<AffineMap, 8>{AffineMap::get(numDims, 0, in, context),
-                                   AffineMap::get(numDims, 0, kernel, context),
-                                   AffineMap::get(numDims, 0, out, context)};
-}
-
 #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE)                      \
   template SmallVector<AffineExpr, 4>                                          \
   mlir::linalg::weightedPoolingInputIndex<OP_TYPE>(                            \
@@ -1209,18 +1177,6 @@ LogicalResult FillOp::fold(ArrayRef<Attribute>,
                            SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
-LogicalResult Conv1DOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult Conv2DOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult Conv3DOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
                               SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
@@ -1362,3 +1318,39 @@ LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
                              SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
+LogicalResult ConvWOp::fold(ArrayRef<Attribute>,
+                            SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNWCOp::fold(ArrayRef<Attribute>,
+                              SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNCWOp::fold(ArrayRef<Attribute>,
+                              SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvHWOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNHWCOp::fold(ArrayRef<Attribute>,
+                               SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNCHWOp::fold(ArrayRef<Attribute>,
+                               SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvDHWOp::fold(ArrayRef<Attribute>,
+                              SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNDHWCOp::fold(ArrayRef<Attribute>,
+                                SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult ConvNCDHWOp::fold(ArrayRef<Attribute>,
+                                SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index db29835e2caa..281edd9a91f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -295,61 +295,6 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
   nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
 }
 
-/// Following functions emit scalar part of the N-D convolution op.
-/// N-D convolution has 2N loops:
-///   1-N: Iterate over the output array *O* with iterators *m1, ..., mN*.
-///   N-2N:. Iterate over the kernel *K* with iterators *n1, ..., nN*.
-///
-/// The scalar part accumulates products of input array *I* values with kernel
-/// ones. The accumulation expression therefore looks like:
-///   O[m1, ..., mN] += I[m1 + n1, ..., mN + nN] * K[n1, ..., nN].
-/// Note that the input array has to be padded in order to prevent
-/// out of bounds accesses.
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, Conv1DOp convOp) {
-  assert(convOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(allIvs.size() == 2);
-  Value m1(allIvs[0]);
-  Value n1(allIvs[1]);
-  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
-      O(convOp.getOutputBuffer(0));
-  // Emit scalar form for the 1D conv case.
-  Value i1 = m1 + n1;
-  O(m1) = O(m1) + I(i1) * K(n1);
-}
-
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, Conv2DOp convOp) {
-  assert(convOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(allIvs.size() == 4);
-  Value m1(allIvs[0]), m2(allIvs[1]);
-  Value n1(allIvs[2]), n2(allIvs[3]);
-  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
-      O(convOp.getOutputBuffer(0));
-  // Emit scalar form for the 2D conv case.
-  Value i1 = m1 + n1;
-  Value i2 = m2 + n2;
-  O(m1, m2) = O(m1, m2) + I(i1, i2) * K(n1, n2);
-}
-
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, Conv3DOp convOp) {
-  assert(convOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(allIvs.size() == 6);
-  Value m1(allIvs[0]), m2(allIvs[1]), m3(allIvs[2]);
-  Value n1(allIvs[3]), n2(allIvs[4]), n3(allIvs[5]);
-  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
-      O(convOp.getOutputBuffer(0));
-  // Emit scalar form for the 3D conv case.
-  Value i1 = m1 + n1;
-  Value i2 = m2 + n2;
-  Value i3 = m3 + n3;
-  O(m1, m2, m3) = O(m1, m2, m3) + I(i1, i2, i3) * K(n1, n2, n3);
-}
-
 template <typename IndexedValueType>
 Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
                      MutableArrayRef<Value> imIdx) {
@@ -738,6 +683,24 @@ static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
     return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
   if (isa<BatchMatmulOp>(op))
     return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
+  if (isa<ConvWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
+  if (isa<ConvNWCOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
+  if (isa<ConvNCWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
+  if (isa<ConvHWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
+  if (isa<ConvNHWCOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
+  if (isa<ConvNCHWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
+  if (isa<ConvDHWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
+  if (isa<ConvNDHWCOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
+  if (isa<ConvNCDHWOp>(op))
+    return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
   llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
 }
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a5a6e9bee34f..ca59ecd387ec 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -507,11 +507,3 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
   linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
   return
 }
-
-// -----
-
-func @conv_type_mismatch(%in: memref<?xi32>, %filter: memref<?xf32>, %out: memref<?xf32>) {
-  // expected-error @+1 {{expected all element types of operands to match}}
-  linalg.conv1D(%in, %filter, %out) : memref<?xi32>, memref<?xf32>, memref<?xf32>
-  return
-}

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index ee63d59ca8c4..6af53a2b8d22 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1288,7 +1288,7 @@ func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out :  m
 //       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
 
 func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
-  linalg.conv1D(%in, %filter, %out) : memref<?xf32>, memref<?xf32>, memref<?xf32>
+  linalg.conv_1d %in, %filter, %out : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
   return
 }
 
@@ -1303,10 +1303,10 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
 //       CHECKLOOP: scf.for %[[b:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
 //       CHECKLOOP:   scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
 //       CHECKLOOP:     %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
-//       CHECKLOOP:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
 //       CHECKLOOP:     %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
-//       CHECKLOOP:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKLOOP:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
 //       CHECKLOOP:     %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKLOOP:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKLOOP:     %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKLOOP:     store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
 
@@ -1318,19 +1318,18 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
 //       CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
 //       CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
 //       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) {
-//       CHECKPARALLEL:   scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//       CHECKPARALLEL:     %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
-//       CHECKPARALLEL:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
-//       CHECKPARALLEL:     %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
-//       CHECKPARALLEL:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:     %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
-//       CHECKPARALLEL:     %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:     store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
+//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
 
 
 func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
-  linalg.conv2D(%in, %filter, %out) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+  linalg.conv_2d %in, %filter, %out : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
   return
 }
 // CHECKLOOP-LABEL: @conv2d_no_symbols
@@ -1349,10 +1348,12 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
 //       CHECKLOOP:       scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
 //       CHECKLOOP:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
 //       CHECKLOOP:         %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
-//       CHECKLOOP:         %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
 //       CHECKLOOP:         %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
-//       CHECKLOOP:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+
+//       CHECKLOOP:         %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
 //       CHECKLOOP:         %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+
+//       CHECKLOOP:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKLOOP:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKLOOP:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
 
@@ -1366,21 +1367,19 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
 //       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
 //       CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
 //       CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) {
-//       CHECKPARALLEL:   scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//       CHECKPARALLEL:     scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-//       CHECKPARALLEL:       %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
-//       CHECKPARALLEL:       %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
-//       CHECKPARALLEL:       %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
-//       CHECKPARALLEL:       %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
-//       CHECKPARALLEL:       %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:       %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
-//       CHECKPARALLEL:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:       store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
+//       CHECKPARALLEL:   %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
+//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
 
 
 func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
-  linalg.conv3D(%in, %filter, %out) : memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>
+  linalg.conv_3d %in, %filter, %out : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
   return
 }
 
@@ -1406,10 +1405,12 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
 //       CHECKLOOP:             %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
 //       CHECKLOOP:             %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
 //       CHECKLOOP:             %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
-//       CHECKLOOP:             %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
 //       CHECKLOOP:             %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
-//       CHECKLOOP:             %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+
+//       CHECKLOOP:             %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
 //       CHECKLOOP:             %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+
+//       CHECKLOOP:             %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKLOOP:             %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKLOOP:             store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
 
@@ -1426,16 +1427,13 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
 //       CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
 //       CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
 //       CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
-//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) {
-//       CHECKPARALLEL:   scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-//       CHECKPARALLEL:     scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-//       CHECKPARALLEL:       scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
-//       CHECKPARALLEL:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
-//       CHECKPARALLEL:         %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
-//       CHECKPARALLEL:         %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
-//       CHECKPARALLEL:         %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:         %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
-//       CHECKPARALLEL:         %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
-//       CHECKPARALLEL:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
-//       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
+//       CHECKPARALLEL:   %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
+//       CHECKPARALLEL:   %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
+//       CHECKPARALLEL:   %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:   %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:   %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>


        


More information about the Mlir-commits mailing list