[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 11 07:34:01 PDT 2024


================
@@ -36,6 +190,329 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+/// Structure to keep information of constant transform matrices.
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+
+  return builder.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               constVec));
+}
+
+/// Extract height x width data from 4D or 6D tensors.
+Value extract2DData(OpBuilder &builder, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                    int64_t srcSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = builder.getIndexAttr(height);
+  sizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+  SmallVector<int64_t> targetShape(srcSize, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      builder, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+/// Insert transformed height x width data to 4D or 6D tensors which it is
+/// extracted from.
+Value insert2DData(OpBuilder &builder, Location loc, Value source, Value dest,
+                   Value outLoopIndex, Value inLoopIndex, int64_t height,
+                   int64_t width, int64_t outLoopIdx, int64_t inLoopIdx,
+                   int64_t heightIdx, int64_t widthIdx, int64_t destSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(destSize, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = builder.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(builder, loc,
+                                                                 source, init);
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = builder.getIndexAttr(height);
+  retSizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
+/// This function transforms the filter. The data layout of the filter is FHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+/// After the transformation, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value FIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (F, H, W, C).
+    auto extractFilter = extract2DData(
+        builder, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    Value matmulRetValue = extractFilter;
+    if (leftTransform) {
+      // Get constant transform matrix G.
+      auto it = GMatrices.find(key);
+      if (it == GMatrices.end())
+        return {};
+      const TransformMatrix &GMatrix = it->second;
+
+      retRows = GMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+      // Multiply G x g.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix GT.
+      auto it = GTMatrices.find(key);
+      if (it == GTMatrices.end())
+        return {};
+      const TransformMatrix &GTMatrix = it->second;
+
+      auto matmulType =
+          RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+      // Multiply u = (G x g) x GT.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, C, F).
+    int64_t retHeight = leftTransform ? m + r - 1 : 1;
+    int64_t retWidth = rightTransform ? m + r - 1 : 1;
+    auto insertSliceOp = insert2DData(builder, loc, matmulRetValue, args[0],
+                                      FIter, CIter, retHeight, retWidth,
+                                      /*outLoopIdx=*/3, /*inLoopIdx=*/2,
+                                      /*heightIdx=*/0, /*widthIdx=*/1,
+                                      /*destSize=*/4);
+
+    return {insertSliceOp};
+  };
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
+      {oneStep, oneStep}, {retValue}, buildBody);
+  return loops.results[0];
+}
+
+/// This function transforms the input. The data layout of the input is NHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+/// After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (inputH != alphaH && inputH != 1)
+    return Value();
+  if (inputW != alphaW && inputW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value NIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (N, H, W, C).
+    auto extractInput = extract2DData(
+        builder, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    int64_t retCols = 1;
+    Value matmulRetValue = extractInput;
+    if (leftTransform) {
+      // Get constant transform matrix BT.
+      auto it = BTMatrices.find(key);
+      if (it == BTMatrices.end())
+        return {};
+      const TransformMatrix &BTMatrix = it->second;
+
+      retRows = BTMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value BT =
+          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+      // Multiply BT x d.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix B.
+      auto it = BMatrices.find(key);
+      if (it == BMatrices.end())
+        return {};
+      const TransformMatrix &BMatrix = it->second;
+
+      retCols = BMatrix.cols;
+      auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+      Value B =
+          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+      // Multiply v = (BT x d) x B.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, 1, 1, N, C).
----------------
Max191 wrote:

The shape of the destination here should not always be `1x1` in the inner H and W dims. There are 2 ways that I would suggest for implementing the decomposition here:

1. Before decomposing, tile `N`, `C`, `HDiv`, `WDiv` down to `1`, and require these sizes to be `1` when you do the decomposition. The decomposition is very simple then, because it does not need any loop nest. This is what we have done in IREE, and the requirement to tile these dims to 1 before decomposing has not posed any problems so far.
2. The other way to do it would be to allow for partial tiling of `N`, `C`, `HDiv`, `WDiv`, and do the remaining tiling down to 1 as part of the decomposition. This approach makes the tile size selection before decomposition a little more flexible, but you need to be careful that the tiling implementation tiles `HDiv` and `WDiv` of the result instead of `H` and `W` of the input (this is natural since tiling implementations usually tile based on the result, but in the output transform things are reversed, and you would want to tile based on the input).

What is being done by this PR seems to be sort of a mix of the 2. The `HDiv` and `WDiv` dimensions are required to be 1, but the `N` and `C` dimensions are not. If you are going to restrict dimensions, then restrict them all to be `1` before decomposition, and just make the decomposition simpler, since the tile size selection will need to be aware of this requirement regardless of how many dimensions are required to be `1`. Option (2) is the more general way to do it, but it makes the decomposition a bit more complex (although you already have implemented some of the complexity here, so maybe worth trying). I will leave the decision between (1) and (2) up to you, but I would recommend picking one or the other.

https://github.com/llvm/llvm-project/pull/96183


More information about the Mlir-commits mailing list