[Mlir-commits] [mlir] 4d03651 - [mlir][tensor] Improve size inference in tiling tensor.pack ops.
Hanhan Wang
llvmlistbot at llvm.org
Thu Feb 23 10:35:13 PST 2023
Author: Hanhan Wang
Date: 2023-02-23T10:35:00-08:00
New Revision: 4d0365101f98061dbaf409c0f390778ef66672e7
URL: https://github.com/llvm/llvm-project/commit/4d0365101f98061dbaf409c0f390778ef66672e7
DIFF: https://github.com/llvm/llvm-project/commit/4d0365101f98061dbaf409c0f390778ef66672e7.diff
LOG: [mlir][tensor] Improve size inference in tiling tensor.pack ops.
The sizes of input operands need being clampled only when there are
incomplete tiles, i.e., the padding value is set. The shape input slice
can be folded into constants when they are static shapes and tiling
sizes.
Reviewed By: chelini
Differential Revision: https://reviews.llvm.org/D144604
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
mlir/test/Dialect/Tensor/tiling.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index ff11a3a27f9e0..33e698f7d2c75 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -161,11 +161,13 @@ struct PackOpTiling
}
// Limit the size of the input operand for incomplete tiles.
- OpFoldResult dimSize = srcDimValues[dim];
- auto avDimSize = AV(dim0).bind(dimSize);
- auto avInputIdx = AV(dim1).bind(inputIndices.back());
- inputSizes.back() =
- ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
+ if (packOp.getPaddingValue()) {
+ OpFoldResult dimSize = srcDimValues[dim];
+ auto avDimSize = AV(dim0).bind(dimSize);
+ auto avInputIdx = AV(dim1).bind(inputIndices.back());
+ inputSizes.back() =
+ ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
+ }
}
auto oneAttr = b.getI64IntegerAttr(1);
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index 09ebc45ccb57b..c22d29027670f 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -5,22 +5,18 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8
return %0 : tensor<1x1x4x8x8x32xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 64, 8)>
// CHECK: func.func @KCRS_to_KCRSsr
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
// CHECK: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
// CHECK: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
-// CHECK: %[[IN_R_SZ:.+]] = affine.min #[[MAP1]](%[[R]])
// CHECK: %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
-// CHECK: %[[IN_S_SZ:.+]] = affine.min #[[MAP3]](%[[S]])
// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
-// CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, %[[IN_R_SZ]], %[[IN_S_SZ]]] [1, 1, 1, 1]
+// CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x?x?xf32> to tensor<32x8xf32>
+// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[TILE]]
@@ -71,22 +67,16 @@ func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>)
return %0 : tensor<32x4x32x8xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 32)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 8)>
// CHECK: func.func @KC_to_CKkc
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %{{.+}} = scf.for %[[C:[a-zA-Z0-9]+]] =
// CHECK: %{{.+}} = scf.for %[[K:[a-zA-Z0-9]+]] =
// CHECK-DAG: %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
-// CHECK-DAG: %[[IN_K_SZ:.+]] = affine.min #[[MAP1]](%[[K]])
// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP2]](%[[C]])
-// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP3]](%[[C]])
-// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
-// CHECK-SAME: [%[[IN_K]], %[[IN_C]]] [%[[IN_K_SZ]], %[[IN_C_SZ]]] [1, 1]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-SAME: [0, 0] [32, 8] [1, 1] : tensor<?x?xf32> to tensor<32x8xf32>
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]]
+// CHECK-SAME: [%[[IN_K]], %[[IN_C]]] [32, 8] [1, 1]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[TILE]]
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index b92e5aeba9d99..3f07e3ce6712c 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -181,8 +181,6 @@ transform.sequence failures(propagate) {
// -----
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 64)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -32 + 256, 128)>
// CHECK: func.func @NC_to_NCnc
// CHECK-SAME: %[[IN:.*]]: tensor<128x256xf32>,
// CHECK-SAME: %[[OUT:.*]]: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> {
@@ -193,10 +191,8 @@ transform.sequence failures(propagate) {
// CHECK: %[[RES0:.*]] = scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<4x8x32x32xf32>) {
// CHECK: %[[RES1:.+]] = scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<4x8x32x32xf32>) {
// CHECK-DAG: %[[IN_N:.+]] = affine.apply #[[MAP0]](%[[N]])
-// CHECK-DAG: %[[IN_N_SZ:.*]] = affine.min #[[MAP1]]
// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]])
-// CHECK-DAG: %[[IN_C_SZ:.*]] = affine.min #[[MAP2]]
-// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [%[[IN_N_SZ]], %[[IN_C_SZ]]] [1, 1] : tensor<128x256xf32> to tensor<?x?xf32>
+// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [64, 128] [1, 1] : tensor<128x256xf32> to tensor<64x128xf32>
// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[N]], %[[C]], 0, 0] [2, 4, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<2x4x32x32xf32>
// CHECK: %[[SUB_RES:.*]] = tensor.pack
// CHECK-SAME: %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[SUB_OUT]]
@@ -221,7 +217,6 @@ transform.sequence failures(propagate) {
// -----
// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 8)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 16)>
// CHECK: func.func @KC_to_CKkc
// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
@@ -230,9 +225,8 @@ transform.sequence failures(propagate) {
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
// CHECK: scf.for %[[C:.+]] = %[[C0]] to %[[C32]] step %[[C2]]
// CHECK-DAG: %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]])
-// CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])
// CHECK: %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[IN]]
-// CHECK-SAME: [0, %[[IN_C]]] [128, %[[IN_C_SZ]]]
+// CHECK-SAME: [0, %[[IN_C]]] [128, 16]
// CHECK: %[[OUTPUT_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[C]], 0, 0, 0] [2, 4, 32, 8]
// CHECK: tensor.pack
// CHECK-SAME: %[[INPUT_SLICE]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
@@ -620,9 +614,7 @@ transform.sequence failures(propagate) {
// -----
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 6, 1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -2 + 8, 2)>
// CHECK: func.func @perfect_NPQK_to_NKPQk
// CHECK-SAME: %[[SOURCE:.+]]: tensor<1x6x6x8xf32>,
// CHECK-SAME: %{{.+}}: tensor<1x4x6x6x2xf32>)
@@ -633,10 +625,7 @@ transform.sequence failures(propagate) {
// CHECK: %{{.+}} = scf.for %[[ARG2:.+]] = %[[C0]] to %[[C4]] step %[[C1]]
// CHECK: %{{.+}} = scf.for %[[ARG4:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
// CHECK: %{{.+}} = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
-// CHECK: %[[MIN_ARG4:.+]] = affine.min #[[MAP]](%[[ARG4]])
-// CHECK: %[[MIN_ARG6:.+]] = affine.min #[[MAP]](%[[ARG6]])
// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[ARG2]])
-// CHECK: %[[MIN_ARG2:.+]] = affine.min #[[MAP2]](%[[ARG2]])
// CHECK: %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[ARG4]], %[[ARG6]], %[[APPLY]]]
// CHECK: %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
// CHECK: %[[PACK:.+]] = tensor.pack
More information about the Mlir-commits
mailing list