[Mlir-commits] [mlir] 1aaf8aa - [mlir][Linalg] Conv1D, Conv2D and Conv3D added as named ops

Alex Zinenko llvmlistbot at llvm.org
Wed Jul 29 07:40:04 PDT 2020


Author: Jakub Lichman
Date: 2020-07-29T16:39:56+02:00
New Revision: 1aaf8aa53d694309087b322861038130490bdd5e

URL: https://github.com/llvm/llvm-project/commit/1aaf8aa53d694309087b322861038130490bdd5e
DIFF: https://github.com/llvm/llvm-project/commit/1aaf8aa53d694309087b322861038130490bdd5e.diff

LOG: [mlir][Linalg] Conv1D, Conv2D and Conv3D added as named ops

This commit is part of a greater project which aims to add
full end-to-end support for convolutions inside mlir. The
reason behind having conv ops for each rank rather than
having one generic ConvOp is to enable better optimizations
for every N-D case which reflects memory layout of input/kernel
buffers better and simplifies code as well. We expect plain linalg.conv
to be progressively retired.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D83879

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 21bff4185abf..75e6599bf9fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -85,6 +85,14 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
 SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
                                   ArrayRef<AffineExpr> b);
 
+/// Generates indexing maps for convolution with the following structure:
+/// input:   (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r)
+/// kernel:  (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r)
+/// output:  (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r)
+/// where r is the rank of the input, kernel and output
+llvm::Optional<SmallVector<AffineMap, 8>>
+createConvNDIndexingMaps(MLIRContext *context, unsigned rank);
+
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc"
 
 #define GET_OP_CLASSES

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 1e3321af981e..84ae8e440bee 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -180,6 +180,131 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
   let hasFolder = 1;
 }
 
