[Mlir-commits] [mlir] [TOSA] Usage of 32bit integer for 'index to float' in rfft2d (PR #75098)

Dmitriy Smirnov llvmlistbot at llvm.org
Thu Dec 14 03:47:03 PST 2023


https://github.com/d-smirnov updated https://github.com/llvm/llvm-project/pull/75098

>From a2e20d086a95e60d45329c6e421379f62f2db90a Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Mon, 11 Dec 2023 12:04:43 +0000
Subject: [PATCH 1/2] [TOSA] Usage of 32bit integer for 'index to float' in
 rfft2d

Lowering of rfft2d to linalg now uses index to i32 cast
if an output float is of 32bit and 64bit otherwise.
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  7 ++-
 .../TosaToLinalg/tosa-to-linalg.mlir          | 48 +++++++++----------
 2 files changed, 29 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96f..3422862fc76b22 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2284,8 +2284,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
 
   static Value castIndexToFloat(OpBuilder &builder, Location loc,
                                 FloatType type, Value value) {
-    auto integerVal =
-        builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
+    auto integerVal = builder.create<arith::IndexCastUIOp>(
+        loc,
+        32 < type.getIntOrFloatBitWidth() ? builder.getI64Type()
+                                          : builder.getI32Type(),
+        value);
 
     return builder.create<arith::UIToFPOp>(loc, type, integerVal);
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aa53b366f6da68..6f9597dd399af6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1691,10 +1691,10 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:   %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
 // CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
 // CHECK:   %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
-// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
-// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
-// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
+// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
+// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
+// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
+// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1702,17 +1702,17 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:     outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
+// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
+// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
+// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
+// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
 // CHECK:     %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
 // CHECK:     %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
@@ -1761,10 +1761,10 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
 // CHECK:   %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
 // CHECK:   %[[CST_9:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
-// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
-// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
-// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
+// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i32
+// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i32 to f32
+// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i32
+// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1772,17 +1772,17 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:     outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
+// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
+// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
+// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
+// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i32 to f32
 // CHECK:     %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
 // CHECK:     %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32

>From f055337dff55826eec4d7cbaaa4727f427cca603 Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Thu, 14 Dec 2023 11:26:53 +0000
Subject: [PATCH 2/2] wip - fft2d

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 215 ++++++++++++++++++
 1 file changed, 215 insertions(+)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3422862fc76b22..0ba5bca3594451 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2393,6 +2393,220 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
   }
 };
 
