[Mlir-commits] [mlir] d5a9fc1 - [MLIR] Fix tiling for `tensor.unpack` with outer permutations
Lorenzo Chelini
llvmlistbot at llvm.org
Sun Jan 15 23:42:23 PST 2023
Author: Lorenzo Chelini
Date: 2023-01-16T08:42:18+01:00
New Revision: d5a9fc13ef82f7f6c1a350bceef79f7988cdac20
URL: https://github.com/llvm/llvm-project/commit/d5a9fc13ef82f7f6c1a350bceef79f7988cdac20
DIFF: https://github.com/llvm/llvm-project/commit/d5a9fc13ef82f7f6c1a350bceef79f7988cdac20.diff
LOG: [MLIR] Fix tiling for `tensor.unpack` with outer permutations
An outer dim permutation requires adjusting the offsets and sizes of the
`tensor.extract_slice` operations generated during tiling. Originally
this was done by computing an inverse permutation of the outer
permutation for both `tensor.pack` and `tensor.unpack`. For packing, the
tiling is applied on interchanged dimensions; thus, it is correct to
compute the inverse. For unpacking, on the other hand, tiling involves
the output tensor that does not have interchanged dimensions, and no
inverse is required.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D141688
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/test/Dialect/Tensor/tiling.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d84f10294d28..3fc470e1874a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -94,14 +94,13 @@ static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
return loopBounds;
}
-static void applyInversePermToRange(SmallVector<OpFoldResult> &offsets,
- SmallVector<OpFoldResult> &sizes,
- ArrayRef<int64_t> permutation) {
+static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
+ SmallVector<OpFoldResult> &sizes,
+ ArrayRef<int64_t> permutation) {
if (permutation.empty())
return;
- SmallVector<int64_t> inversedPerm = invertPermutationVector(permutation);
- applyPermutationToVector<OpFoldResult>(offsets, inversedPerm);
- applyPermutationToVector<OpFoldResult>(sizes, inversedPerm);
+ applyPermutationToVector<OpFoldResult>(offsets, permutation);
+ applyPermutationToVector<OpFoldResult>(sizes, permutation);
}
struct PackOpTiling
@@ -133,7 +132,8 @@ struct PackOpTiling
int64_t inputRank = packOp.getSourceRank();
SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end());
SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end());
- applyInversePermToRange(origOffsets, origSizes, packOp.getOuterDimsPerm());
+ applyPermToRange(origOffsets, origSizes,
+ invertPermutationVector(packOp.getOuterDimsPerm()));
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
@@ -382,8 +382,8 @@ struct UnPackOpTiling
// The tiling is applied on destination dimensions. We have to apply the
// interchange on source dimensions if outer_dims_perm is set.
- applyInversePermToRange(sliceSrcIndices, sliceSrcSizes,
- unpackOp.getOuterDimsPerm());
+ applyPermToRange(sliceSrcIndices, sliceSrcSizes,
+ unpackOp.getOuterDimsPerm());
Attribute zeroAttr = b.getIndexAttr(0);
sliceSrcIndices.append(numInnerTiles, zeroAttr);
sliceSrcSizes.append(unpackOp.getMixedTiles());
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index 9a110493874b..c874666737f8 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -549,6 +549,7 @@ transform.sequence failures(propagate) {
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]]
// CHECK-SAME: into %{{.+}}[%[[K]], %[[C]]] [%[[OUT_K_SZ]], %[[OUT_C_SZ]]]
// CHECK: scf.yield %[[RES]]
+
func.func @dynamic_perfect_CKkc_to_KC(%source: tensor<?x?x2x2xf32>, %dest: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 2] into %dest : tensor<?x?x2x2xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
@@ -559,3 +560,77 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
%1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
}
+
+// -----
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+// CHECK: func.func @perfect_NKPQk_to_NPQK(
+// CHECK-SAME: %[[SOURCE:.+]]: tensor<1x4x6x6x2xf32>,
+// CHECK-SAME: %{{.+}}: tensor<1x6x6x8xf32>)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %{{.+}} = scf.for %[[P:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
+// CHECK: %{{.+}} = scf.for %[[Q:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
+// CHECK: %{{.+}} = scf.for %[[K:.+]] = %[[C0]] to %[[C8]] step %[[C4]]
+// CHECK: %[[K_SZ:.+]] = affine.apply #[[MAP]](%[[K]])
+// CHECK: %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[K_SZ]], %[[P]], %[[Q]], 0]
+// CHECK: %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[P]], %[[Q]], %[[K]]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack
+// CHECK-SAME: %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
+// CHECK-SAME: into %[[SLICE_DEST]]
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK]]
+// CHECK-SAME: into %{{.+}}[0, %[[P]], %[[Q]], %[[K]]]
+// CHECK: scf.yield %[[RES]]
+
+func.func @perfect_NKPQk_to_NPQK(%source: tensor<1x4x6x6x2xf32>, %dest: tensor<1x6x6x8xf32>) -> tensor<1x6x6x8xf32> {
+ %0 = tensor.unpack %source outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %dest : tensor<1x4x6x6x2xf32> -> tensor<1x6x6x8xf32>
+ return %0 : tensor<1x6x6x8xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+ %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 4]
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 6, 1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -2 + 8, 2)>
+// CHECK: func.func @perfect_NPQK_to_NKPQk
+// CHECK-SAME: %[[SOURCE:.+]]: tensor<1x6x6x8xf32>,
+// CHECK-SAME: %{{.+}}: tensor<1x4x6x6x2xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
+// CHECK: %{{.+}} = scf.for %[[ARG2:.+]] = %[[C0]] to %[[C4]] step %[[C1]]
+// CHECK: %{{.+}} = scf.for %[[ARG4:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
+// CHECK: %{{.+}} = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
+// CHECK: %[[MIN_ARG4:.+]] = affine.min #[[MAP]](%[[ARG4]])
+// CHECK: %[[MIN_ARG6:.+]] = affine.min #[[MAP]](%[[ARG6]])
+// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[ARG2]])
+// CHECK: %[[MIN_ARG2:.+]] = affine.min #[[MAP2]](%[[ARG2]])
+// CHECK: %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[ARG4]], %[[ARG6]], %[[APPLY]]]
+// CHECK: %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK-SAME: %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
+// CHECK-SAME: into %[[SLICE_DEST]]
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[PACK]]
+// CHECK-SAME: into %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
+// CHECK: scf.yield %[[RES]]
+
+func.func @perfect_NPQK_to_NKPQk(%source: tensor<1x6x6x8xf32>, %dest: tensor<1x4x6x6x2xf32>) -> tensor<1x4x6x6x2xf32> {
+ %0 = tensor.pack %source outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %dest : tensor<1x6x6x8xf32> -> tensor<1x4x6x6x2xf32>
+ return %0 : tensor<1x4x6x6x2xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+ %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 1]
+}
More information about the Mlir-commits
mailing list