[Mlir-commits] [mlir] 92e751d - [mlir][linalg] Add NHWC + FHWC Img2Col (#68708)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 13 02:20:22 PDT 2023
Author: Jack Frankland
Date: 2023-10-13T10:20:18+01:00
New Revision: 92e751d426dbc17607bc8f552325fc659f4d0f66
URL: https://github.com/llvm/llvm-project/commit/92e751d426dbc17607bc8f552325fc659f4d0f66
DIFF: https://github.com/llvm/llvm-project/commit/92e751d426dbc17607bc8f552325fc659f4d0f66.diff
LOG: [mlir][linalg] Add NHWC + FHWC Img2Col (#68708)
Adds the Img2Col transformation for the fhwc channel ordering in a
Conv2D. Because of how the channel ordering affects the matrix
dimensions in the flattened filter this results in a slightly different
implementation of the actual "matrix multiplication". Instead of doing a
regular row-column dot-product this arrangement requires a row-row dot
product, otherwise the filter matrix would first need to be transposed.
Adds a lit test to the transform dialect to check the semantics of the
optimization are correct.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07a192f7b8606d3..3597209d7f90c25 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1175,6 +1175,14 @@ FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp);
+/// Same as the above but for Fhwc channel orderings in the filter. In this case
+/// the matrix multiplication is actually a row-wise dot-product rather than a
+/// row-column dot-product. This is to avoid transposing the filter matrix which
+/// would be required for a regular matrix multiplication to produce the correct
+/// output dimensions.
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp);
+
/// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no
/// reduction among the input channels so each convolution can be a
/// matrix-vector product and by transposing both input filter so channels are
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9ce780d3d249cfb..8508507871d0c6c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3118,6 +3118,9 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
.Case([&](linalg::Conv2DNhwcHwcfOp op) {
return rewriteInIm2Col(rewriter, op);
})
+ .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+ return rewriteInIm2Col(rewriter, op);
+ })
.Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
return rewriteInIm2Col(rewriter, op);
})
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 275e78aaa73dde6..e7629d79494bd47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -494,6 +494,141 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
reshapedResult.getOperation());
}
+FailureOr<std::pair<Operation *, Operation *>>
+rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
+ auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
+
+ if (!filterType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected a static shape for the filter");
+
+ if (!inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(convOp,
+ "expected a static shape for the input");
+
+ // TODO: Support dilation.
+ if (!hasAllOneValues(convOp.getDilations()))
+ return rewriter.notifyMatchFailure(convOp,
+ "expected all ones for dilations");
+
+ MLIRContext *context = rewriter.getContext();
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+
+ int64_t n = outputShape[0];
+ int64_t oh = outputShape[1];
+ int64_t ow = outputShape[2];
+ int64_t oc = outputShape[3];
+ int64_t fh = filterShape[1];
+ int64_t fw = filterShape[2];
+ int64_t ic = filterShape[3];
+
+ Location loc = convOp.getLoc();
+
+ // Reshape output and filter to the LHS and result of a "row-wise" matrix
+ // multiplication.
+ SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
+ auto reshapedFilterType =
+ RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
+ Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedFilterType, filter, filterReassocIndices);
+
+ SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
+ RankedTensorType reshapedOutputType =
+ RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
+ Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
+ loc, reshapedOutputType, output, outputReassocIndices);
+
+ SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
+ Value colTensor = rewriter.create<tensor::EmptyOp>(
+ loc, colTensorShape, inputType.getElementType());
+
+ // Convert the input to a (BMK) column tensor.
+ auto nloops = colTensorShape.size();
+
+ auto parallel = utils::IteratorType::parallel;
+ auto reduction = utils::IteratorType::reduction;
+ SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+
+ SmallVector<AffineMap> img2colIndexingMaps = {
+ AffineMap::getMultiDimIdentityMap(nloops, context)};
+
+ auto img2ColTensor = rewriter.create<linalg::GenericOp>(
+ loc, colTensor.getType(),
+ /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ img2colIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ // Get the iterators named based on the matmul (batch, m, k).
+ Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
+ Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
+ Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
+
+ // Recover the original iteration indices from the problem/input sizes.
+ SmallVector<Value> mIndices = unrollIndex(
+ nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
+ auto ohIndex = mIndices[0];
+ auto owIndex = mIndices[1];
+
+ SmallVector<Value> kIndices = unrollIndex(
+ nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
+ auto fhIndex = kIndices[0];
+ auto fwIndex = kIndices[1];
+ auto icIndex = kIndices[2];
+
+ // Extract the input element corresponding to the expanded indices.
+ Value hIndex =
+ getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
+ convOp.getStrides().getValues<int64_t>()[0]);
+ Value wIndex =
+ getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
+ convOp.getStrides().getValues<int64_t>()[1]);
+
+ // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
+ SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
+ Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
+ loc, input, extractionIndices);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ });
+
+ // Because we didn't transpose the filters we don't actually have a batched
+ // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
+ // products.
+ AffineExpr bDim, mDim, nDim, kDim;
+ bindDims(context, bDim, mDim, nDim, kDim);
+ auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
+ auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
+ auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
+ SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
+ parallel, reduction};
+
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, reshapedOutputType,
+ /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
+ /*outputs=*/ValueRange{reshapedOutput},
+ ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value mul =
+ createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
+ Value add = createAdd(loc, mul, args[2], nestedBuilder);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
+ });
+ Value result = genericOp.getResults().front();
+
+ auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
+ loc, outputType, result, outputReassocIndices);
+
+ rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
+
+ return std::make_pair(img2ColTensor.getOperation(),
+ reshapedResult.getOperation());
+}
+
namespace {
class ConvertConv2DNhwcHwcf final
@@ -534,12 +669,25 @@ class ConvertConv2DNchwFchw final
return success();
}
};
+
+class ConvertConv2DNhwcFhwc final
+ : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(rewriteInIm2Col(rewriter, convOp)))
+ return failure();
+ return success();
+ }
+};
} // end anonymous namespace
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
- ConvertConv2DNchwFchw>(context);
+ ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
}
} // end namespace linalg
} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index 657cf83f25460fd..b2470ed7b748042 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -279,6 +279,76 @@ transform.sequence failures(propagate) {
// -----
+// CHECK: IR printer: tensor_producer
+// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+
+// Collapsed indices.
+// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
+// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
+// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
+
+// Compute input channel/convolved indices.
+// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
+// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
+// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])
+
+// Extract from the input tensor.
+// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
+// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
+// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+
+// CHECK: IR printer: transformed
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: @conv_2d_nhwc_fhwc
+// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
+// CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32>
+// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
+// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
+// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
+// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: #[[MAP0]]
+// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %{{.+}} : f32
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
+// CHECK-SAME: #[[MAP1]]
+// CHECK-SAME: #[[MAP2]]
+// CHECK-SAME: #[[MAP3]]
+// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
+// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
+// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+// CHECK: } -> tensor<1x196x16xf32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: return %[[RESULT]]
+
+func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
+ transform.print %transformed {name = "transformed"}: !transform.any_op
+}
+
+// -----
+
// Check for signed extend when the input type is smaller than the accumulator type.
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
More information about the Mlir-commits
mailing list