[Mlir-commits] [mlir] 92f7e81 - [mlir][Linalg] Implement padding for linalg.conv and lowering to loops.
Hanhan Wang
llvmlistbot at llvm.org
Fri Mar 13 14:37:19 PDT 2020
Author: Hanhan Wang
Date: 2020-03-13T14:35:58-07:00
New Revision: 92f7e8133ae98e1f300bad164c4099b2e609bae7
URL: https://github.com/llvm/llvm-project/commit/92f7e8133ae98e1f300bad164c4099b2e609bae7
DIFF: https://github.com/llvm/llvm-project/commit/92f7e8133ae98e1f300bad164c4099b2e609bae7.diff
LOG: [mlir][Linalg] Implement padding for linalg.conv and lowering to loops.
Summary:
To enable this, two changes are needed:
1) Add an optional attribute `padding` to linalg.conv.
2) Compute if the indices accessing is out of bound in the loops. If so, use the
padding value `0`. Otherwise, use the value derived from load.
In the patch, the padding only works for lowering without other transformations,
e.g., tiling, fusion, etc.
Differential Revision: https://reviews.llvm.org/D75722
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 4ced0675fe95..7756a08d5cb2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -62,7 +62,7 @@ SmallVector<AffineExpr, 4> makeAffineDimExprs(unsigned num, unsigned &startIdx,
/// Builds the indexing expressions for a ConvOp `op`. Returns the vector of
/// AffineMaps representing:
-/// `stride[i] * xs[i] + dilation[i] * zs[i]`
+/// `stride[i] * xs[i] + dilation[i] * zs[i] - pad_low[i]`
SmallVector<AffineExpr, 4> weightedConvInputIndex(ConvOp op,
ArrayRef<AffineExpr> xs,
ArrayRef<AffineExpr> zs);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index a93486744a2d..457a8db7788f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -265,13 +265,18 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
```
}];
- // TODO(ntv) padding.
- // Following the TF source of truth above, strides and dilations are integer
- // attributes of the same rank as the number of window dimensions.
+ // Following the TF source of truth above, strides, dilations and padding are
+ // integer attributes of the same rank as the number of window dimensions.
+ // The padding attribute specifies the amount of zero padding to be applied to
+ // the base area, which is a n-d array of (low, high) padding. Each pair has
+ // the low padding as the first element and the high padding as the second
+ // element. Using padding is equivalent to inserting those same zero values
+ // into the input before doing the convolution.
let arguments = (ins AnyStridedMemRef:$filter, AnyStridedMemRef:$input,
AnyStridedMemRef:$output,
OptionalAttr<I64ArrayAttr>:$strides,
- OptionalAttr<I64ArrayAttr>:$dilations);
+ OptionalAttr<I64ArrayAttr>:$dilations,
+ OptionalAttr<I64ElementsAttr>:$padding);
let extraClassDeclaration = libraryCallName # [{
// TODO(ntv) extend to support more than 1 dimensions and potentially
@@ -314,9 +319,17 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
.cast<IntegerAttr>().getValue().getSExtValue();
}
- // F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) ->
- // O(b, x0, ..., xN-1, k)
- // for N equal to `nWindow`.
+ int64_t getLowPad(unsigned i) {
+ assert(i < getNumWindowLoops());
+ if (!padding().hasValue()) return 0;
+ return padding().getValue().getValue<int64_t>({i, 0});
+ }
+
+ // F(z0, ..., zN-1, q, k) *
+ // I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q)
+ // -> O(b, x0, ..., xN-1, k)
+ // for N equal to `nWindow`. If there is no padding attirbute, it will be
+ // ignored.
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
MLIRContext *context = getContext();
auto nWin = getNumWindowLoops();
@@ -346,7 +359,9 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
// filter[z[0], ..., z[N-1], q, k]
AffineMap::get(idx, 0, concat(concat(zs, qs), ks)),
// input[b,
- // x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1],
+ // x[0]*s[0] + d[0]*z[0] - pad_low[0],
+ // ...
+ // x[N-1]*s[N-1] + d[N-1]*z[N-1] - pad_low[N-1],
// q]
AffineMap::get(idx, 0, concat(concat(bs, ws), qs)),
// output[b, x[0], ..., x[N-1], k]
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 19cf7f55bcc4..aa340e55e8b5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -900,8 +900,12 @@ mlir::linalg::weightedConvInputIndex(ConvOp op, ArrayRef<AffineExpr> xs,
assert(xs.size() == zs.size());
SmallVector<AffineExpr, 4> res;
res.reserve(xs.size());
- for (unsigned i = 0, e = xs.size(); i < e; ++i)
- res.push_back(op.getStride(i) * xs[i] + op.getDilation(i) * zs[i]);
+ for (unsigned i = 0, e = xs.size(); i < e; ++i) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ auto expr =
+ op.getStride(i) * xs[i] + op.getDilation(i) * zs[i] - op.getLowPad(i);
+ res.push_back(expr);
+ }
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 6650c353b736..cf95212982c9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -152,6 +152,18 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
"expected linalg op with buffer semantics");
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
+
+ if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ if (convOp.padding())
+ llvm_unreachable("Unexpected conv with padding");
+ }
+ if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ if (convOp.padding())
+ llvm_unreachable("Unexpected conv with padding");
+ }
+
auto subView = dyn_cast_or_null<SubViewOp>(
consumer.getInput(consumerIdx).getDefiningOp());
auto slice =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 05722036f8e5..22f53b2ab8bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -177,6 +177,51 @@ class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
+ /// Returns the input value of convOp. If the indices in `imIdx` is out of
+ /// boundrary, returns 0 instead.
+ static ValueHandle getConvOpInput(ConvOp convOp, IndexedValueType im,
+ ArrayRef<ValueHandle> imIdx) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ if (!convOp.padding())
+ return im(imIdx);
+
+ ValueHandle zeroIndex = std_constant_index(0);
+ SmallVector<ValueHandle, 8> conds = {
+ std_constant_int(/*value=*/1, /*width=*/1)};
+ SmallVector<ValueHandle, 8> clampedImIdx;
+ for (auto iter : llvm::enumerate(imIdx)) {
+ int idx = iter.index();
+ auto dim = iter.value();
+ // Only need to iterate over the window dimensions.
+ if (idx == 0 || idx == static_cast<int>(imIdx.size()) - 1) {
+ clampedImIdx.push_back(dim);
+ continue;
+ }
+
+ using edsc::op::operator<;
+ using edsc::op::operator>=;
+ using edsc::op::operator||;
+ conds.push_back(conds.back() || (dim < zeroIndex));
+ ValueHandle bound = std_dim(convOp.input(), idx);
+ conds.push_back(conds.back() || (dim >= bound));
+
+ // When padding is involed, the indices will only be shifted to negative,
+ // so having a max op is enough.
+ auto *context = ScopedContext::getContext();
+ auto maxMap = AffineMap::get(/*dimCount=*/1, 0,
+ {getAffineDimExpr(/*position=*/0, context),
+ getAffineConstantExpr(0, context)});
+ clampedImIdx.push_back(
+ affine_max(dim.getType(), maxMap, ValueRange{dim}));
+ }
+
+ auto b = ScopedContext::getBuilder();
+ Type type = convOp.input().getType().cast<MemRefType>().getElementType();
+ ValueHandle zero = std_constant(type, b.getZeroAttr(type));
+ ValueHandle readInput = im(clampedImIdx);
+ return std_select(conds.back(), zero, readInput);
+ }
+
static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
assert(convOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
@@ -192,8 +237,10 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
SmallVector<ValueHandle, 8> oIdx(
makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
+
// Emit scalar form.
- O(oIdx) += F(fIdx) * I(imIdx);
+ ValueHandle paddedInput = getConvOpInput(convOp, I, imIdx);
+ O(oIdx) += F(fIdx) * paddedInput;
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index d751b46059b4..2b0654c219a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -193,6 +193,12 @@ SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
auto linalgOp = cast<linalg::LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
+ if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ if (convOp.padding())
+ llvm_unreachable("Unexpected conv with padding");
+ }
+
edsc::ScopedContext scope(rewriter, op->getLoc());
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
@@ -295,6 +301,12 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
"DRR failure case must be a precondition");
+ if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ if (convOp.padding())
+ llvm_unreachable("Unexpected conv with padding");
+ }
+
LinalgOp linOp = cast<LinalgOp>(op);
assert(linOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index e7a462d1a5df..54a4290e6e36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -160,6 +160,12 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
OperationFolder *folder) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+ if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ if (convOp.padding())
+ llvm_unreachable("Unexpected conv with padding");
+ }
+
// 1. Promote the specified views and use them in the new op.
ScopedContext scope(b, op.getLoc());
auto promotedBufferAndViews = promoteSubViews(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 3274abd81111..cabdd7497caf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -342,6 +342,12 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
tileSizes.size() &&
"expected matching number of tile sizes and loops");
+ if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ if (convOp.padding())
+ llvm_unreachable("Unexpected conv with padding");
+ }
+
// If permutation is empty, use the identity. Build the permutation map
// otherwise.
auto invPermutationMap = AffineMap::getMultiDimIdentityMap(
@@ -421,6 +427,12 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
if (tileSizes.empty())
return llvm::None;
+ if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
+ // TODO(ntv): add a level of indirection to linalg.generic.
+ if (convOp.padding())
+ llvm_unreachable("Unexpected conv with padding");
+ }
+
// The following uses the convention that "tiling by zero" skips tiling a
// particular dimension. This convention is significantly simpler to handle
// instead of adjusting affine maps to account for missing dimensions.
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index f0c9a8bf6e16..b2cd61791d66 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -7,6 +7,7 @@
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
// CHECK-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
+// CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
// CHECK-DAG: #[[Stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK-DAG: #[[Stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
@@ -212,6 +213,44 @@ func @conv_view4(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %
// CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32, #[[strided4D]]>
+func @conv_padding(%arg0: memref<?x?x?x?xf32>,
+ %arg1: memref<?x?x?x?xf32>,
+ %arg2: memref<?x?x?x?xf32>) {
+ linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
+ padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
+ strides = [1, 1]} :
+ memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+ return
+}
+// CHECK-LABEL: func @conv_padding
+// CHECK: %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>) {
+// CHECK: %[[ZERO:.*]] = constant 0.000000e+00 : f32
+// CHECK: %[[Z0:.*]] = dim %arg0, 0 : memref<?x?x?x?xf32>
+// CHECK: %[[Z1:.*]] = dim %arg0, 1 : memref<?x?x?x?xf32>
+// CHECK: %[[Q:.*]] = dim %arg0, 2 : memref<?x?x?x?xf32>
+// CHECK: %[[K:.*]] = dim %arg0, 3 : memref<?x?x?x?xf32>
+// CHECK: %[[B:.*]] = dim %arg1, 0 : memref<?x?x?x?xf32>
+// CHECK: %[[X0:.*]] = dim %arg2, 1 : memref<?x?x?x?xf32>
+// CHECK: %[[X1:.*]] = dim %arg2, 2 : memref<?x?x?x?xf32>
+// CHECK: loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} {
+// CHECK: loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} {
+// CHECK: loop.for %{{.*}} = %{{.*}} to %[[X1]] step %{{.*}} {
+// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
+// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} {
+// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Z0]] step %{{.*}} {
+// CHECK: loop.for %{{.*}} = %{{.*}} to %[[Z1]] step %{{.*}} {
+// CHECK: %[[SUM0:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}})
+// CHECK: %[[SUM1:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}})
+// CHECK: %[[IDX:.*]] = affine.max #[[clampMinMap]](%[[SUM0]])
+// CHECK: %[[IDY:.*]] = affine.max #[[clampMinMap]](%[[SUM1]])
+// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref<?x?x?x?xf32>
+// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
+// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+// CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
+// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+// CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+
func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
%f0 = constant 0.0 : f32
return %f0, %f0 : f32, f32
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 5cc3ab621df5..468fad45dd90 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -222,6 +222,28 @@ func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?,
// -----
+func @conv_padding(%arg0: memref<?x?x?x?xf32>,
+ %arg1: memref<?x?x?x?xf32>,
+ %arg2: memref<?x?x?x?xf32>) {
+ linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
+ padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
+ strides = [1, 1]} :
+ memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+ return
+}
+
+// CHECK-LABEL: func @conv_padding(
+// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {
+// CHECK-SAME: dilations = [1, 1],
+// CHECK-SAME: padding = dense<[
+// CHECK-SAME: [0, 1], [1, 1]]> : tensor<2x2xi64>,
+// CHECK-SAME: strides = [1, 1]} :
+// CHECK-SAME: memref<?x?x?x?xf32>,
+// CHECK-SAME: memref<?x?x?x?xf32>,
+// CHECK-SAME: memref<?x?x?x?xf32>
+
+// -----
+
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
More information about the Mlir-commits
mailing list