+class ConvOpBase<string mnemonic, int N>
+  : LinalgStructured_Op<mnemonic, [NInputs<2>, NOutputs<1>]> {
+  let description = [{
+    Base operation for any N-D Convolution implemented as a linalg.generic op.
+
+    Usage:
+
+    ```mlir
+    linalg.conv<N>D(%in, %filter, %out) : memref<(?x)+f32>,
+                                          memref<(?x)+f32>,
+                                          memref<(?x)+f32>
+    ```
+
+    where    %in:     input array
+             %filter: kernel or filter that will be applied on the input array
+             %out:    output array
+
+    and rank of the operands is *N*.
+
+    Every child convolution is expressed as:
+
+    ```mlir
+    #conv_trait = {
+      args_in = 2,
+      args_out = 1,
+      indexing_maps = #conv_accesses,
+      library_call  = "linalg_conv",
+      iterator_types = [("parallel", "parallel")+], // `2 * rank` iterators
+    }
+
+    linalg.generic #conv_trait %in, %filter, %out {
+      ^bb0(%a: f32, %b: f32, %c: f32) :
+        %d = mulf %a, %b : f32
+        %e = addf %c, %d : f32
+        linalg.yield %e : f32
+    } : memref<(?x)+f32>,
+        memref<(?x)+f32>,
+        memref<(?x)+f32>
+    ```
+
+    where #conv_accesses depend on the rank of the operands and thus
+    can be found in the documentation of each N-D case.
+    Please note that the input array is expected to be right-padded i.e.
+    the size of the input is greater than or equal to the size of the output
+    + size of the kernel - 1. If it is not padded the behavior of the op
+    is undefined.
+  }];
+
+  let arguments = (ins AnyStridedMemRefOfRank<N>,
+                       AnyStridedMemRefOfRank<N>,
+                       AnyStridedMemRefOfRank<N>);
+
+  let extraClassDeclaration = libraryCallName # [{
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+      // There are always 2 loops for each dimension of the convolution. First
+      // iterates output and second kernel. Since ranks of all 3 operands must
+      // be the same it does not matter which operand is picked to get the rank.
+      // Loops iterating the output can be parallelized and thus are marked as
+      // "parallel" while loops iterating the kernel are accumulating the
+      // products and therefore are marked as "reduction".
+      unsigned rank = getInputShapedType(0).getRank();
+      SmallVector<StringRef, 8> parallel(rank, getParallelIteratorTypeName());
+      SmallVector<StringRef, 8> reduction(rank, getReductionIteratorTypeName());
+      parallel.insert(parallel.end(), reduction.begin(), reduction.end());
+      return parallel;
+    }
+
+    // Generates indexing maps with the following structure:
+    // input:   (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r)
+    // kernel:  (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r)
+    // output:  (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r)
+    // where r is the rank of the input, kernel and output
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      MLIRContext *context = getContext();
+      unsigned rank = getInputShapedType(0).getRank();
+      return createConvNDIndexingMaps(context, rank);
+    }
+  }];
+
+  let hasFolder = 1;
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def Conv1DOp : ConvOpBase<"conv1D", 1> {
+  let description = [{
+    *1D* convolution which uses following affine maps to access operands:
+
+    ```mlir
+    #conv_accesses = [
+      affine_map<(m, n) -> (m + n)>, // in
+      affine_map<(m, n) -> (n)>, // kernel
+      affine_map<(m, n) -> (m)> // out
+    ]
+    ```
+  }];
+}
+
+def Conv2DOp : ConvOpBase<"conv2D", 2> {
+  let description = [{
+    *2D* convolution which uses following affine maps to access operands:
+
+    ```mlir
+    #conv_accesses = [
+      affine_map<(m1, m2, n1, n2) -> (m1 + n1, m2 + n2)>, // in
+      affine_map<(m1, m2, n1, n2) -> (n1, n2)>, // kernel
+      affine_map<(m1, m2, n1, n2) -> (m1, m2) // out
+    ]
+    ```
+  }];
+}
+
+def Conv3DOp : ConvOpBase<"conv3D", 3> {
+  let description = [{
+    *3D* convolution which uses following affine maps to access operands:
+
+    ```mlir
+    #conv_accesses = [
+      affine_map<(m1, m2, m3, n1, n2, n3) -> (m1 + n1, m2 + n2, m3 + n3)>, // in
+      affine_map<(m1, m2, m3, n1, n2, n3) -> (n1, n2, n3)>, // kernel
+      affine_map<(m1, m2, m3, n1, n2, n3) -> (m1, m2, m3)> // out
+    ]
+    ```
+  }];
+}
+
 /// A base class for pooling operation such as conv. The arguments must contain
 /// optional arguments `strides`, `dilations` and `padding` with following type:
 ///   OptionalAttr<I64ArrayAttr>:$strides

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 8a54c93d7685..921445bd03b1 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -235,7 +235,10 @@ void mlir::populateLinalgToStandardConversionPatterns(
       LinalgOpConversion<PoolingMaxOp>,
       LinalgOpConversion<PoolingMinOp>,
       LinalgOpConversion<PoolingSumOp>,
-      LinalgOpConversion<CopyOp>, 
+      LinalgOpConversion<CopyOp>,
+      LinalgOpConversion<Conv1DOp>,
+      LinalgOpConversion<Conv2DOp>,
+      LinalgOpConversion<Conv3DOp>,
       LinalgOpConversion<FillOp>,
       LinalgOpConversion<GenericOp>,
       LinalgOpConversion<IndexedGenericOp>>(ctx);

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 192179b3ff50..e67adf8c2042 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -986,6 +986,17 @@ static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
   return success();
 }
 
+template <typename ConvNDOp>
+static LogicalResult verify(ConvNDOp op) {
+  auto outputType = op.getOutputShapedType(0).getElementType();
+  auto inputType = op.getInputShapedType(0).getElementType();
+  auto kernelType = op.getInputShapedType(1).getElementType();
+  if (outputType != inputType || inputType != kernelType)
+    return op.emitOpError("expected all element types of operands to match");
+
+  return success();
+}
+
 static LogicalResult verify(ConvOp op) {
   auto oType = op.output().getType().cast<MemRefType>();
   auto fType = op.filter().getType().cast<MemRefType>();
@@ -1096,6 +1107,27 @@ mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
   return res;
 }
 
