[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
Hsiangkai Wang
llvmlistbot at llvm.org
Fri Jul 12 23:18:25 PDT 2024
================
@@ -36,6 +190,329 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
constexpr TransformMapKeyTy F_4_3{4, 3};
constexpr TransformMapKeyTy F_2_5{2, 5};
+/// Structure to keep information of constant transform matrices.
+struct TransformMatrix {
+ TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ int64_t scalarFactor = 1)
+ : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+ const float *table;
+ int64_t rows;
+ int64_t cols;
+ int64_t scalarFactor;
+};
+
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
+ TransformMatrix transform, Type type) {
+ ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+
+ return builder.create<arith::ConstantOp>(
+ loc, DenseFPElementsAttr::get(
+ RankedTensorType::get(
+ SmallVector<int64_t>{transform.rows, transform.cols}, type),
+ constVec));
+}
+
+/// Extract height x width data from 4D or 6D tensors.
+Value extract2DData(OpBuilder &builder, Location loc, Value source,
+ Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+ int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+ int64_t srcSize) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ auto sourceShape = sourceType.getShape();
+ int64_t height = sourceShape[heightIdx];
+ int64_t width = sourceShape[widthIdx];
+
+ auto zeroIndex = builder.getIndexAttr(0);
+ auto oneIndex = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+ offsets[outLoopIdx] = outLoopIndex;
+ offsets[inLoopIdx] = inLoopIndex;
+ SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+ sizes[heightIdx] = builder.getIndexAttr(height);
+ sizes[widthIdx] = builder.getIndexAttr(width);
+ SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+ SmallVector<int64_t> targetShape(srcSize, 1);
+ targetShape[heightIdx] = height;
+ targetShape[widthIdx] = width;
+
+ auto targetType = RankedTensorType::get(targetShape, elementType);
+ auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+ loc, targetType, source, offsets, sizes, strides);
+
+ auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+ auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+ builder, loc, extractFilterOp, extractFilterType);
+
+ return extractFilter;
+}
+
+/// Insert transformed height x width data to 4D or 6D tensors which it is
+/// extracted from.
+Value insert2DData(OpBuilder &builder, Location loc, Value source, Value dest,
+ Value outLoopIndex, Value inLoopIndex, int64_t height,
+ int64_t width, int64_t outLoopIdx, int64_t inLoopIdx,
+ int64_t heightIdx, int64_t widthIdx, int64_t destSize) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ SmallVector<int64_t> sliceShape(destSize, 1);
+ sliceShape[heightIdx] = height;
+ sliceShape[widthIdx] = width;
+ auto init = builder.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+ auto result = tensor::createCanonicalRankReducingInsertSliceOp(builder, loc,
+ source, init);
+
+ auto zeroIndex = builder.getIndexAttr(0);
+ auto oneIndex = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+ retOffsets[outLoopIdx] = outLoopIndex;
+ retOffsets[inLoopIdx] = inLoopIndex;
+ SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+ retSizes[heightIdx] = builder.getIndexAttr(height);
+ retSizes[widthIdx] = builder.getIndexAttr(width);
+ SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result, dest, retOffsets, retSizes, strides);
+
+ return insertSliceOp;
+}
+
+/// This function transforms the filter. The data layout of the filter is FHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+/// After the transformation, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract filter<h x w> from filter<f x h x w x c>
+/// %ret = linalg.matmul G, %extracted
+/// %ret = linalg.matmul %ret, GT
+/// %inserted = insert %ret into filter<h x w x c x f>
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+ Value retValue, int64_t m, int64_t r,
+ bool leftTransform = true, bool rightTransform = true) {
+ // Map from (m, r) to G transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ GMatrices = {
+ {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+ {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+ {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+ };
+
+ // Map from (m, r) to GT transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ GTMatrices = {
+ {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+ {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+ {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+ };
+
+ auto filterType = cast<ShapedType>(filter.getType());
+ Type elementType = filterType.getElementType();
+ auto filterShape = filterType.getShape(); // F, H, W, C
+ int64_t filterF = filterShape[0];
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+ int64_t filterC = filterShape[3];
+
+ if (filterH != r && filterH != 1)
+ return Value();
+ if (filterW != r && filterW != 1)
+ return Value();
+
+ auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+ ValueRange args) -> scf::ValueVector {
+ Value FIter = ivs[0];
+ Value CIter = ivs[1];
+
+ // Extract (H, W) from (F, H, W, C).
+ auto extractFilter = extract2DData(
+ builder, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+ /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+ TransformMapKeyTy key = {m, r};
+ int64_t retRows = 1;
+ Value matmulRetValue = extractFilter;
+ if (leftTransform) {
+ // Get constant transform matrix G.
+ auto it = GMatrices.find(key);
+ if (it == GMatrices.end())
+ return {};
+ const TransformMatrix &GMatrix = it->second;
+
+ retRows = GMatrix.rows;
+ auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+ auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+ // Multiply G x g.
+ auto matmulOp = builder.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ if (rightTransform) {
+ // Get constant transform matrix GT.
+ auto it = GTMatrices.find(key);
+ if (it == GTMatrices.end())
+ return {};
+ const TransformMatrix >Matrix = it->second;
+
+ auto matmulType =
+ RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+ auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+ // Multiply u = (G x g) x GT.
+ auto matmulOp = builder.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ // Insert (H, W) to (H, W, C, F).
+ int64_t retHeight = leftTransform ? m + r - 1 : 1;
+ int64_t retWidth = rightTransform ? m + r - 1 : 1;
+ auto insertSliceOp = insert2DData(builder, loc, matmulRetValue, args[0],
+ FIter, CIter, retHeight, retWidth,
+ /*outLoopIdx=*/3, /*inLoopIdx=*/2,
+ /*heightIdx=*/0, /*widthIdx=*/1,
+ /*destSize=*/4);
+
+ return {insertSliceOp};
+ };
+
+ auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+ auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+ auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ scf::LoopNest loops = scf::buildLoopNest(
+ rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
+ {oneStep, oneStep}, {retValue}, buildBody);
+ return loops.results[0];
+}
+
+/// This function transforms the input. The data layout of the input is NHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+/// After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract input<h x w> from input<n x h x w x c>
+/// %ret = linalg.matmul BT, %extracted
+/// %ret = linalg.matmul %ret, B
+/// %inserted = insert %ret into input<h x w x n x c>
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+ Value retValue, int64_t m, int64_t r,
+ bool leftTransform = true, bool rightTransform = true) {
+ // Map from (m, r) to BT transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ BTMatrices = {
+ {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+ {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+ {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+ };
+
+ // Map from (m, r) to B transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ BMatrices = {
+ {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+ {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+ {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+ };
+
+ auto inputType = cast<ShapedType>(input.getType());
+ Type elementType = inputType.getElementType();
+ auto inputShape = inputType.getShape(); // N, H, W, C
+ int64_t inputN = inputShape[0];
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t inputC = inputShape[3];
+ int64_t alphaH = leftTransform ? m + r - 1 : 1;
+ int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+ if (inputH != alphaH && inputH != 1)
+ return Value();
+ if (inputW != alphaW && inputW != 1)
+ return Value();
+
+ auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+ ValueRange args) -> scf::ValueVector {
+ Value NIter = ivs[0];
+ Value CIter = ivs[1];
+
+ // Extract (H, W) from (N, H, W, C).
+ auto extractInput = extract2DData(
+ builder, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+ /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+ TransformMapKeyTy key = {m, r};
+ int64_t retRows = 1;
+ int64_t retCols = 1;
+ Value matmulRetValue = extractInput;
+ if (leftTransform) {
+ // Get constant transform matrix BT.
+ auto it = BTMatrices.find(key);
+ if (it == BTMatrices.end())
+ return {};
+ const TransformMatrix &BTMatrix = it->second;
+
+ retRows = BTMatrix.rows;
+ auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+ auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value BT =
+ create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+ // Multiply BT x d.
+ auto matmulOp = builder.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ if (rightTransform) {
+ // Get constant transform matrix B.
+ auto it = BMatrices.find(key);
+ if (it == BMatrices.end())
+ return {};
+ const TransformMatrix &BMatrix = it->second;
+
+ retCols = BMatrix.cols;
+ auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+ auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+ Value B =
+ create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+ // Multiply v = (BT x d) x B.
+ auto matmulOp = builder.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ // Insert (H, W) to (H, W, 1, 1, N, C).
----------------
Hsiangkai wrote:
The implementation is based on the pseudo code of its origin paper which describes how to support arbitrary size of height and width using tiling and using nested loops for N and C. That's why my implementation only considers to tile height and width outside decomposition. Your proposal 2 is interesting that to support arbitrary values of tileH and tileW inside decomposition. I can give it a try to make the decomposition more flexible.
Thanks for your review and suggestion.
https://github.com/llvm/llvm-project/pull/96183
More information about the Mlir-commits
mailing list