[Mlir-commits] [mlir] 86f186e - [mlir][linalg] Add	makeComposedPadHighOp.
    Tobias Gysi 
    llvmlistbot at llvm.org
       
    Wed Nov 24 11:19:44 PST 2021
    
    
  
Author: Tobias Gysi
Date: 2021-11-24T19:18:59Z
New Revision: 86f186efea7b5f542ef3d9fa2e63fd485475e011
URL: https://github.com/llvm/llvm-project/commit/86f186efea7b5f542ef3d9fa2e63fd485475e011
DIFF: https://github.com/llvm/llvm-project/commit/86f186efea7b5f542ef3d9fa2e63fd485475e011.diff
LOG: [mlir][linalg] Add makeComposedPadHighOp.
Add the makeComposedPadHighOp method which creates a new PadTensorOp if necessary. If the source to pad is actually the result of a sequence of padded LinalgOps, the method checks if padding is needed or if we can use the padded result of the padded LinalgOp sequence directly.
Example:
```
%0 = tensor.extract_slice %arg0 [%iv0, %iv1] [%sz0, %sz1]
%1 = linalg.pad_tensor %0 low[0, 0] high[...] { linalg.yield %cst }
%2 = linalg.matmul ins(...) outs(%1)
%3 = tensor.extract_slice %2 [0, 0] [%sz0, %sz1]
```
when padding %3 return %2 instead of introducing
```
%4 = linalg.pad_tensor %3 low[0, 0] high[...] { linalg.yield %cst }
```
Depends On D114161
Reviewed By: nicolasvasilache, pifon2a
Differential Revision: https://reviews.llvm.org/D114175
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/pad.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index c1cdd3eda2cb3..55fb8bdca7914 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -93,20 +93,42 @@ FailureOr<int64_t> getConstantUpperBoundForIndex(Value value);
 ///
 /// Example:
 /// ```
-///   %0 = tensor.extract_slice %arg0[3, 4][3, 32][1, 1] : tensor<64x64xf32> to
+/// %0 = tensor.extract_slice %arg0[3, 4][3, 32][1, 1] : tensor<64x64xf32> to
 ///                                                        tensor<3x32xf32>
-///   %1 = tensor.extract_slice %0[0, 5][3, 4][1, 1] : tensor<3x32xf32> to
+/// %1 = tensor.extract_slice %0[0, 5][3, 4][1, 1] : tensor<3x32xf32> to
 ///                                                    tensor<3x4xf32>
 /// ```
 /// folds into:
 /// ```
-///   %1 = tensor.extract_slice %arg0[3, 9][3, 4][1, 1] : tensor<64x64xf32> to
-///                                                         tensor<3x4xf32>
+/// %1 = tensor.extract_slice %arg0[3, 9][3, 4][1, 1] : tensor<64x64xf32> to
+///                                                       tensor<3x4xf32>
 /// ```
 tensor::ExtractSliceOp makeComposedExtractSliceOp(
     OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides);
 
+/// Create a PadTensorOp that pads `source` to the size of the statically sized
+/// `type` whose static sizes are assumed to be greater than the dynamic
+/// `source` size. The padding introduces trailing `pad` values until the target
+/// size is met. If `source` is defined by one or more LinalgOps that have been
+/// padded with the same value and sizes, return their padded result instead of
+/// creating a PadTensorOp.
+///
+/// Example:
+/// ```
+/// %0 = tensor.extract_slice %arg0 [%iv0, %iv1] [%sz0, %sz1]
+/// %1 = linalg.pad_tensor %0 low[0, 0] high[...] { linalg.yield %cst }
+/// %2 = linalg.matmul ins(...) outs(%1)
+/// %3 = tensor.extract_slice %2 [0, 0] [%sz0, %sz1]
+/// ```
+/// makeComposedPadHighOp(source=%3, pad=%cst) returns %2
+/// makeComposedPadHighOp(source=%3, pad=%other_cst) returns %4
+/// ```
+/// %4 = linalg.pad_tensor %3 low[0, 0] high[...] { linalg.yield %other_cst }
+/// ```
+Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
+                            Value source, Value pad, bool nofold);
+
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
 //===----------------------------------------------------------------------===//
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 752b14ee0cd44..73db8b71ff2db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -211,9 +211,9 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
   auto staticTensorType = RankedTensorType::get(
       staticSizes, getElementTypeOrSelf(opOperand->get()));
   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