+llvm::Optional<SmallVector<AffineMap, 8>>
+mlir::linalg::createConvNDIndexingMaps(MLIRContext *context, unsigned rank) {
+  unsigned numDims = rank * 2, idx = 0;
+
+  SmallVector<AffineExpr, 8> dims, in, kernel, out;
+  dims = makeAffineDimExprs(numDims, idx, context);
+  in.reserve(rank);
+  kernel.reserve(rank);
+  out.reserve(rank);
+
+  for (unsigned i = 0; i < rank; i++) {
+    in.push_back(dims[i] + dims[rank + i]);
+    kernel.push_back(dims[rank + i]);
+    out.push_back(dims[i]);
+  }
+
+  return SmallVector<AffineMap, 8>{AffineMap::get(numDims, 0, in, context),
+                                   AffineMap::get(numDims, 0, kernel, context),
+                                   AffineMap::get(numDims, 0, out, context)};
+}
+
 #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE)                      \
   template SmallVector<AffineExpr, 4>                                          \
   mlir::linalg::weightedPoolingInputIndex<OP_TYPE>(                            \
@@ -1177,6 +1209,18 @@ LogicalResult FillOp::fold(ArrayRef<Attribute>,
                            SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);
 }
+LogicalResult Conv1DOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult Conv2DOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult Conv3DOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
                               SmallVectorImpl<OpFoldResult> &) {
   return foldMemRefCast(*this);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 32e50cb597d7..db29835e2caa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -295,6 +295,61 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
   nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
 }
 
+/// Following functions emit scalar part of the N-D convolution op.
+/// N-D convolution has 2N loops:
+///   1-N: Iterate over the output array *O* with iterators *m1, ..., mN*.
+///   N-2N:. Iterate over the kernel *K* with iterators *n1, ..., nN*.
+///
+/// The scalar part accumulates products of input array *I* values with kernel
+/// ones. The accumulation expression therefore looks like:
+///   O[m1, ..., mN] += I[m1 + n1, ..., mN + nN] * K[n1, ..., nN].
+/// Note that the input array has to be padded in order to prevent
+/// out of bounds accesses.
+template <typename IndexedValueType>
+void emitScalarImplementation(ArrayRef<Value> allIvs, Conv1DOp convOp) {
+  assert(convOp.hasBufferSemantics() &&
+         "expected linalg op with buffer semantics");
+  assert(allIvs.size() == 2);
+  Value m1(allIvs[0]);
+  Value n1(allIvs[1]);
+  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
+      O(convOp.getOutputBuffer(0));
+  // Emit scalar form for the 1D conv case.
+  Value i1 = m1 + n1;
+  O(m1) = O(m1) + I(i1) * K(n1);
+}
+
+template <typename IndexedValueType>
+void emitScalarImplementation(ArrayRef<Value> allIvs, Conv2DOp convOp) {
+  assert(convOp.hasBufferSemantics() &&
+         "expected linalg op with buffer semantics");
+  assert(allIvs.size() == 4);
+  Value m1(allIvs[0]), m2(allIvs[1]);
+  Value n1(allIvs[2]), n2(allIvs[3]);
+  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
+      O(convOp.getOutputBuffer(0));
+  // Emit scalar form for the 2D conv case.
+  Value i1 = m1 + n1;
+  Value i2 = m2 + n2;
+  O(m1, m2) = O(m1, m2) + I(i1, i2) * K(n1, n2);
+}
+
+template <typename IndexedValueType>
+void emitScalarImplementation(ArrayRef<Value> allIvs, Conv3DOp convOp) {
+  assert(convOp.hasBufferSemantics() &&
+         "expected linalg op with buffer semantics");
+  assert(allIvs.size() == 6);
+  Value m1(allIvs[0]), m2(allIvs[1]), m3(allIvs[2]);
+  Value n1(allIvs[3]), n2(allIvs[4]), n3(allIvs[5]);
+  IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
+      O(convOp.getOutputBuffer(0));
+  // Emit scalar form for the 3D conv case.
+  Value i1 = m1 + n1;
+  Value i2 = m2 + n2;
+  Value i3 = m3 + n3;
+  O(m1, m2, m3) = O(m1, m2, m3) + I(i1, i2, i3) * K(n1, n2, n3);
+}
+
 template <typename IndexedValueType>
 Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
                      MutableArrayRef<Value> imIdx) {

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index ca59ecd387ec..a5a6e9bee34f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -507,3 +507,11 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
   linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
   return
 }
