[Mlir-commits] [mlir] bb69de3 - [mlir][Linalg] Add a vectorization pattern for linalg::PadTensorOp

Nicolas Vasilache llvmlistbot at llvm.org
Wed Feb 10 06:16:57 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-10T14:13:49Z
New Revision: bb69de3f415653cad5ac25b79c10e016ee74dcfe

URL: https://github.com/llvm/llvm-project/commit/bb69de3f415653cad5ac25b79c10e016ee74dcfe
DIFF: https://github.com/llvm/llvm-project/commit/bb69de3f415653cad5ac25b79c10e016ee74dcfe.diff

LOG: [mlir][Linalg] Add a vectorization pattern for linalg::PadTensorOp

The new pattern is exercised from the TestLinalgTransforms pass.

Differential Revision: https://reviews.llvm.org/D96410

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a40d425f7f2e..6916fa78abbb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -231,6 +231,28 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     static linalg::PadTensorOp createPadScalarOp(
         Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
         ArrayRef<OpFoldResult> high, Location loc, OpBuilder & builder);
+
+    // Return a vector of all the static or dynamic values (low/high padding) of
+    // the op.
+    inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayAttr staticAttrs,
+                                                     ValueRange values) {
+      SmallVector<OpFoldResult> res;
+      unsigned numDynamic = 0;
+      unsigned count = staticAttrs.size();
+      for (unsigned idx = 0; idx < count; ++idx) {
+        if (ShapedType::isDynamic(staticAttrs[idx].cast<IntegerAttr>().getInt()))
+          res.push_back(values[numDynamic++]);
+        else
+          res.push_back(staticAttrs[idx]);
+      }
+      return res;
+    }
+    SmallVector<OpFoldResult> getMixedLowPad() {
+      return getMixedPadImpl(static_low(), low());
+    }
+    SmallVector<OpFoldResult> getMixedHighPad() {
+      return getMixedPadImpl(static_high(), high());
+    }
   }];
 
   let builders = [

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 669f127f6434..4b5580a62abc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -809,6 +809,16 @@ void populateLinalgConvGeneralizationPatterns(
 //===----------------------------------------------------------------------===//
 // Op-specific patterns.
 //===----------------------------------------------------------------------===//
+
+/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`,
+/// it needs a specific pattern to vectorize.
+struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
+  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadTensorOp padOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 /// Match and rewrite for the pattern:
 /// ```
 ///    %alloc = ...

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 8bc21b179037..1aeb92a2faf8 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1213,6 +1213,10 @@ def Vector_TransferReadOp :
     OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
       "ValueRange":$indices, "AffineMap":$permutationMap,
       CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
+    // Builder that sets padding to 'getMinorIdentityMap'.
+    OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
+      "ValueRange":$indices, "Value":$padding,
+      CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
     // Builder that sets permutation map (resp. padding) to
     // 'getMinorIdentityMap' (resp. zero).
     OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a9a43e194d75..86f05c38ed89 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -448,9 +448,71 @@ Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
 }
 
 //----------------------------------------------------------------------------//
-// Misc. conv vectorization patterns.
+// Misc. vectorization patterns.
 //----------------------------------------------------------------------------//
-// TODO: cleanup all this.
+
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
+/// TransferWriteOp. For now, this only applies when all low and high paddings
+/// are determined to be zero.
+LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
+    linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
+  // Helper function to determine whether an OpFoldResult is not a zero Index.
+  auto isNotZeroIndex = [](OpFoldResult ofr) {
+    if (Attribute attr = ofr.dyn_cast<Attribute>())
+      return attr.cast<IntegerAttr>().getInt() != 0;
+    Value v = ofr.get<Value>();
+    if (auto constOp = v.getDefiningOp<ConstantIntOp>())
+      return constOp.getValue() != 0;
+    return true;
+  };
+
+  auto resultShapedType = padOp.result().getType().cast<ShapedType>();
+  // Bail on non-static shapes.
+  if (!resultShapedType.hasStaticShape())
+    return failure();
+
+  // If any pad_low is not a static 0, needs a mask. Bail for now.
+  if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
+    return failure();
+  VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
+  if (!vectorType)
+    return failure();
+
+  // Only support padding with a constant for now, i.e. either:
+  //   1. A BBarg from a 
diff erent block.
+  //   2. A value defined outside of the current block.
+  Block &block = padOp.region().front();
+  auto yieldOp = cast<YieldOp>(block.getTerminator());
+  assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
+  Value padValue = yieldOp.values().front();
+  Operation *definingOp = padValue.getDefiningOp();
+  if (definingOp && definingOp->getBlock() == &block)
+    return failure();
+  if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
+    return failure();
+
+  // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
+  if (llvm::any_of(padOp.getMixedHighPad(),
+                   [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
+    return failure();
+
+  // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
+  // TransferWriteOp@[0..0].
+  SmallVector<Value> indices(
+      resultShapedType.getRank(),
+      rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
+  Value read = rewriter.create<vector::TransferReadOp>(
+      padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
+  Value init =
+      rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
+                                    resultShapedType.getElementType());
+  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
+                                                       indices);
+
+  return success();
+}
+
+// TODO: cleanup all the convolution vectorization patterns.
 template <class ConvOp, int N>
 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
     ConvOp op, PatternRewriter &rewriter) const {

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 9fe8cf23c162..99b978895c7e 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1122,8 +1122,8 @@ class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
 
 } // namespace
 
-void BroadcastOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
+void BroadcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
   results.insert<BroadcastToShapeCast>(context);
 }
 
@@ -2026,17 +2026,32 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
 
 /// Builder that sets padding to zero.
 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vector, Value source, ValueRange indices,
-                           AffineMap permutationMap,
+                           VectorType vectorType, Value source,
+                           ValueRange indices, AffineMap permutationMap,
                            ArrayRef<bool> maybeMasked) {
   Type elemType = source.getType().cast<ShapedType>().getElementType();
   Value padding = builder.create<ConstantOp>(result.location, elemType,
                                              builder.getZeroAttr(elemType));
   if (maybeMasked.empty())
-    return build(builder, result, vector, source, indices, permutationMap,
+    return build(builder, result, vectorType, source, indices, permutationMap,
                  padding, ArrayAttr());
   ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
-  build(builder, result, vector, source, indices, permutationMap, padding,
+  build(builder, result, vectorType, source, indices, permutationMap, padding,
+        maskedArrayAttr);
+}
+
+/// Builder that sets permutation map to 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices, Value padding,
+                           ArrayRef<bool> maybeMasked) {
+  auto permMap = getTransferMinorIdentityMap(
+      source.getType().cast<ShapedType>(), vectorType);
+  if (maybeMasked.empty())
+    return build(builder, result, vectorType, source, indices, permMap, padding,
+                 ArrayAttr());
+  ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
+  build(builder, result, vectorType, source, indices, permMap, padding,
         maskedArrayAttr);
 }
 

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3904353287c5..961a9307c1f5 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -390,3 +390,44 @@ func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12x
     outs(%c: memref<4x12xi32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @pad_static
+//   CHECK-NOT:   linalg.pad_tensor
+func @pad_static(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
+  //      CHECK: %[[C0:.*]] = constant 0 : index
+  //      CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]]
+  // CHECK-SAME:   : tensor<?x?x?xf32>, vector<2x3x4xf32>
+  //      CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32>
+  //      CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]]
+  // CHECK-SAME:   {masked = [false, false, false]} : vector<2x3x4xf32>, tensor<2x3x4xf32>
+  %0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 0, 0] {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index):
+      linalg.yield %pad_value : f32
+    } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+
+  // CHECK: return %[[WRITTEN]] : tensor<2x3x4xf32>
+  return %0 : tensor<2x3x4xf32>
+}
+
+// CHECK-LABEL: func @pad_static_high_padding
+//       CHECK:   linalg.pad_tensor
+func @pad_static_high_padding(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
+  %0 = linalg.pad_tensor %arg0 low[0, 0, 0] high[0, 1, 0] {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index):
+      linalg.yield %pad_value : f32
+    } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+  return %0 : tensor<2x3x4xf32>
+}
+
+// CHECK-LABEL: func @pad_dynamic
+//       CHECK:   linalg.pad_tensor
+func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
+                  %pad_value: f32) -> tensor<6x?x?x?xf32> {
+  %0 = linalg.pad_tensor %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+      linalg.yield %pad_value : f32
+    } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
+  return %0 : tensor<6x?x?x?xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index f9dea42f3a8a..a492d496af51 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -491,6 +491,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
   patterns.insert<LinalgVectorizationPattern>(
       LinalgTransformationFilter()
           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
+  patterns.insert<PadTensorOpVectorizationPattern>(funcOp.getContext());
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 


        


More information about the Mlir-commits mailing list