+/*
+LEVEL_CHECK(H <= MAX_KERNEL);
+LEVEL_CHECK(W <= MAX_KERNEL);
+
+ERROR_IF(!power_of_two(H));
+ERROR_IF(!power_of_two(W));
+
+float sign_val = 1.0;
+
+if (inverse) {
+    sign_val = -1.0;
+}
+
+for_each(0 <= n < N, 0 <= oy < H, 0 <= ox < W) {
+    in_out_t sum_real = 0.0;
+    in_out_t sum_imag = 0.0;
+    for_each(0 <= iy < H, 0 <= ix < W) {
+        in_out_t val_real = tensor_read<in_out_t>(input_real, [N,H,W], [n,iy,ix]);
+        in_out_t val_imag = tensor_read<in_out_t>(input_imag, [N,H,W], [n,iy,ix]);
+        float_t a = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);
+        sum_real += val_real * cos(a) + val_imag * sin(a);
+        sum_imag += -val_real * sin(a) + val_imag * cos(a);
+    }
+    tensor_write<in_out_t>(output_real, [N,H,W], [n,oy,ox], sum_real);
+    tensor_write<in_out_t>(output_imag, [N,H,W], [n,oy,ox], sum_imag);
+}
+*/
+struct FFT2dConverter final : public OpRewritePattern<FFT2dOp> {
+  using OpRewritePattern<FFT2dOp>::OpRewritePattern;
+
+  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
+
+  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
+                                  OpFoldResult ofr) {
+    auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
+    auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
+
+    auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
+    auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
+    auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
+    return getAsOpFoldResult(plusOne);
+  }
+
+  static RankedTensorType
+  computeOutputShape(OpBuilder &builder, Location loc, Value input,
+                     llvm::SmallVectorImpl<Value> &dynamicSizes) {
+    // Get [N, H, W]
+    auto dims = tensor::getMixedSizes(builder, loc, input);
+
+    // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
+    // output tensors.
+    dims[2] = halfPlusOne(builder, loc, dims[2]);
+
+    llvm::SmallVector<int64_t, 3> staticSizes;
+    dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
+
+    auto elementType =
+        input.getType().cast<RankedTensorType>().getElementType();
+    return RankedTensorType::get(staticSizes, elementType);
+  }
+
+  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
+                                RankedTensorType type,
+                                llvm::ArrayRef<Value> dynamicSizes) {
+    auto emptyTensor =
+        rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
+    auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
+    auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
+    auto filledTensor = rewriter
+                            .create<linalg::FillOp>(loc, ValueRange{fillValue},
+                                                    ValueRange{emptyTensor})
+                            .result();
+    return filledTensor;
+  }
+
+  static Value castIndexToFloat(OpBuilder &builder, Location loc,
+                                FloatType type, Value value) {
+    auto integerVal = builder.create<arith::IndexCastUIOp>(
+        loc,
+        32 < type.getIntOrFloatBitWidth() ? builder.getI64Type()
+                                          : builder.getI32Type(),
+        value);
+
+    return builder.create<arith::UIToFPOp>(loc, type, integerVal);
+  }
+
+  static Value createLinalgIndex(OpBuilder &builder, Location loc,
+                                 FloatType type, int64_t index) {
+    auto indexVal = builder.create<linalg::IndexOp>(loc, index);
+    return castIndexToFloat(builder, loc, type, indexVal);
+  }
+
+  template <typename... Args>
+  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
+                                                         Args... args) {
+    return {builder.getAffineDimExpr(args)...};
+  }
+
+  LogicalResult matchAndRewrite(FFT2dOp fft2d,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::all_of(fft2d->getOperandTypes(), isRankedTensor) ||
+        !llvm::all_of(fft2d->getResultTypes(), isRankedTensor)) {
+      return rewriter.notifyMatchFailure(fft2d,
+                                         "only supports ranked tensors");
+    }
+
+    auto loc = fft2d.getLoc();
+    auto input_real = fft2d.getInputReal();
+    auto input_imag = fft2d.getInputImag();
+    auto inverse = fft2d.getInverseAttr();
+
+    float sign_val = 1.0;
+    if (inverse) {
+        sign_val = -1.0;
+    }
+
+    auto real_el_ty = input_real.getType().cast<ShapedType>().getElementType().cast<FloatType>();
+    auto imag_el_ty = input_imag.getType().cast<ShapedType>().getElementType().cast<FloatType>();
+
+    assert(real_el_ty == imag_el_ty);
+
+    // Compute the output type and set of dynamic sizes
+    llvm::SmallVector<Value> dynamicSizes;
+
+    auto outputType = RankedTensorType::get(cast<ShapedType>(input_real.getType()).getShape(), real_el_ty);
+
+    // Iterator types for the linalg.generic implementation
+    llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
+        utils::IteratorType::parallel, utils::IteratorType::parallel,
+        utils::IteratorType::parallel, utils::IteratorType::reduction,
+        utils::IteratorType::reduction};
+
+    // Inputs/outputs to the linalg.generic implementation
+    llvm::SmallVector<Value> genericOpInputs = {input_real, input_imag};
+    llvm::SmallVector<Value> genericOpOutputs = {
+        createZeroTensor(rewriter, loc, outputType, dynamicSizes),
+        createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
+
+    // Indexing maps for input and output tensors
+    auto indexingMaps = AffineMap::inferFromExprList(llvm::ArrayRef{
+        affineDimsExpr(rewriter, 0, 3, 4), affineDimsExpr(rewriter, 0, 3, 4),
+        affineDimsExpr(rewriter, 0, 1, 2), affineDimsExpr(rewriter, 0, 1, 2)});
+
+    // Width and height dimensions of the original input.
+    auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
+    auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
+
+    // Constants and dimension sizes
+    auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
+    auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
+    auto constH = castIndexToFloat(rewriter, loc, real_el_ty, dimH);
+    auto constW = castIndexToFloat(rewriter, loc, real_el_ty, dimW);
+
+    auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
+      Value valReal = args[0];
+      Value valImag = args[1];
+      Value sumReal = args[2];
+      Value sumImag = args[3];
+
+      // Indices for angle computation
+      auto oy = createLinalgIndex(builder, loc, real_el_ty, 1);
+      auto ox = createLinalgIndex(builder, loc, real_el_ty, 2);
+      auto iy = createLinalgIndex(builder, loc, real_el_ty, 3);
+      auto ix = createLinalgIndex(builder, loc, real_el_ty, 4);
+      /*float_t a = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);*/
+                    // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W)
+      auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
+      auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
+      auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
+      auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+      auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
+      auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+      if( inverse.getValue() ) {
+          angle = builder.create<arith::MulFOp>(loc, angle, rewriter.create<arith::ConstantOp>(loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
+      }
+/*
+      sum_real += val_real * cos(a) + val_imag * sin(a);
+      sum_imag += -val_real * sin(a) + val_imag * cos(a);
+*/
+      // realComponent = valReal * cos(angle)
+      // imagComponent = valReal * sin(angle)
+      auto cosAngle = builder.create<math::CosOp>(loc, angle);
+      auto sinAngle = builder.create<math::SinOp>(loc, angle);
+
+      auto rcos =
+          builder.create<arith::MulFOp>(loc, valReal, cosAngle);
+      auto rsin =
+          builder.create<arith::MulFOp>(loc, valImag, sinAngle);
+      auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
+
+      auto icos =
+          builder.create<arith::MulFOp>(loc, valImag, cosAngle);
+      auto isin =
+          builder.create<arith::MulFOp>(loc, valReal, sinAngle);
+
+      auto imagComponent =
+          builder.create<arith::SubFOp>(loc, icos, isin);
+
+      // outReal = sumReal + realComponent
+      // outImag = sumImag - imagComponent
+      auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
+      auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
+
+      builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
+    };
+
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
+        indexingMaps, iteratorTypes, buildBody);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgConversionPatterns(
@@ -2456,6 +2670,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       RescaleConverter,
       ReverseConverter,
       RFFT2dConverter,
+      FFT2dConverter,
       TableConverter,
       TileConverter,
       TransposeConverter>(patterns->getContext());



More information about the Mlir-commits mailing list