+
+// -----
+
+func @conv_type_mismatch(%in: memref<?xi32>, %filter: memref<?xf32>, %out: memref<?xf32>) {
+  // expected-error @+1 {{expected all element types of operands to match}}
+  linalg.conv1D(%in, %filter, %out) : memref<?xi32>, memref<?xf32>, memref<?xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index b01beb7e8f17..ee63d59ca8c4 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1286,3 +1286,156 @@ func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out :  m
 //       CHECKPARALLEL:   %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
 //       CHECKPARALLEL:   %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKPARALLEL:   store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
+
+func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
+  linalg.conv1D(%in, %filter, %out) : memref<?xf32>, memref<?xf32>, memref<?xf32>
+  return
+}
+
+// CHECKLOOP-LABEL: @conv1d_no_symbols
+//  CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32>
+//       CHECKLOOP: %[[c0:.*]] = constant 0 : index
+//       CHECKLOOP: %[[c1:.*]] = constant 1 : index
+//       CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+//       CHECKLOOP: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
+//       CHECKLOOP: scf.for %[[b:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//       CHECKLOOP:   scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKLOOP:     %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
+//       CHECKLOOP:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+//       CHECKLOOP:     %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+//       CHECKLOOP:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKLOOP:     %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKLOOP:     %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKLOOP:     store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+
+// CHECKPARALLEL-LABEL: @conv1d_no_symbols
+//  CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32>
+//       CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+//       CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+//       CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
+//       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) {
+//       CHECKPARALLEL:   scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKPARALLEL:     %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
+//       CHECKPARALLEL:     %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
+//       CHECKPARALLEL:     %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
+//       CHECKPARALLEL:     %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:     %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
+//       CHECKPARALLEL:     %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:     store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
+
+
+func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
+  linalg.conv2D(%in, %filter, %out) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+  return
+}
+// CHECKLOOP-LABEL: @conv2d_no_symbols
+//  CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//       CHECKLOOP: %[[c0:.*]] = constant 0 : index
+//       CHECKLOOP: %[[c1:.*]] = constant 1 : index
+//       CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+//       CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+//       CHECKLOOP: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
+//       CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
+//       CHECKLOOP: scf.for %[[arg3:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//       CHECKLOOP:   scf.for %[[arg4:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//       CHECKLOOP:     scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKLOOP:       scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//       CHECKLOOP:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
+//       CHECKLOOP:         %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
+//       CHECKLOOP:         %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
+//       CHECKLOOP:         %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
+//       CHECKLOOP:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKLOOP:         %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+//       CHECKLOOP:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKLOOP:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: @conv2d_no_symbols
+//  CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//       CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+//       CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+//       CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+//       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+//       CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
+//       CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKPARALLEL:     scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//       CHECKPARALLEL:       %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
+//       CHECKPARALLEL:       %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
+//       CHECKPARALLEL:       %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
+//       CHECKPARALLEL:       %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
+//       CHECKPARALLEL:       %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:       %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+//       CHECKPARALLEL:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:       store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
+
+
+func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
+  linalg.conv3D(%in, %filter, %out) : memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>
+  return
+}
+
+// CHECKLOOP-LABEL: @conv3d_no_symbols
+//  CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//       CHECKLOOP: %[[c2:.*]] = constant 2 : index
+//       CHECKLOOP: %[[c0:.*]] = constant 0 : index
+//       CHECKLOOP: %[[c1:.*]] = constant 1 : index
+//       CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+//       CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+//       CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+//       CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
+//       CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+//       CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+//       CHECKLOOP: scf.for %[[arg3:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+//       CHECKLOOP:   scf.for %[[arg4:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
+//       CHECKLOOP:     scf.for %[[arg5:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {
+//       CHECKLOOP:       scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKLOOP:         scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//       CHECKLOOP:           scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//       CHECKLOOP:             %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
+//       CHECKLOOP:             %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
+//       CHECKLOOP:             %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
+//       CHECKLOOP:             %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
+//       CHECKLOOP:             %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+//       CHECKLOOP:             %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKLOOP:             %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+//       CHECKLOOP:             %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKLOOP:             store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+
+// CHECKPARALLEL-LABEL: @conv3d_no_symbols
+//  CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[c2:.*]] = constant 2 : index
+//       CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
+//       CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
+//       CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+//       CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) {
+//       CHECKPARALLEL:   scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+//       CHECKPARALLEL:     scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+//       CHECKPARALLEL:       scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+//       CHECKPARALLEL:         %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
+//       CHECKPARALLEL:         %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
+//       CHECKPARALLEL:         %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
+//       CHECKPARALLEL:         %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
+//       CHECKPARALLEL:         %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>


        


More information about the Mlir-commits mailing list