[Mlir-commits] [mlir] [TOSA] FFT2D operator (PR #77005)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Jan 9 19:50:41 PST 2024
================
@@ -2344,6 +2344,134 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
}
};
+struct FFT2dConverter final : public OpRewritePattern<FFT2dOp> {
+ using OpRewritePattern<FFT2dOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FFT2dOp fft2d,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::all_of(fft2d->getOperandTypes(),
+ RFFT2dConverter::isRankedTensor) ||
+ !llvm::all_of(fft2d->getResultTypes(),
+ RFFT2dConverter::isRankedTensor)) {
+ return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
+ }
+
+ auto loc = fft2d.getLoc();
+ auto input_real = fft2d.getInputReal();
+ auto input_imag = fft2d.getInputImag();
+ auto inverse = fft2d.getInverseAttr();
+
+ auto real_el_ty = input_real.getType()
+ .cast<ShapedType>()
+ .getElementType()
+ .cast<FloatType>();
+ auto imag_el_ty = input_imag.getType()
+ .cast<ShapedType>()
+ .getElementType()
+ .cast<FloatType>();
+
+ assert(real_el_ty == imag_el_ty);
+
+ // Compute the output type and set of dynamic sizes
+ llvm::SmallVector<Value> dynamicSizes;
+
+ // Get [N, H, W]
+ auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
----------------
kuhar wrote:
similar in the other code below
https://github.com/llvm/llvm-project/pull/77005
More information about the Mlir-commits
mailing list