-  result = linalg::PadTensorOp::createPadHighOp(
-      staticTensorType, opOperand->get(), paddingValue.getValue(),
-      /*nofold=*/nofold, opToPad->getLoc(), b);
+  result =
+      makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
+                            opOperand->get(), paddingValue.getValue(), nofold);
   return success();
 }
 
diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index cf0aee6bd2138..a737876eff5ec 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -322,6 +322,66 @@ tensor::ExtractSliceOp makeComposedExtractSliceOp(
                                           foldedOffsets, sizes, strides);
 }
 
+Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
+                            Value source, Value pad, bool nofold) {
+  assert(type.hasStaticShape() && "expect tensor type to have static shape");
+
+  // Exit if `source` is not defined by an ExtractSliceOp.
+  auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
+  if (!sliceOp)
+    return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
+
+  // Search the `source` use-def chain for padded LinalgOps.
+  Value current = sliceOp.source();
+  while (current) {
+    auto linalgOp = current.getDefiningOp<LinalgOp>();
+    if (!linalgOp)
+      break;
+    OpResult opResult = current.cast<OpResult>();
+    current = linalgOp.getOutputOperand(opResult.getResultNumber())->get();
+  }
+  auto padTensorOp = current ? current.getDefiningOp<PadTensorOp>() : nullptr;
+
+  // Exit if the search fails to match a PadTensorOp at the end of the matched
+  // LinalgOp sequence.
+  if (!padTensorOp)
+    return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
+
+  // Exit if the padded result type does not match.
+  if (sliceOp.source().getType() != type)
+    return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
+
+  // Exit if the LinalgOps are not high padded.
+  if (llvm::any_of(padTensorOp.getMixedLowPad(), [](OpFoldResult ofr) {
+        return getConstantIntValue(ofr) != static_cast<int64_t>(0);
+      }))
+    return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
+
+  // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
+  // of the slice padded by `padTensorOp`.
+  auto padTensorOpSliceOp =
+      padTensorOp.source().getDefiningOp<tensor::ExtractSliceOp>();
+  if (!padTensorOpSliceOp ||
+      llvm::any_of(llvm::zip(sliceOp.getMixedSizes(),
+                             padTensorOpSliceOp.getMixedSizes()),
+                   [](std::tuple<OpFoldResult, OpFoldResult> it) {
+                     return !isEqualConstantIntOrValue(std::get<0>(it),
+                                                       std::get<1>(it));
+                   }))
+    return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
+
+  // Exit if the padding values do not match.
+  Attribute padTensorOpPadAttr, padAttr;
+  Value padTensorOpPad = padTensorOp.getConstantPaddingValue();
+  if (!padTensorOpPad ||
+      !matchPattern(padTensorOpPad, m_Constant(&padTensorOpPadAttr)) ||
+      !matchPattern(pad, m_Constant(&padAttr)) || padTensorOpPadAttr != padAttr)
+    return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b);
+
+  // Return the padded result if the padding values and sizes match.
+  return sliceOp.source();
+}
+
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(
diff  --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir
index 52cd242cece3e..68734803276e4 100644
--- a/mlir/test/Dialect/Linalg/pad.mlir
+++ b/mlir/test/Dialect/Linalg/pad.mlir
@@ -214,6 +214,123 @@ func @dynamic_sizes(%arg0: tensor<?x?xf32>,
 
 // -----
 
+#map0 = affine_map<(d0) -> (64, d0)>
+
+//      CHECK:  compose_padding
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<64x64xf32>
+func @compose_padding(%arg0: tensor<64x64xf32>,
+                      %iv0 : index) -> tensor<?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+
+  //      CHECK:  %[[SIZE:.*]] = affine.min
+  %size = affine.min #map0(%iv0)
+
+  //      CHECK:  %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+  // CHECK-SAME:                                     [0, 0]
+  // CHECK-SAME:                                     [%[[SIZE]], %[[SIZE]]]
+  //      CHECK:  %[[T1:.*]] = linalg.pad_tensor %[[T0]]
+  //      CHECK:  %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]
+  //      CHECK:  %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]
+  %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
+  %1 = linalg.pad_tensor %0 low[0, 0] high[%iv0, %iv0]  {
+    ^bb0(%arg3: index, %arg4: index):  // no predecessors
+      linalg.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<64x64xf32>
+  %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+  %3 = linalg.fill(%cst, %2) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+  %4 = tensor.extract_slice %3[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
+
+  // Check there are no additional pad tensor operations.
+  //  CHECK-NOT:  linalg.pad_tensor
+
+  // Check the matmul directly uses the result of the fill operation.
+  //      CHECK:  %[[T4:.*]] = linalg.matmul ins(%[[T3]]
+  //      CHECK:  %[[T5:.*]] = tensor.extract_slice %[[T4]]
+  // CHECK-SAME:                                     [0, 0]
+  // CHECK-SAME:                                     [%[[SIZE]], %[[SIZE]]]
+  %5 = linalg.matmul ins(%4, %4 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  //      CHECK:  return %[[T5]]
+  return %5 : tensor<?x?xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> (64, d0)>
+
+//      CHECK:  
diff erent_padding_values
+func @
diff erent_padding_values(%arg0: tensor<64x64xf32>,
+                               %iv0 : index) -> tensor<?x?xf32> {
+  %cst = arith.constant 42.0 : f32
+  %size = affine.min #map0(%iv0)
+  %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
+  %1 = linalg.pad_tensor %0 low[0, 0] high[%iv0, %iv0]  {
+    ^bb0(%arg3: index, %arg4: index):  // no predecessors
+      linalg.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<64x64xf32>
+  %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+  %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
+
+  // Different padding values prevent composing the paddings (42.0 vs. 0.0).
+  //      CHECK:  = linalg.fill
+  //      CHECK:  = linalg.pad_tensor
+  //      CHECK:  = linalg.matmul
+  %5 = linalg.matmul ins(%4, %4 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %5 : tensor<?x?xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> (64, d0)>
+
+//      CHECK:  
diff erent_padding_dynamic_sizes
+func @
diff erent_padding_dynamic_sizes(%arg0: tensor<64x64xf32>,
+                                      %iv0 : index) -> tensor<?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+  %size = affine.min #map0(%iv0)
+  %0 = tensor.extract_slice %arg0[0, 0] [%iv0, %iv0] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
+  %1 = linalg.pad_tensor %0 low[0, 0] high[%iv0, %iv0]  {
+    ^bb0(%arg3: index, %arg4: index):  // no predecessors
+      linalg.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<64x64xf32>
+  %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+  %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
+
+  // Different dynamic sizes prevent composing the paddings (%iv0 vs %size).
+  //      CHECK:  = linalg.fill
+  //      CHECK:  = linalg.pad_tensor
+  //      CHECK:  = linalg.matmul
+  %5 = linalg.matmul ins(%4, %4 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %5 : tensor<?x?xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> (64, d0)>
+
+//      CHECK:  
diff erent_padding_static_sizes
+func @
diff erent_padding_static_sizes(%arg0: tensor<62x62xf32>,
+                                     %iv0 : index) -> tensor<?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+  %size = affine.min #map0(%iv0)
+  %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<62x62xf32> to tensor<?x?xf32>
+  %1 = linalg.pad_tensor %0 low[0, 0] high[%iv0, %iv0]  {
+    ^bb0(%arg3: index, %arg4: index):  // no predecessors
+      linalg.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<62x62xf32>
+  %2 = linalg.fill(%cst, %1) : f32, tensor<62x62xf32> -> tensor<62x62xf32>
+  %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<62x62xf32> to tensor<?x?xf32>
+
+  // Different static sizes prevent composing the paddings (62 vs 64 derived from #map0).
+  //      CHECK:  = linalg.fill
+  //      CHECK:  = linalg.pad_tensor
+  //      CHECK:  = linalg.matmul
+  %5 = linalg.matmul ins(%4, %4 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %5 : tensor<?x?xf32>
+}
+
+// -----
+
 #map = affine_map<(d0) -> (7, -d0 + 12)>
 
 //      CHECK-FILL:  scalar_operand
        
    
    
More information about the Mlir-commits
mailing list