[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 15 10:23:51 PDT 2024


================
@@ -36,6 +190,329 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+/// Structure to keep information of constant transform matrices.
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+
+  return builder.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               constVec));
+}
+
+/// Extract height x width data from 4D or 6D tensors.
+Value extract2DData(OpBuilder &builder, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                    int64_t srcSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = builder.getIndexAttr(height);
+  sizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+  SmallVector<int64_t> targetShape(srcSize, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      builder, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+/// Insert transformed height x width data to 4D or 6D tensors which it is
+/// extracted from.
+Value insert2DData(OpBuilder &builder, Location loc, Value source, Value dest,
+                   Value outLoopIndex, Value inLoopIndex, int64_t height,
+                   int64_t width, int64_t outLoopIdx, int64_t inLoopIdx,
+                   int64_t heightIdx, int64_t widthIdx, int64_t destSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(destSize, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = builder.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(builder, loc,
+                                                                 source, init);
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = builder.getIndexAttr(height);
+  retSizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
+/// This function transforms the filter. The data layout of the filter is FHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+/// After the transformation, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value FIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (F, H, W, C).
+    auto extractFilter = extract2DData(
+        builder, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    Value matmulRetValue = extractFilter;
+    if (leftTransform) {
+      // Get constant transform matrix G.
+      auto it = GMatrices.find(key);
+      if (it == GMatrices.end())
+        return {};
+      const TransformMatrix &GMatrix = it->second;
+
+      retRows = GMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+      // Multiply G x g.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix GT.
+      auto it = GTMatrices.find(key);
+      if (it == GTMatrices.end())
+        return {};
+      const TransformMatrix &GTMatrix = it->second;
+
+      auto matmulType =
+          RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+      // Multiply u = (G x g) x GT.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, C, F).
+    int64_t retHeight = leftTransform ? m + r - 1 : 1;
+    int64_t retWidth = rightTransform ? m + r - 1 : 1;
+    auto insertSliceOp = insert2DData(builder, loc, matmulRetValue, args[0],
+                                      FIter, CIter, retHeight, retWidth,
+                                      /*outLoopIdx=*/3, /*inLoopIdx=*/2,
+                                      /*heightIdx=*/0, /*widthIdx=*/1,
+                                      /*destSize=*/4);
+
+    return {insertSliceOp};
+  };
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
+      {oneStep, oneStep}, {retValue}, buildBody);
+  return loops.results[0];
+}
+
+/// This function transforms the input. The data layout of the input is NHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+/// After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (inputH != alphaH && inputH != 1)
+    return Value();
+  if (inputW != alphaW && inputW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value NIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (N, H, W, C).
+    auto extractInput = extract2DData(
+        builder, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    int64_t retCols = 1;
+    Value matmulRetValue = extractInput;
+    if (leftTransform) {
+      // Get constant transform matrix BT.
+      auto it = BTMatrices.find(key);
+      if (it == BTMatrices.end())
+        return {};
+      const TransformMatrix &BTMatrix = it->second;
+
+      retRows = BTMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value BT =
+          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+      // Multiply BT x d.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix B.
+      auto it = BMatrices.find(key);
+      if (it == BMatrices.end())
+        return {};
+      const TransformMatrix &BMatrix = it->second;
+
+      retCols = BMatrix.cols;
+      auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+      Value B =
+          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+      // Multiply v = (BT x d) x B.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, 1, 1, N, C).
----------------
Max191 wrote:

I would strongly recommend supporting tiling of all dimensions (except got alphaH, alphaW) in the TilingInterface implementation. TilingInterface is very useful for reducing problem sizes to shapes that fit nicely into caches and vector registers. Having control over all dimensions in the TilingInterface is very important for being able to do this. In the current implementation, the tiles will always contain the full size N and C dimensions (until the inner tile after decomposition, which will have a size of 1), which greatly restricts the options for tile size selection.

I would either refactor the logic into a shared utility function, or move the logic into the TilingInterface implementation. You can add the util function here if you decide to do that:
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Utils/Utils.cpp

https://github.com/llvm/llvm-project/pull/96183


More information about the Mlir-commits mailing list