[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)

Hsiangkai Wang llvmlistbot at llvm.org
Wed Jul 17 07:15:44 PDT 2024


================
@@ -107,6 +584,162 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
   return expandOutput;
 }
 
+/// This function transforms the output. The data layout of the output is HWNF.
+/// The transformation matrix is 2-dimension. We need to extract H x W from
+/// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+/// After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %f = lo_f to hi_f step 1
+///     %extracted = extract input<h x w> from result<h x w x n x f>
+///     %ret = linalg.matmul AT, %extracted
+///     %ret = linalg.matmul %ret, A
+///     %inserted = insert %ret into ret<n x h x w x f>
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+                      Value output, int64_t m, int64_t r,
+                      bool leftTransform = true, bool rightTransform = true) {
+  // Map from (m, r) to AT transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      ATMatrices = {
+          {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+          {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+          {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+      };
+
+  // Map from (m, r) to A transform matrix.
+  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+      AMatrices = {
+          {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+          {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+          {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+      };
+
+  auto valueType = cast<ShapedType>(value.getType());
+  Type elementType = valueType.getElementType();
+  auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
+  int64_t valueN = valueShape[4];
+  int64_t valueF = valueShape[5];
+  int64_t alphaH = leftTransform ? m + r - 1 : 1;
+  int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+  if (valueH != alphaH && valueH != 1)
+    return Value();
+  if (valueW != alphaW && valueW != 1)
+    return Value();
+
+  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
+                       ValueRange args) -> scf::ValueVector {
+    Value NIter = ivs[0];
+    Value FIter = ivs[1];
+
+    // Extract (H, W) from (H, W, 1, 1, N, F).
----------------
Hsiangkai wrote:

Done.

https://github.com/llvm/llvm-project/pull/96183


More information about the Mlir-commits mailing list