[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)

Hsiangkai Wang llvmlistbot at llvm.org
Wed Jul 17 07:14:05 PDT 2024


================
@@ -38,6 +191,329 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
 constexpr TransformMapKeyTy F_4_3{4, 3};
 constexpr TransformMapKeyTy F_2_5{2, 5};
 
+/// Structure to keep information of constant transform matrices.
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+
+  return builder.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               constVec));
+}
+
+/// Extract height x width data from 4D or 6D tensors.
+Value extract2DData(OpBuilder &builder, Location loc, Value source,
+                    Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+                    int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+                    int64_t srcSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  auto sourceShape = sourceType.getShape();
+  int64_t height = sourceShape[heightIdx];
+  int64_t width = sourceShape[widthIdx];
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+  offsets[outLoopIdx] = outLoopIndex;
+  offsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+  sizes[heightIdx] = builder.getIndexAttr(height);
+  sizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+  SmallVector<int64_t> targetShape(srcSize, 1);
+  targetShape[heightIdx] = height;
+  targetShape[widthIdx] = width;
+
+  auto targetType = RankedTensorType::get(targetShape, elementType);
+  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+      loc, targetType, source, offsets, sizes, strides);
+
+  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      builder, loc, extractFilterOp, extractFilterType);
+
+  return extractFilter;
+}
+
+/// Insert transformed height x width data to 4D or 6D tensors which it is
+/// extracted from.
+Value insert2DData(OpBuilder &builder, Location loc, Value source, Value dest,
+                   Value outLoopIndex, Value inLoopIndex, int64_t height,
+                   int64_t width, int64_t outLoopIdx, int64_t inLoopIdx,
+                   int64_t heightIdx, int64_t widthIdx, int64_t destSize) {
+  auto sourceType = cast<ShapedType>(source.getType());
+  Type elementType = sourceType.getElementType();
+  SmallVector<int64_t> sliceShape(destSize, 1);
+  sliceShape[heightIdx] = height;
+  sliceShape[widthIdx] = width;
+  auto init = builder.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(builder, loc,
+                                                                 source, init);
+
+  auto zeroIndex = builder.getIndexAttr(0);
+  auto oneIndex = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+  retOffsets[outLoopIdx] = outLoopIndex;
+  retOffsets[inLoopIdx] = inLoopIndex;
+  SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+  retSizes[heightIdx] = builder.getIndexAttr(height);
+  retSizes[widthIdx] = builder.getIndexAttr(width);
+  SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+      loc, result, dest, retOffsets, retSizes, strides);
----------------
Hsiangkai wrote:

Done.

https://github.com/llvm/llvm-project/pull/96183


More information about the Mlir-commits mailing list