[Mlir-commits] [mlir] be75cf9 - [mlir][tensor] Fix a bug in tiling unpack op.

Hanhan Wang llvmlistbot at llvm.org
Wed Feb 15 10:25:58 PST 2023


Author: Hanhan Wang
Date: 2023-02-15T10:25:50-08:00
New Revision: be75cf931f3fd3f23a04d3e8f441f09a3cf08ef9

URL: https://github.com/llvm/llvm-project/commit/be75cf931f3fd3f23a04d3e8f441f09a3cf08ef9
DIFF: https://github.com/llvm/llvm-project/commit/be75cf931f3fd3f23a04d3e8f441f09a3cf08ef9.diff

LOG: [mlir][tensor] Fix a bug in tiling unpack op.

The inner tiling sizes could be dynamic (which are Values). In this
context, they should be added to tiledOperand when cloning the op.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D143978

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/tiling.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 72ecf830ecfc5..e4367e049b11c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -402,9 +402,12 @@ struct UnPackOpTiling
                                     unpackOp.getDestType().getElementType());
     }
 
-    Operation *tiledUnpackOp =
-        b.create<UnPackOp>(loc, TypeRange{sliceDest.getType()},
-                           ValueRange{sliceSource, sliceDest}, op->getAttrs());
+    SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
+    for (auto tile : unpackOp.getInnerTiles())
+      tiledOperands.push_back(tile);
+
+    Operation *tiledUnpackOp = b.create<UnPackOp>(
+        loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
 
     if (isPerfectTilingCase)
       return {tiledUnpackOp};

diff  --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index 79276b6bd0dae..f7734fdf4ac88 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -580,8 +580,8 @@ transform.sequence failures(propagate) {
 // CHECK:       %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[P]], %[[Q]], %[[K]]]
 // CHECK:       %[[UNPACK:.+]] = tensor.unpack
 // CHECK-SAME:    %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
-// CHECK-SAME:    into %[[SLICE_DEST]] 
-// CHECK:       %[[RES:.+]] = tensor.insert_slice %[[UNPACK]] 
+// CHECK-SAME:    into %[[SLICE_DEST]]
+// CHECK:       %[[RES:.+]] = tensor.insert_slice %[[UNPACK]]
 // CHECK-SAME:    into %{{.+}}[0, %[[P]], %[[Q]], %[[K]]]
 // CHECK:       scf.yield %[[RES]]
 
@@ -598,6 +598,32 @@ transform.sequence failures(propagate) {
 
 // -----
 
+func.func private @get_dynamic_tile_size() -> index
+
+// CHECK-LABEL: func.func @fully_dynamic_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DST:[0-9a-zA-Z]+]]
+// CHECK:         %[[INNER_TS:.+]] = call @get_dynamic_tile_size() : () -> index
+// CHECK:         %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[DST]])
+// CHECK:           %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]])
+// CHECK:             %[[SLICE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK:             %[[EMPTY:.+]] = tensor.empty
+// CHECK:             %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
+// CHECK-SAME:          inner_dims_pos = [1, 0] inner_tiles = [%[[INNER_TS]], %[[INNER_TS]]] into %[[EMPTY]]
+func.func @fully_dynamic_unpack(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = func.call @get_dynamic_tile_size() : () -> index
+  %1 = tensor.unpack %source inner_dims_pos = [1, 0] inner_tiles = [%0, %0] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [4, 8]
+}
+
+// -----
+
 // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 6, 1)>
 // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)>
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -2 + 8, 2)>


        


More information about the Mlir-commits mailing list