[Mlir-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 11 07:34:00 PDT 2024
================
@@ -38,6 +191,329 @@ constexpr TransformMapKeyTy F_2_3{2, 3};
constexpr TransformMapKeyTy F_4_3{4, 3};
constexpr TransformMapKeyTy F_2_5{2, 5};
+/// Structure to keep information of constant transform matrices.
+struct TransformMatrix {
+ TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ int64_t scalarFactor = 1)
+ : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+ const float *table;
+ int64_t rows;
+ int64_t cols;
+ int64_t scalarFactor;
+};
+
+/// Utility function to convert constant array to arith.constant Value.
+Value create2DTransformMatrix(OpBuilder &builder, Location loc,
+ TransformMatrix transform, Type type) {
+ ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+
+ return builder.create<arith::ConstantOp>(
+ loc, DenseFPElementsAttr::get(
+ RankedTensorType::get(
+ SmallVector<int64_t>{transform.rows, transform.cols}, type),
+ constVec));
+}
+
+/// Extract height x width data from 4D or 6D tensors.
+Value extract2DData(OpBuilder &builder, Location loc, Value source,
+ Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+ int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx,
+ int64_t srcSize) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ auto sourceShape = sourceType.getShape();
+ int64_t height = sourceShape[heightIdx];
+ int64_t width = sourceShape[widthIdx];
+
+ auto zeroIndex = builder.getIndexAttr(0);
+ auto oneIndex = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult, 6> offsets(srcSize, zeroIndex);
+ offsets[outLoopIdx] = outLoopIndex;
+ offsets[inLoopIdx] = inLoopIndex;
----------------
Max191 wrote:
I don't fully understand what `outLoopIndex` and `inLoopIndex` are supposed to represent. If this could be extracting from a 6D tensor, should this be able to set offsets for the 2 dims that don't fall under `{outLoopIdx, inLoopIdx, heightIdx, widthIdx}`?
https://github.com/llvm/llvm-project/pull/96183
More information about the Mlir-commits
mailing list