[Mlir-commits] [mlir] 8c5ad0a - [mlir][Vector] Add a masked vectorization of tensor.pad

Nicolas Vasilache llvmlistbot at llvm.org
Thu Apr 13 13:20:36 PDT 2023


Author: Nicolas Vasilache
Date: 2023-04-13T13:20:29-07:00
New Revision: 8c5ad0a2f6532cec2f6841cc3e9a1ea043409398

URL: https://github.com/llvm/llvm-project/commit/8c5ad0a2f6532cec2f6841cc3e9a1ea043409398
DIFF: https://github.com/llvm/llvm-project/commit/8c5ad0a2f6532cec2f6841cc3e9a1ea043409398.diff

LOG: [mlir][Vector] Add a masked vectorization of tensor.pad

This revision takes advantage of masking support to introduce a vectorized
version of pad that does not require lowering to lower-level form.

Lowering to lower-level form (if/else + generate + fill + copy + insert_slice)
creates unnecessary complexity that can be completely sidestepped by using
masked vectorization properly.

Differential Revision: https://reviews.llvm.org/D148261

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7eaa2f7168bd..52982c3fd537 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -589,6 +589,13 @@ LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
 
+/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
+/// and (3) all-zero lowPad to
+///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
+FailureOr<vector::TransferWriteOp>
+maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp,
+                ArrayRef<int64_t> inputVectorSizes);
+
 /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
 FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter,
                                        LinalgOp linalgOp);

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 0b844e1e3333..2a95ff243fbf 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2322,6 +2322,11 @@ def Vector_CreateMaskOp :
     ```
   }];
 
+  let builders = [
+    // Build with mixed static/dynamic operands.
+    OpBuilder<(ins "VectorType":$type, "ArrayRef<OpFoldResult>":$mixedOperands)>
+  ];
+
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
   let assemblyFormat = "$operands attr-dict `:` type(results)";

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2970d3476d51..39f7802a688a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3098,6 +3098,16 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
   // TODO: Check that the correct number of vectorSizes was provided.
 
   for (Operation *target : targets) {
+    if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
+      FailureOr<vector::TransferWriteOp> maybeWriteOp =
+          maskedVectorize(rewriter, padOp, vectorSizes);
+      if (failed(maybeWriteOp)) {
+        return mlir::emitSilenceableFailure(target->getLoc())
+               << "failed to vectorize padOp";
+      }
+      continue;
+    }
+
     auto linalgOp = dyn_cast<LinalgOp>(target);
     if (!linalgOp) {
       return mlir::emitSilenceableFailure(target->getLoc())
@@ -3107,7 +3117,7 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
     if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes,
                                  getVectorizeNdExtract()))) {
       return mlir::emitSilenceableFailure(target->getLoc())
-             << "failed to vectorize op";
+             << "failed to vectorize linalg op";
     }
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b54eb0fa9a4f..14726a8c579f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -1385,6 +1386,63 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
   }
 }
 
+FailureOr<vector::TransferWriteOp>
+mlir::linalg::maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp,
+                              ArrayRef<int64_t> inputVectorSizes) {
+  auto padValue = padOp.getConstantPaddingValue();
+  if (!padValue) {
+    LDBG("pad value is not constant: " << padOp << "\n");
+    return rewriter.notifyMatchFailure(padOp, "pad value is not constant");
+  }
+
+  ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
+  if (!(resultTensorShape == inputVectorSizes)) {
+    LDBG("result tensor shape must match input vector sizes: " << padOp
+                                                               << "\n");
+    return rewriter.notifyMatchFailure(
+        padOp, "result tensor shape must match input vector sizes");
+  }
+  if (llvm::any_of(padOp.getStaticLow(),
+                   [](int64_t val) { return val != 0; })) {
+    LDBG("low pad must all be zero: " << padOp << "\n");
+    return rewriter.notifyMatchFailure(padOp, "low pad must all be zero");
+  }
+
+  Location loc = padOp.getLoc();
+  int64_t rank = inputVectorSizes.size();
+  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(padOp);
+  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto emptyOp =
+      rewriter.create<tensor::EmptyOp>(loc, padOp.getResultType(),
+                                       /*dynamicSizes=*/ValueRange{});
+  SmallVector<OpFoldResult> mixedSourceDims =
+      getMixedDimensions(rewriter, loc, padOp.getSource());
+  Value mask =
+      rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/padOp.getSource(),
+      /*indices=*/SmallVector<Value>(rank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(rank, true));
+  auto maskedOp = cast<vector::MaskOp>(
+      mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+  auto transferWriteOp = rewriter.create<vector::TransferWriteOp>(
+      loc,
+      /*vector=*/maskedOp->getResult(0),
+      /*source=*/emptyOp,
+      /*indices=*/SmallVector<Value>(rank, zero),
+      /*inBounds=*/SmallVector<bool>(rank, true));
+  rewriter.replaceOp(padOp, transferWriteOp->getResults());
+  return transferWriteOp;
+}
+
 /// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
 /// are used to vectorize this operation. `inputVectorSizes` must match the rank
 /// of the iteration space of the operation and the input vector sizes must be

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8ee59659c3a4..89ca09911230 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -640,10 +640,9 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
   auto loc = parser.getCurrentLocation();
   DictionaryAttr dictAttr;
   // TODO: Unify linalg op attribute parsing.
-  if (parser.parseAttribute(dictAttr) ||
-      parser.parseOperand(lhsInfo) || parser.parseComma() ||
-      parser.parseOperand(rhsInfo) || parser.parseComma() ||
-      parser.parseOperand(accInfo) ||
+  if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
+      parser.parseComma() || parser.parseOperand(rhsInfo) ||
+      parser.parseComma() || parser.parseOperand(accInfo) ||
       parser.parseTrailingOperandList(masksInfo) ||
       parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonTypeList(types) ||
@@ -5369,6 +5368,14 @@ LogicalResult ConstantMaskOp::verify() {
 // CreateMaskOp
 //===----------------------------------------------------------------------===//
 
+void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
+                         VectorType type,
+                         ArrayRef<OpFoldResult> mixedOperands) {
+  SmallVector<Value> operands =
+      getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
+  build(builder, result, type, operands);
+}
+
 LogicalResult CreateMaskOp::verify() {
   auto vectorType = getResult().getType().cast<VectorType>();
   // Verify that an operand was specified for each result vector each dimension.

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c407b49d896c..d54a2f57617c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -2757,3 +2757,37 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!pdl.operation) -> !pdl.operation
   transform.structured.masked_vectorize %0 vector_sizes [2, 4]
 }
+
+// -----
+
+// CHECK-LABEL: func @test_masked_vectorize_pad
+func.func @test_masked_vectorize_pad(
+  %0 : tensor<?x?xf32>, %h0 : index, %h1 : index) 
+    -> tensor<2x4xf32>
+{
+  //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  //  CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
+  //  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
+  //      CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
+  //      CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
+  //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
+  //      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { 
+  // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]] 
+  // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32> 
+  // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
+  //      CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0]], %[[c0]]]
+  // CHECK-SAME:   {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
+  %cst = arith.constant 42.43 : f32
+  %1 = tensor.pad %0 low[0, 0] high[%h0, %h1]  {
+    ^bb0(%hh1: index, %hh2: index):
+      tensor.yield %cst : f32
+    } : tensor<?x?xf32> to tensor<2x4xf32>
+  return %1: tensor<2x4xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 
+    : (!pdl.operation) -> !pdl.operation
+  transform.structured.masked_vectorize %0 vector_sizes [2, 4]
+}


        


More information about the Mlir-commits mailing list