[Mlir-commits] [mlir] b6060b7 - [mlir][Linalg] Fix element type of results when folding reshapes.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 5 15:40:54 PDT 2021
Author: MaheshRavishankar
Date: 2021-05-05T15:40:41-07:00
New Revision: b6060b76731da36e14ef96c789b79e3b23672973
URL: https://github.com/llvm/llvm-project/commit/b6060b76731da36e14ef96c789b79e3b23672973
DIFF: https://github.com/llvm/llvm-project/commit/b6060b76731da36e14ef96c789b79e3b23672973.diff
LOG: [mlir][Linalg] Fix element type of results when folding reshapes.
Fixing a minor bug which lead to element type of the output being
modified when folding reshapes with generic op.
Differential Revision: https://reviews.llvm.org/D101942
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d1646e92b8d44..7fd6d245ccb59 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1129,9 +1129,12 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
SmallVector<Value> newOutputs;
SmallVector<Type> newOutputTypes;
for (auto output : op.outputs()) {
+ auto newOutputType = RankedTensorType::get(
+ reshapeFound.getSrcType().getShape(),
+ output.getType().template cast<RankedTensorType>().getElementType());
Value newOutput = rewriter.create<TensorReshapeOp>(
- op->getLoc(), reshapeFound.getSrcType(), output, reassociation);
- newOutputTypes.push_back(newOutput.getType());
+ op->getLoc(), newOutputType, output, reassociation);
+ newOutputTypes.push_back(newOutputType);
newOutputs.push_back(newOutput);
}
// 5. Create a new generic op with lowerer rank.
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index eda7d460a5268..cc46bd3c273f4 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -88,3 +88,40 @@ func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<
} -> tensor<112x112x16xf32>
return %22 : tensor<112x112x16xf32>
}
+
+// -----
+
+func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
+ %arg2 : tensor<5xf32>) -> tensor<2x3x5xf32> {
+ %cst_6 = constant 1.000000e+00 : f32
+ %cst_7 = constant 7.000000e+00 : f32
+ %cst_8 = constant 1.1920929E-7 : f32
+ %25 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+ : tensor<6x5xi32> into tensor<2x3x5xi32>
+ %26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32>
+ %28 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%25, %arg1, %arg2 : tensor<2x3x5xi32>, tensor<5xf32>, tensor<5xf32>)
+ outs(%26 : tensor<2x3x5xf32>) {
+ ^bb0(%arg6: i32, %arg7: f32, %arg8: f32, %arg9: f32): // no predecessors
+ %29 = sitofp %arg6 : i32 to f32
+ %30 = addf %arg7, %cst_8 : f32
+ %31 = divf %cst_7, %30 : f32
+ %32 = divf %cst_6, %31 : f32
+ %33 = mulf %29, %32 : f32
+ %34 = addf %33, %arg8 : f32
+ linalg.yield %34 : f32
+ } -> tensor<2x3x5xf32>
+ return %28 : tensor<2x3x5xf32>
+}
+// CHECK-LABEL: func @type_correctness
+// CHECK: %[[OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
+// CHECK: linalg.tensor_reshape %[[OP]]
+// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>
More information about the Mlir-commits
mailing list