[Mlir-commits] [mlir] 92f7e81 - [mlir][Linalg] Implement padding for linalg.conv and lowering to loops.

Hanhan Wang llvmlistbot at llvm.org
Fri Mar 13 14:37:19 PDT 2020


Author: Hanhan Wang
Date: 2020-03-13T14:35:58-07:00
New Revision: 92f7e8133ae98e1f300bad164c4099b2e609bae7

URL: https://github.com/llvm/llvm-project/commit/92f7e8133ae98e1f300bad164c4099b2e609bae7
DIFF: https://github.com/llvm/llvm-project/commit/92f7e8133ae98e1f300bad164c4099b2e609bae7.diff

LOG: [mlir][Linalg] Implement padding for linalg.conv and lowering to loops.

Summary:
To enable this, two changes are needed:
1) Add an optional attribute `padding` to linalg.conv.
2) Compute if the indices accessing is out of bound in the loops. If so, use the
padding value `0`. Otherwise, use the value derived from load.

In the patch, the padding only works for lowering without other transformations,
e.g., tiling, fusion, etc.

Differential Revision: https://reviews.llvm.org/D75722

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 4ced0675fe95..7756a08d5cb2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -62,7 +62,7 @@ SmallVector<AffineExpr, 4> makeAffineDimExprs(unsigned num, unsigned &startIdx,
 
 /// Builds the indexing expressions for a ConvOp `op`. Returns the vector of
 /// AffineMaps representing:
-///   `stride[i] * xs[i] + dilation[i] * zs[i]`
+///   `stride[i] * xs[i] + dilation[i] * zs[i] - pad_low[i]`
 SmallVector<AffineExpr, 4> weightedConvInputIndex(ConvOp op,
                                                   ArrayRef<AffineExpr> xs,
                                                   ArrayRef<AffineExpr> zs);

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index a93486744a2d..457a8db7788f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -265,13 +265,18 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
     ```
   }];
 
-  // TODO(ntv) padding.
-  // Following the TF source of truth above, strides and dilations are integer
-  // attributes of the same rank as the number of window dimensions.
+  // Following the TF source of truth above, strides, dilations and padding are
+  // integer attributes of the same rank as the number of window dimensions.
+  // The padding attribute specifies the amount of zero padding to be applied to
+  // the base area, which is a n-d array of (low, high) padding. Each pair has
+  // the low padding as the first element and the high padding as the second
+  // element. Using padding is equivalent to inserting those same zero values
+  // into the input before doing the convolution.
   let arguments = (ins AnyStridedMemRef:$filter, AnyStridedMemRef:$input,
                    AnyStridedMemRef:$output,
                    OptionalAttr<I64ArrayAttr>:$strides,
-                   OptionalAttr<I64ArrayAttr>:$dilations);
+                   OptionalAttr<I64ArrayAttr>:$dilations,
+                   OptionalAttr<I64ElementsAttr>:$padding);
 
   let extraClassDeclaration = libraryCallName # [{
     // TODO(ntv) extend to support more than 1 dimensions and potentially
@@ -314,9 +319,17 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
         .cast<IntegerAttr>().getValue().getSExtValue();
     }
 
-    //   F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) ->
-    //     O(b, x0, ..., xN-1, k)
-    // for N equal to `nWindow`.
+    int64_t getLowPad(unsigned i) {
+      assert(i < getNumWindowLoops());
+      if (!padding().hasValue()) return 0;
+      return padding().getValue().getValue<int64_t>({i, 0});
+    }
+
+    //   F(z0, ..., zN-1, q, k) *
+    //     I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q)
+    //   ->  O(b, x0, ..., xN-1, k)
+    // for N equal to `nWindow`. If there is no padding attirbute, it will be
+    // ignored.
     llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
       MLIRContext *context = getContext();
       auto nWin = getNumWindowLoops();
@@ -346,7 +359,9 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
         // filter[z[0], ..., z[N-1], q, k]
         AffineMap::get(idx, 0, concat(concat(zs, qs), ks)),
         // input[b,
-        //       x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1],
+        //       x[0]*s[0] + d[0]*z[0] - pad_low[0],
+        //       ...
+        //       x[N-1]*s[N-1] + d[N-1]*z[N-1] - pad_low[N-1],
         //       q]
         AffineMap::get(idx, 0, concat(concat(bs, ws), qs)),
         // output[b, x[0], ..., x[N-1], k]

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 19cf7f55bcc4..aa340e55e8b5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -900,8 +900,12 @@ mlir::linalg::weightedConvInputIndex(ConvOp op, ArrayRef<AffineExpr> xs,
   assert(xs.size() == zs.size());
   SmallVector<AffineExpr, 4> res;
   res.reserve(xs.size());
-  for (unsigned i = 0, e = xs.size(); i < e; ++i)
-    res.push_back(op.getStride(i) * xs[i] + op.getDilation(i) * zs[i]);
+  for (unsigned i = 0, e = xs.size(); i < e; ++i) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    auto expr =
+        op.getStride(i) * xs[i] + op.getDilation(i) * zs[i] - op.getLowPad(i);
+    res.push_back(expr);
+  }
   return res;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 6650c353b736..cf95212982c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -152,6 +152,18 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
          "expected linalg op with buffer semantics");
   assert(consumer.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
+
+  if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    if (convOp.padding())
+      llvm_unreachable("Unexpected conv with padding");
+  }
+  if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    if (convOp.padding())
+      llvm_unreachable("Unexpected conv with padding");
+  }
+
   auto subView = dyn_cast_or_null<SubViewOp>(
       consumer.getInput(consumerIdx).getDefiningOp());
   auto slice =

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 05722036f8e5..22f53b2ab8bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -177,6 +177,51 @@ class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
 template <typename IndexedValueType>
 class LinalgScopedEmitter<IndexedValueType, ConvOp> {
 public:
+  /// Returns the input value of convOp. If the indices in `imIdx` is out of
+  /// boundrary, returns 0 instead.
+  static ValueHandle getConvOpInput(ConvOp convOp, IndexedValueType im,
+                                    ArrayRef<ValueHandle> imIdx) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    if (!convOp.padding())
+      return im(imIdx);
+
+    ValueHandle zeroIndex = std_constant_index(0);
+    SmallVector<ValueHandle, 8> conds = {
+        std_constant_int(/*value=*/1, /*width=*/1)};
+    SmallVector<ValueHandle, 8> clampedImIdx;
+    for (auto iter : llvm::enumerate(imIdx)) {
+      int idx = iter.index();
+      auto dim = iter.value();
+      // Only need to iterate over the window dimensions.
+      if (idx == 0 || idx == static_cast<int>(imIdx.size()) - 1) {
+        clampedImIdx.push_back(dim);
+        continue;
+      }
+
+      using edsc::op::operator<;
+      using edsc::op::operator>=;
+      using edsc::op::operator||;
+      conds.push_back(conds.back() || (dim < zeroIndex));
+      ValueHandle bound = std_dim(convOp.input(), idx);
+      conds.push_back(conds.back() || (dim >= bound));
+
+      // When padding is involed, the indices will only be shifted to negative,
+      // so having a max op is enough.
+      auto *context = ScopedContext::getContext();
+      auto maxMap = AffineMap::get(/*dimCount=*/1, 0,
+                                   {getAffineDimExpr(/*position=*/0, context),
+                                    getAffineConstantExpr(0, context)});
+      clampedImIdx.push_back(
+          affine_max(dim.getType(), maxMap, ValueRange{dim}));
+    }
+
+    auto b = ScopedContext::getBuilder();
+    Type type = convOp.input().getType().cast<MemRefType>().getElementType();
+    ValueHandle zero = std_constant(type, b.getZeroAttr(type));
+    ValueHandle readInput = im(clampedImIdx);
+    return std_select(conds.back(), zero, readInput);
+  }
+
   static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
     assert(convOp.hasBufferSemantics() &&
            "expected linalg op with buffer semantics");
@@ -192,8 +237,10 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
     SmallVector<ValueHandle, 8> oIdx(
         makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
     IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
+
     // Emit scalar form.
-    O(oIdx) += F(fIdx) * I(imIdx);
+    ValueHandle paddedInput = getConvOpInput(convOp, I, imIdx);
+    O(oIdx) += F(fIdx) * paddedInput;
   }
 };
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index d751b46059b4..2b0654c219a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -193,6 +193,12 @@ SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
   auto linalgOp = cast<linalg::LinalgOp>(op);
   assert(linalgOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
+  if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    if (convOp.padding())
+      llvm_unreachable("Unexpected conv with padding");
+  }
+
   edsc::ScopedContext scope(rewriter, op->getLoc());
 
   if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
@@ -295,6 +301,12 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
   assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
          "DRR failure case must be a precondition");
 
+  if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    if (convOp.padding())
+      llvm_unreachable("Unexpected conv with padding");
+  }
+
   LinalgOp linOp = cast<LinalgOp>(op);
   assert(linOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index e7a462d1a5df..54a4290e6e36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -160,6 +160,12 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
                                               OperationFolder *folder) {
   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
 
+  if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    if (convOp.padding())
+      llvm_unreachable("Unexpected conv with padding");
+  }
+
   // 1. Promote the specified views and use them in the new op.
   ScopedContext scope(b, op.getLoc());
   auto promotedBufferAndViews = promoteSubViews(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 3274abd81111..cabdd7497caf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -342,6 +342,12 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
              tileSizes.size() &&
          "expected matching number of tile sizes and loops");
 
+  if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    if (convOp.padding())
+      llvm_unreachable("Unexpected conv with padding");
+  }
+
   // If permutation is empty, use the identity. Build the permutation map
   // otherwise.
   auto invPermutationMap = AffineMap::getMultiDimIdentityMap(
@@ -421,6 +427,12 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
   if (tileSizes.empty())
     return llvm::None;
 
+  if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
+    // TODO(ntv): add a level of indirection to linalg.generic.
+    if (convOp.padding())
+      llvm_unreachable("Unexpected conv with padding");
+  }
+
   // The following uses the convention that "tiling by zero" skips tiling a
   // particular dimension. This convention is significantly simpler to handle
   // instead of adjusting affine maps to account for missing dimensions.

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index f0c9a8bf6e16..b2cd61791d66 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -7,6 +7,7 @@
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 // CHECK-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
+// CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
 
 // CHECK-DAG: #[[Stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
 // CHECK-DAG: #[[Stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
@@ -212,6 +213,44 @@ func @conv_view4(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %
 //       CHECK:                 %{{.*}} = addf %{{.*}}, %{{.*}} : f32
 //       CHECK:                 store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32, #[[strided4D]]>
 
+func @conv_padding(%arg0: memref<?x?x?x?xf32>,
+                   %arg1: memref<?x?x?x?xf32>,
+                   %arg2: memref<?x?x?x?xf32>) {
+  linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
+                                    padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
+                                    strides = [1, 1]} :
+    memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @conv_padding
+//       CHECK: %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>) {
+//       CHECK:   %[[ZERO:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[Z0:.*]] = dim %arg0, 0 : memref<?x?x?x?xf32>
+//       CHECK:   %[[Z1:.*]] = dim %arg0, 1 : memref<?x?x?x?xf32>
+//       CHECK:   %[[Q:.*]] =  dim %arg0, 2 : memref<?x?x?x?xf32>
+//       CHECK:   %[[K:.*]] =  dim %arg0, 3 : memref<?x?x?x?xf32>
+//       CHECK:   %[[B:.*]] =  dim %arg1, 0 : memref<?x?x?x?xf32>
+//       CHECK:   %[[X0:.*]] = dim %arg2, 1 : memref<?x?x?x?xf32>
+//       CHECK:   %[[X1:.*]] = dim %arg2, 2 : memref<?x?x?x?xf32>
+//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} {
+//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} {
+//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[X1]] step %{{.*}} {
+//       CHECK:         loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
+//       CHECK:           loop.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} {
+//       CHECK:             loop.for %{{.*}} = %{{.*}} to %[[Z0]] step %{{.*}} {
+//       CHECK:               loop.for %{{.*}} = %{{.*}} to %[[Z1]] step %{{.*}} {
+//       CHECK:                 %[[SUM0:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}})
+//       CHECK:                 %[[SUM1:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}})
+//       CHECK:                 %[[IDX:.*]] = affine.max #[[clampMinMap]](%[[SUM0]])
+//       CHECK:                 %[[IDY:.*]] = affine.max #[[clampMinMap]](%[[SUM1]])
+//       CHECK:                 %{{.*}} = load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
+//       CHECK:                 %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
+//       CHECK:                 %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+//       CHECK:                 store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+
 func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
   %f0 = constant 0.0 : f32
   return %f0, %f0 : f32, f32

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 5cc3ab621df5..468fad45dd90 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -222,6 +222,28 @@ func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?,
 
 // -----
 
+func @conv_padding(%arg0: memref<?x?x?x?xf32>,
+                   %arg1: memref<?x?x?x?xf32>,
+                   %arg2: memref<?x?x?x?xf32>) {
+  linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
+                                    padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
+                                    strides = [1, 1]} :
+    memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @conv_padding(
+//       CHECK:   linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {
+//  CHECK-SAME:     dilations = [1, 1],
+//  CHECK-SAME:     padding = dense<[
+//  CHECK-SAME:                      [0, 1], [1, 1]]> : tensor<2x2xi64>,
+//  CHECK-SAME:     strides = [1, 1]} :
+//  CHECK-SAME:     memref<?x?x?x?xf32>,
+//  CHECK-SAME:     memref<?x?x?x?xf32>,
+//  CHECK-SAME:     memref<?x?x?x?xf32>
+
+// -----
+
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 


        


More information about the Mlir-commits mailing list