[Mlir-commits] [mlir] b6060b7 - [mlir][Linalg] Fix element type of results when folding reshapes.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 5 15:40:54 PDT 2021


Author: MaheshRavishankar
Date: 2021-05-05T15:40:41-07:00
New Revision: b6060b76731da36e14ef96c789b79e3b23672973

URL: https://github.com/llvm/llvm-project/commit/b6060b76731da36e14ef96c789b79e3b23672973
DIFF: https://github.com/llvm/llvm-project/commit/b6060b76731da36e14ef96c789b79e3b23672973.diff

LOG: [mlir][Linalg] Fix element type of results when folding reshapes.

Fixing a minor bug which lead to element type of the output being
modified when folding reshapes with generic op.

Differential Revision: https://reviews.llvm.org/D101942

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/fusion-push-reshape.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d1646e92b8d44..7fd6d245ccb59 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1129,9 +1129,12 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
     SmallVector<Value> newOutputs;
     SmallVector<Type> newOutputTypes;
     for (auto output : op.outputs()) {
+      auto newOutputType = RankedTensorType::get(
+          reshapeFound.getSrcType().getShape(),
+          output.getType().template cast<RankedTensorType>().getElementType());
       Value newOutput = rewriter.create<TensorReshapeOp>(
-          op->getLoc(), reshapeFound.getSrcType(), output, reassociation);
-      newOutputTypes.push_back(newOutput.getType());
+          op->getLoc(), newOutputType, output, reassociation);
+      newOutputTypes.push_back(newOutputType);
       newOutputs.push_back(newOutput);
     }
     // 5. Create a new generic op with lowerer rank.

diff  --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index eda7d460a5268..cc46bd3c273f4 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -88,3 +88,40 @@ func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<
   } -> tensor<112x112x16xf32>
   return %22 : tensor<112x112x16xf32>
 }
+
+// -----
+
+func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
+    %arg2 : tensor<5xf32>) -> tensor<2x3x5xf32> {
+  %cst_6 = constant 1.000000e+00 : f32
+  %cst_7 = constant 7.000000e+00 : f32
+  %cst_8 = constant 1.1920929E-7 : f32
+  %25 = linalg.tensor_reshape %arg0
+      [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+      : tensor<6x5xi32> into tensor<2x3x5xi32>
+  %26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32>
+  %28 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%25, %arg1, %arg2 : tensor<2x3x5xi32>, tensor<5xf32>, tensor<5xf32>)
+      outs(%26 : tensor<2x3x5xf32>) {
+      ^bb0(%arg6: i32, %arg7: f32, %arg8: f32, %arg9: f32):  // no predecessors
+        %29 = sitofp %arg6 : i32 to f32
+        %30 = addf %arg7, %cst_8 : f32
+        %31 = divf %cst_7, %30 : f32
+        %32 = divf %cst_6, %31 : f32
+        %33 = mulf %29, %32 : f32
+        %34 = addf %33, %arg8 : f32
+        linalg.yield %34 : f32
+      } -> tensor<2x3x5xf32>
+  return %28 : tensor<2x3x5xf32>
+}
+// CHECK-LABEL: func @type_correctness
+//       CHECK:   %[[OP:.+]] = linalg.generic
+//  CHECK-SAME:   ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>)
+//  CHECK-SAME:   outs(%{{.+}} : tensor<6x5xf32>)
+//       CHECK:   linalg.tensor_reshape %[[OP]]
+//  CHECK-SAME:   tensor<6x5xf32> into tensor<2x3x5xf32>


        


More information about the Mlir-commits mailing list