[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 11 07:33:59 PDT 2024
================
@@ -38,6 +191,329 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
constexpr TransformMapKeyTy F_4_3{4, 3};
constexpr TransformMapKeyTy F_2_5{2, 5};
+/// Structure to keep information of constant transform matrices.
+struct TransformMatrix {
+ TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ int64_t scalarFactor = 1)
+ : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+ const float *table;
+ int64_t rows;
+ int64_t cols;
+ int64_t scalarFactor;
+};
+
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
+ TransformMatrix transform, Type type) {
+ ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+
+ return builder.create<arith::ConstantOp>(
+ loc, DenseFPElementsAttr::get(
+ RankedTensorType::get(
+ SmallVector<int64_t>{transform.rows, transform.cols}, type),
+ constVec));
+}
+
+/// Extract height x width data from 4D or 6D tensors.
+Value extract2DData(OpBuilder &builder, Location loc, Value source,
+ Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+ int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+ int64_t srcSize) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ auto sourceShape = sourceType.getShape();
+ int64_t height = sourceShape[heightIdx];
+ int64_t width = sourceShape[widthIdx];
+
+ auto zeroIndex = builder.getIndexAttr(0);
+ auto oneIndex = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+ offsets[outLoopIdx] = outLoopIndex;
+ offsets[inLoopIdx] = inLoopIndex;
+ SmallVector<OpFoldResult, 6> sizes(srcSize, oneIndex);
+ sizes[heightIdx] = builder.getIndexAttr(height);
+ sizes[widthIdx] = builder.getIndexAttr(width);
+ SmallVector<OpFoldResult, 6> strides(srcSize, oneIndex);
+ SmallVector<int64_t> targetShape(srcSize, 1);
+ targetShape[heightIdx] = height;
+ targetShape[widthIdx] = width;
+
+ auto targetType = RankedTensorType::get(targetShape, elementType);
+ auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
+ loc, targetType, source, offsets, sizes, strides);
+
+ auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+ auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+ builder, loc, extractFilterOp, extractFilterType);
+
+ return extractFilter;
+}
+
+/// Insert transformed height x width data to 4D or 6D tensors which it is
+/// extracted from.
+Value insert2DData(OpBuilder &builder, Location loc, Value source, Value dest,
+ Value outLoopIndex, Value inLoopIndex, int64_t height,
+ int64_t width, int64_t outLoopIdx, int64_t inLoopIdx,
+ int64_t heightIdx, int64_t widthIdx, int64_t destSize) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ SmallVector<int64_t> sliceShape(destSize, 1);
+ sliceShape[heightIdx] = height;
+ sliceShape[widthIdx] = width;
+ auto init = builder.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+ auto result = tensor::createCanonicalRankReducingInsertSliceOp(builder, loc,
+ source, init);
+
+ auto zeroIndex = builder.getIndexAttr(0);
+ auto oneIndex = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult, 6> retOffsets(destSize, zeroIndex);
+ retOffsets[outLoopIdx] = outLoopIndex;
+ retOffsets[inLoopIdx] = inLoopIndex;
+ SmallVector<OpFoldResult, 6> retSizes(destSize, oneIndex);
+ retSizes[heightIdx] = builder.getIndexAttr(height);
+ retSizes[widthIdx] = builder.getIndexAttr(width);
+ SmallVector<OpFoldResult, 6> strides(destSize, oneIndex);
+
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result, dest, retOffsets, retSizes, strides);
----------------
Max191 wrote:
You can just pass `source` instead of `result` here and avoid needing an extra insert_slice above.
https://github.com/llvm/llvm-project/pull/96183
More information about the Mlir-commits
mailing list