[Mlir-commits] [mlir] [mlir][tosa] Convert tosa.transpose_conv2d to linalg.generic directly (PR #79824)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 29 05:25:37 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
Currently, we use reverse, pad, reshape, and conv2d operators, etc, to emulate transpose_conv2d. This patch adds a pattern to convert tosa.transpose_conv2d to linalg.generic directly.
---
Full diff: https://github.com/llvm/llvm-project/pull/79824.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+77-1)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (+1)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+32)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8dc2d27bd545ff8..b9fd03ff67c4e23 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1010,6 +1010,81 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
return success();
}
};
+
+class TransposeConv2DConverter
+ : public OpConversionPattern<tosa::TransposeConv2DOp> {
+public:
+ using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ Location loc = op->getLoc();
+ Value input = op->getOperand(0);
+ Value weight = op->getOperand(1);
+ Value bias = op->getOperand(2);
+
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ ShapedType weightTy = cast<ShapedType>(weight.getType());
+ ShapedType biasTy = cast<ShapedType>(bias.getType());
+ ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op,
+ "tosa.transpose_conv2d requires static shapes for weight and bias");
+
+ Type inputETy = inputTy.getElementType();
+ Type resultETy = resultTy.getElementType();
+
+ if (inputETy.isUnsignedInteger())
+ return rewriter.notifyMatchFailure(
+ op, "tosa.transpose_conv2d does not support unsigned integer input");
+
+ // Broadcast the bias as the starting values for accumulation.
+ auto emptyTensor =
+ rewriter.create<tensor::EmptyOp>(loc, resultTy.getShape(), resultETy);
+
+ Value broadcastBias =
+ linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, emptyTensor);
+
+ auto *context = op->getContext();
+ AffineExpr n, ih, iw, oc, ic, kh, kw;
+ bindDims(context, n, ih, iw, oc, ic, kh, kw);
+
+ constexpr unsigned numDims = 7;
+ auto lhsMap = AffineMap::get(numDims, 0, {n, ih, iw, ic}, context);
+ auto rhsMap = AffineMap::get(numDims, 0, {oc, kh, kw, ic}, context);
+ /* outPad: top, bottom, left, right */
+ ArrayRef<int64_t> outPad = op.getOutPadAttr();
+ ArrayRef<int64_t> stride = op.getStrideAttr();
+ auto resultMap = AffineMap::get(numDims, 0,
+ {n, ih * stride[0] + outPad[0] + kh,
+ iw * stride[1] + outPad[2] + kw, oc},
+ context);
+
+ auto transposeConv2D =
+ rewriter
+ .create<linalg::GenericOp>(
+ loc, resultTy, ValueRange({input, weight}), broadcastBias,
+ ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap},
+ tosa::getNParallelLoopsAttrs(numDims),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
+ ValueRange args) {
+ auto mul =
+ nestedBuilder.create<arith::MulFOp>(loc, args[0], args[1])
+ .getResult();
+ auto acc =
+ nestedBuilder.create<arith::AddFOp>(loc, mul, args[2])
+ .getResult();
+ nestedBuilder.create<linalg::YieldOp>(loc, acc);
+ })
+ .getResult(0);
+
+ rewriter.replaceOp(op, transposeConv2D);
+
+ return success();
+ }
+};
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
@@ -1031,7 +1106,8 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
MaxPool2dConverter,
AvgPool2dConverter,
FullyConnectedConverter,
- TransposeConverter
+ TransposeConverter,
+ TransposeConv2DConverter
>(patterns->getContext());
// clang-format on
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b9d..422d1e6189a21e6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -61,6 +61,7 @@ struct TosaToLinalgNamed
target.addIllegalOp<tosa::MatMulOp>();
target.addIllegalOp<tosa::FullyConnectedOp>();
target.addIllegalOp<tosa::TransposeOp>();
+ target.addIllegalOp<tosa::TransposeConv2DOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 6616ea7cf699fa5..b39c488e3bbbeb0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -781,3 +781,35 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
return
}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d5, d6, d4)>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+
+// CHECK-LABEL: @transpose_conv2d
+func.func @transpose_conv2d(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x3x1xf32> {
+ // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x1x3x1xf32>
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic
+ // CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%arg2 : tensor<1xf32>)
+ // CHECK-SAME: outs(%[[INIT]] : tensor<1x1x3x1xf32>) {
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+ // CHECK: linalg.yield %[[IN]] : f32
+ // CHECK: } -> tensor<1x1x3x1xf32>
+ // CHECK: %[[RESULT:.+]] = linalg.generic
+ // CHECK-SAME: {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%arg0, %arg1 : tensor<1x1x2x1xf32>, tensor<1x1x2x1xf32>)
+ // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x1x3x1xf32>) {
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_0:.+]]: f32, %[[OUT:.+]]: f32):
+ // CHECK: %[[S3:.+]] = arith.mulf %[[IN]], %[[IN_0]] : f32
+ // CHECK: %[[S4:.+]] = arith.addf %[[S3]], %[[OUT]] : f32
+ // CHECK: linalg.yield %[[S4]] : f32
+ // CHECK: } -> tensor<1x1x3x1xf32>
+ // CHECK: return %[[RESULT]] : tensor<1x1x3x1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 1, 3, 1>, stride = array<i64: 1, 1>} : (tensor<1x1x2x1xf32>, tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x3x1xf32>
+ return %0 : tensor<1x1x3x1xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/79824
More information about the Mlir-commits
mailing list