[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)

Hsiangkai Wang llvmlistbot at llvm.org
Mon Jul 15 02:09:35 PDT 2024


================
@@ -36,6 +190,329 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+/// Structure to keep information of constant transform matrices.
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+
+  return builder.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               constVec));
+}
+
+/// Extract height x width data from 4D or 6D tensors.
+Value extract2DData(OpBuilder &builder, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                    int64_t srcSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = builder.getIndexAttr(height);
+  sizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+  SmallVector<int64_t> targetShape(srcSize, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      builder, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+/// Insert transformed height x width data to 4D or 6D tensors which it is
+/// extracted from.
+Value insert2DData(OpBuilder &builder, Location loc, Value source, Value dest,
+                   Value outLoopIndex, Value inLoopIndex, int64_t height,
+                   int64_t width, int64_t outLoopIdx, int64_t inLoopIdx,
+                   int64_t heightIdx, int64_t widthIdx, int64_t destSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(destSize, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = builder.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(builder, loc,
+                                                                 source, init);
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = builder.getIndexAttr(height);
+  retSizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
+
+  return insertSliceOp;
+}
+
+/// This function transforms the filter. The data layout of the filter is FHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+/// After the transformation, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      Value retValue, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to G transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GMatrices = {
+          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+      };
+
+  // Map from (m, r) to GT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      GTMatrices = {
+          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+      };
+
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+
+  if (filterH != r && filterH != 1)
+    return Value();
+  if (filterW != r && filterW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value FIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (F, H, W, C).
+    auto extractFilter = extract2DData(
+        builder, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    Value matmulRetValue = extractFilter;
+    if (leftTransform) {
+      // Get constant transform matrix G.
+      auto it = GMatrices.find(key);
+      if (it == GMatrices.end())
+        return {};
+      const TransformMatrix &GMatrix = it->second;
+
+      retRows = GMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+      // Multiply G x g.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix GT.
+      auto it = GTMatrices.find(key);
+      if (it == GTMatrices.end())
+        return {};
+      const TransformMatrix &GTMatrix = it->second;
+
+      auto matmulType =
+          RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+      // Multiply u = (G x g) x GT.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, C, F).
+    int64_t retHeight = leftTransform ? m + r - 1 : 1;
+    int64_t retWidth = rightTransform ? m + r - 1 : 1;
+    auto insertSliceOp = insert2DData(builder, loc, matmulRetValue, args[0],
+                                      FIter, CIter, retHeight, retWidth,
+                                      /*outLoopIdx=*/3, /*inLoopIdx=*/2,
+                                      /*heightIdx=*/0, /*widthIdx=*/1,
+                                      /*destSize=*/4);
+
+    return {insertSliceOp};
+  };
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  scf::LoopNest loops = scf::buildLoopNest(
+      rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
+      {oneStep, oneStep}, {retValue}, buildBody);
+  return loops.results[0];
+}
+
+/// This function transforms the input. The data layout of the input is NHWC.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+/// After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     Value retValue, int64_t m, int64_t r,
+                     bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to BT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BTMatrices = {
+          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+      };
+
+  // Map from (m, r) to B transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      BMatrices = {
+          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+      };
+
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (inputH != alphaH && inputH != 1)
+    return Value();
+  if (inputW != alphaW && inputW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value NIter = ivs[0];
+    Value CIter = ivs[1];
+
+    // Extract (H, W) from (N, H, W, C).
+    auto extractInput = extract2DData(
+        builder, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+        /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4);
+
+    TransformMapKeyTy key = {m, r};
+    int64_t retRows = 1;
+    int64_t retCols = 1;
+    Value matmulRetValue = extractInput;
+    if (leftTransform) {
+      // Get constant transform matrix BT.
+      auto it = BTMatrices.find(key);
+      if (it == BTMatrices.end())
+        return {};
+      const TransformMatrix &BTMatrix = it->second;
+
+      retRows = BTMatrix.rows;
+      auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+
+      Value BT =
+          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+      // Multiply BT x d.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    if (rightTransform) {
+      // Get constant transform matrix B.
+      auto it = BMatrices.find(key);
+      if (it == BMatrices.end())
+        return {};
+      const TransformMatrix &BMatrix = it->second;
+
+      retCols = BMatrix.cols;
+      auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType);
+      Value B =
+          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+      // Multiply v = (BT x d) x B.
+      auto matmulOp = builder.create<linalg::MatmulOp>(
+          loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+      matmulRetValue = matmulOp.getResult(0);
+    }
+
+    // Insert (H, W) to (H, W, 1, 1, N, C).
----------------
Hsiangkai wrote:

After trying to deal with tileH, and tileW in (alphH, alphaW, tileH, tileW, N, C) inside input decomposition, we need to deal with tiling for H and W when iterating through tileH and tileW. However, there is similar logic implemented in TilingInterface already. I suggest to keep decomposition simple and avoid the duplication for tiling H and W in decomposition. For N and C, there is no relation to H and W. So, it is simple to keep it inside decomposition and make it flexible as current implementation. I suggest to keep my original implementation. What do you think, @Max191?

https://github.com/llvm/llvm-project/pull/96183


More information about the Mlir-commits mailing list