[Mlir-commits] [mlir] d21beb5 - [MLIR][Linalg] Avoid padding attribute in `pack` when possible

Lorenzo Chelini llvmlistbot at llvm.org
Tue Jul 11 02:32:57 PDT 2023


Author: Lorenzo Chelini
Date: 2023-07-11T11:32:51+02:00
New Revision: d21beb598f5a4932b92a9b702a9195c0212cde03

URL: https://github.com/llvm/llvm-project/commit/d21beb598f5a4932b92a9b702a9195c0212cde03
DIFF: https://github.com/llvm/llvm-project/commit/d21beb598f5a4932b92a9b702a9195c0212cde03.diff

LOG: [MLIR][Linalg] Avoid padding attribute in `pack` when possible

If we deal with statically known tensors and tiles and a given tile
perfectly divides a given dimension, we can omit the padding attribute.
As a bonus point, we can now run pack and unpack propagation
(currently, we bail out during propagation if we have the padding
attribute).

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D154607

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-op-pack.mlir
    mlir/test/Dialect/Linalg/transform-pack-greedily.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c7b65be888bf94..d1c33d8b4c03c3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1823,7 +1823,9 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
 
     // Returns true if we have enough static information to catch undefined
     // behavior when the tile size does not divide perfectly the dimension of
-    // the input tensor.
+    // the input tensor. If a given dimension or a tile associated with it is
+    // dynamic, the dimension is not considered as we don't have enough static
+    // information to understand if the tile perfectly divides that dimension.
     static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
                                     ArrayRef<int64_t> innerDimsPos,
                                     ArrayRef<OpFoldResult> innerTiles);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index ce5dd46ad8f44d..e39f8470e9c7a5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -563,12 +563,25 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
       Value dest = tensor::PackOp::createDestinationTensor(
           rewriter, loc, operand, innerPackSizes, innerPos,
           /*outerDimsPerm=*/{});
-      // TODO: value of the padding attribute should be determined by consumers.
-      auto zeroAttr =
-          rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
-      Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
-      packOps.push_back(rewriter.create<tensor::PackOp>(
-          loc, operand, dest, innerPos, innerPackSizes, zero));
+      ShapedType operandType = operand.getType().cast<ShapedType>();
+      bool areConstantTiles =
+          llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
+            return getConstantIntValue(tile).has_value();
+          });
+      if (areConstantTiles && operandType.hasStaticShape() &&
+          !tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
+                                               innerPackSizes)) {
+        packOps.push_back(rewriter.create<tensor::PackOp>(
+            loc, operand, dest, innerPos, innerPackSizes));
+      } else {
+        // TODO: value of the padding attribute should be determined by
+        // consumers.
+        auto zeroAttr =
+            rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
+        Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+        packOps.push_back(rewriter.create<tensor::PackOp>(
+            loc, operand, dest, innerPos, innerPackSizes, zero));
+      }
       inputsAndInits.push_back(packOps.back());
     }
   }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
index eaeb258590d75d..a8502d211cf80a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
@@ -593,3 +593,36 @@ transform.sequence failures(propagate) {
       : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">) 
       -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
 }
+
+// -----
+
+func.func @no_padding_on_packs(%A: tensor<32x32xf32>, %B: tensor<32x32xf32>, %C: tensor<32x32xf32>)
+    -> tensor<32x32xf32> {
+  %0 = linalg.matmul  ins(%A, %B: tensor<32x32xf32>, tensor<32x32xf32>)
+                     outs(%C: tensor<32x32xf32>)
+    -> tensor<32x32xf32>
+  return %0 : tensor<32x32xf32>
+}
+
+// CHECK-LABEL: no_padding_on_packs
+// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8] 
+// CHECK-SAME:  into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32>
+// CHECK: tensor.pack %{{.+}} outer_dims_perm = [1, 0] 
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [8, 8] 
+// CHECK-SAME:  into %{{.+}} : tensor<32x32xf32> -> tensor<4x4x8x8xf32>
+// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8] 
+// CHECK-SAME:  into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32>
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.pack %0 packed_sizes = [4, 8, 8]
+      : (!transform.any_op) -> (!transform.op<"linalg.generic">)
+    %pack = transform.get_producer_of_operand %1[1]
+    : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
+    %2, %pack_2, %empty_unpack_2 =
+    transform.structured.pack_transpose %pack with_compute_op(%1)
+    outer_perm = [1, 0] inner_perm = [1, 0]
+     : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
+    -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.any_op) 
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
index 68f07067b53502..63ce9c02afb085 100644
--- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
+++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
@@ -348,3 +348,37 @@ transform.sequence failures(propagate) {
       matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
     : (!transform.op<"linalg.matvec">) -> !transform.any_op
 }
+
+// -----
+
+func.func @no_padding_on_packs(%A: tensor<32x32xf32>, %B: tensor<32x32xf32>, %C: tensor<32x32xf32>)
+    -> tensor<32x32xf32> {
+  %0 = linalg.matmul  ins(%A, %B: tensor<32x32xf32>, tensor<32x32xf32>)
+                     outs(%C: tensor<32x32xf32>)
+    -> tensor<32x32xf32>
+  return %0 : tensor<32x32xf32>
+}
+
+// CHECK-LABEL: no_padding_on_packs
+// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 4] 
+// CHECK-SAME:  into %{{.+}} : tensor<32x32xf32> -> tensor<4x8x8x4xf32>
+// CHECK: tensor.pack %{{.+}} outer_dims_perm = [1, 0] 
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [4, 16] into %{{.+}} : tensor<32x32xf32> -> tensor<2x8x4x16xf32>
+// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 16] 
+// CHECK-SAME:  into %{{.+}} : tensor<32x32xf32> -> tensor<4x2x8x16xf32>
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.op<"linalg.matmul">
+    %1 = transform.structured.pack_greedily %0
+        matmul_packed_sizes = [8, 16, 4] matmul_inner_dims_order = [0, 1, 2]
+      : (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic">
+    %pack = transform.get_producer_of_operand %1[1]
+    : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
+    %2, %pack_2, %empty_unpack_2 =
+    transform.structured.pack_transpose %pack with_compute_op(%1)
+    outer_perm = [1, 0] inner_perm = [1, 0]
+     : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
+    -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.any_op)
+}


        


More information about the Mlir-commits mailing list