[Mlir-commits] [mlir] 21e6e70 - [mlir][linalg] Match element type of result when doing propagation of unpack through elementwise

Quinn Dawkins llvmlistbot at llvm.org
Tue Feb 21 10:48:34 PST 2023


Author: Quinn Dawkins
Date: 2023-02-21T13:45:17-05:00
New Revision: 21e6e70ccc952458b1b21fc0b967ba27ca9fa6ba

URL: https://github.com/llvm/llvm-project/commit/21e6e70ccc952458b1b21fc0b967ba27ca9fa6ba
DIFF: https://github.com/llvm/llvm-project/commit/21e6e70ccc952458b1b21fc0b967ba27ca9fa6ba.diff

LOG: [mlir][linalg] Match element type of result when doing propagation of unpack through elementwise

When propagating tensor.unpack ops through elementwise generics, a new
output tensor is needed if the element type of the input differs from
that of the output in the elementwise op.

Differential Revision: https://reviews.llvm.org/D144438

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index bf5e64ba1f34a..206f6c51a4929 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -434,14 +434,30 @@ pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter,
   GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest,
                                              packedOutIndexingMap, packInfo);
 
-  auto unPackOp = unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
+  // If the output element type for the generic 
diff ers from the source
+  // unpack op, we need to create a new destination tensor.
+  auto loc = genericOp.getLoc();
+  Value unPackDest = producerUnPackOp.getDest();
+  auto genericOutElementType = getElementTypeOrSelf(genericOp.getResult(0));
+  if (producerUnPackOp.getDestType().getElementType() !=
+      genericOutElementType) {
+    SmallVector<OpFoldResult> unPackMixedSizes;
+    if (auto unPackEmpty = unPackDest.getDefiningOp<tensor::EmptyOp>())
+      unPackMixedSizes = unPackEmpty.getMixedSizes();
+    else
+      unPackMixedSizes = tensor::getMixedSizes(rewriter, loc, unPackDest);
+
+    unPackDest = rewriter.create<tensor::EmptyOp>(loc, unPackMixedSizes,
+                                                  genericOutElementType);
+  }
+
   // Insert an unPackOp right after the packed generic.
   Value unPackOpRes =
       rewriter
           .create<tensor::UnPackOp>(
-              genericOp.getLoc(),
+              loc,
               newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
-              unPackOp.getDest(), producerUnPackOp.getInnerDimsPos(),
+              unPackDest, producerUnPackOp.getInnerDimsPos(),
               producerUnPackOp.getMixedTiles(),
               producerUnPackOp.getOuterDimsPerm())
           .getResult();

diff  --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 32190cac8d2f8..afe184b655adc 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -441,6 +441,46 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
+func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
+  %0 = tensor.empty() : tensor<12x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
+  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
+    ^bb0(%in: f32, %out: f16):
+      %3 = arith.truncf %in : f32 to f16
+      linalg.yield %3 : f16
+  } -> tensor<12x56x56x64xf16>
+  return %2 : tensor<12x56x56x64xf16>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func.func @unpack_element_type_change
+// CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
+// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
+// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG1_PACK_EMPTY]]
+// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME:  ins(%[[ARG0_PACK]]
+// CHECK-SAME:  outs(%[[ARG1_PACK]]
+// CHECK: %[[ARG0_NEW_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf16>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG0_NEW_EMPTY_UNPACK]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
 func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
   %init = tensor.empty() : tensor<12x56x56x64xf32>
   %0 = tensor.empty() : tensor<12x56x56x64xf32>


        


More information about the Mlir-commits mailing list