[Mlir-commits] [mlir] 7cedd95 - [mlir][Linalg] Add a transform.structured.lower_unpack op

Nicolas Vasilache llvmlistbot at llvm.org
Wed Feb 1 02:33:31 PST 2023


Author: Nicolas Vasilache
Date: 2023-02-01T02:26:44-08:00
New Revision: 7cedd956d0602bf6aca24c1fa5eda2bd585750ac

URL: https://github.com/llvm/llvm-project/commit/7cedd956d0602bf6aca24c1fa5eda2bd585750ac
DIFF: https://github.com/llvm/llvm-project/commit/7cedd956d0602bf6aca24c1fa5eda2bd585750ac.diff

LOG: [mlir][Linalg] Add a transform.structured.lower_unpack op

This revision introduces `transform.structured.lower_unpack` which allows
rewriting a `tensor.unpack` to `transpose` (`linalg.generic`) + `tensor.empty` + `tensor.collapse_shape` + `tensor.extract_slice`

The implementation is currently limited to static pack ops that do not have outer_dims permutations.

Differential Revision: https://reviews.llvm.org/D142889

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 5bd5c034e1c5..8ac2c2e0ad98 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -25,6 +25,7 @@ class LinalgOp;
 
 namespace tensor {
 class PackOp;
+class UnPackOp;
 } // namespace tensor
 
 namespace transform {

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2aee320b8193..630000c57f8f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -229,7 +229,7 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
     #### Return modes
 
     This operation ignores non-pack ops and drops them in the return.
-    This operation produces a silenceableFailure if the padding fails for any
+    This operation produces a silenceableFailure if the rewrite fails for any
     reason.
     If all the operations referred to by the `target` are rewritten, the
     transform succeeds.
@@ -252,6 +252,45 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LowerUnPackOp
+//===----------------------------------------------------------------------===//
+def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
+                         FunctionalStyleTransformOpTrait,
+                         MemoryEffectsOpInterface,
+                         TransformEachOpTrait,
+                         TransformOpInterface]> {
+  let description = [{
+    Lower a tensor.unpack into empty + linalg.transpose + tensor.collapse_shape + 
+    tensor.extract_slice.
+
+    #### Return modes
+
+    This operation ignores non-unpack ops and drops them in the return.
+    This operation produces a silenceableFailure if the rewrite fails for any
+    reason.
+    If all the operations referred to by the `target` are rewritten, the
+    transform succeeds.
+    Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
+  }];
+
+  let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
+  let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
+                      Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
+                      Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
+                      Transform_ConcreteOpType<"tensor.extract_slice">:$extract_slice_op);
+  let assemblyFormat = [{ 
+    $target attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::tensor::UnPackOp target,
+        ::mlir::transform::ApplyToEachResultList &transformResults,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MatchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9952bb1cc5ac..022c94e39e23 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -782,8 +782,8 @@ struct LowerPackResult {
 };
 
 /// Rewrite pack as pad + reshape + transpose.
-static FailureOr<LowerPackResult> rewriteLowerPack(RewriterBase &rewriter,
-                                                   tensor::PackOp packOp) {
+static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
+                                            tensor::PackOp packOp) {
   // 1. Filter out NYI cases.
   if (!packOp.getOuterDimsPerm().empty())
     return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
@@ -822,7 +822,7 @@ static FailureOr<LowerPackResult> rewriteLowerPack(RewriterBase &rewriter,
       packingMetadata.reassociations);
   Value paddingValue = packOp.getPaddingValue();
   if (!paddingValue) {
-    paddingValue = rewriter.create<arith::ConstantOp>(
+    rewriter.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
   }
   auto padOp =
@@ -876,7 +876,7 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
     transform::TransformState &state) {
   IRRewriter rewriter(target->getContext());
   rewriter.setInsertionPoint(target);
-  FailureOr<LowerPackResult> res = rewriteLowerPack(rewriter, target);
+  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
   if (failed(res)) {
     Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
     diag << "cannot lower to pad + expand + transpose";
@@ -888,6 +888,117 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// LowerUnPackOp
+//===----------------------------------------------------------------------===//
+
+struct LowerUnPackOpResult {
+  tensor::EmptyOp emptyOp;
+  linalg::TransposeOp transposeOp;
+  tensor::CollapseShapeOp collapseShapeOp;
+  tensor::ExtractSliceOp extractSliceOp;
+};
+
+/// Rewrite pack as empty + transpose + reshape + extract_slice.
+static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
+                                                  tensor::UnPackOp unPackOp) {
+  // 1. Filter out NYI cases.
+  if (!unPackOp.getOuterDimsPerm().empty())
+    return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
+
+  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  if (!packedTensorType.hasStaticShape()) {
+    return rewriter.notifyMatchFailure(
+        unPackOp,
+        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+  }
+
+  Location loc = unPackOp->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unPackOp);
+
+  // 2. Compute the permutation vector to move the last `numPackedDims` into the
+  // `innerPosDims` of a shape of rank `packedRank`.
+  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
+  int64_t packedRank = packedTensorType.getRank();
+  auto lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+  PackingMetadata packingMetadata =
+      computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+      packedRank, lastDims, packingMetadata.insertPositions);
+
+  // 3. Compute the stripMinedShape: this is the packed shape without outer and
+  // inner permutations.
+  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+  // 4. Transpose packedShape to stripMinedShape.
+  RankedTensorType stripMinedTensorType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMinedTensorType, packingMetadata.reassociations);
+  auto emptyOp =
+      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
+
+  LLVM_DEBUG(
+      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+                                                DBGS() << "insertPositions: ");
+      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+                                      DBGS() << "packedShape: ");
+      DBGSNL();
+      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      DBGSNL(); llvm::interleaveComma(
+          packingMetadata.reassociations, DBGS() << "reassociations: ",
+          [&](ReassociationIndices ri) {
+            llvm::interleaveComma(ri, llvm::dbgs() << "|");
+          });
+      DBGSNL();
+      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+      DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
+
+  // 5. Collapse from the stripMinedShape to the padded result.
+  auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
+      loc, collapsedType, transposeOp->getResult(0),
+      packingMetadata.reassociations);
+
+  // 6. ExtractSlice
+  auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+  int64_t destRank = destTensorType.getRank();
+  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
+  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, destTensorType, reshapeOp->getResult(0),
+      SmallVector<OpFoldResult>(destRank, zero),
+      tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+      SmallVector<OpFoldResult>(destRank, one));
+
+  // 7. Replace unPackOp by transposeOp.
+  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
+}
+
+DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
+    tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults,
+    transform::TransformState &state) {
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
+  if (failed(res)) {
+    Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+    diag << "cannot rewrite to pad + expand + transpose";
+    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+  }
+  transformResults.push_back(res->emptyOp);
+  transformResults.push_back(res->transposeOp);
+  transformResults.push_back(res->collapseShapeOp);
+  transformResults.push_back(res->extractSliceOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===---------------------------------------------------------------------===//
 // MatchOp
 //===---------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 1b87903bba59..7a89cf90214c 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
 
+  // CHECK-LABEL: func.func @pack(
 func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
@@ -26,3 +27,33 @@ transform.sequence failures(propagate) {
     -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
 }
 
+// -----
+
+// CHECK-LABEL: func.func @unpack(
+func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  //      CHECK: tensor.empty() : tensor<17x8x2x32x16x16xf32>
+  //      CHECK: linalg.transpose
+  // CHECK-SAME:    ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
+  // CHECK-SAME:   outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
+  // CHECK-SAME:   permutation = [0, 5, 1, 4, 2, 3]
+  //      CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]] 
+  // CHECK-SAME:   : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]  
+  // CHECK-SAME:   : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
+  %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
+    : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
+  return %pack : tensor<129x47x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op 
+    : (!pdl.operation) -> !transform.op<"tensor.unpack">
+  transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) 
+    -> (!transform.op<"tensor.empty">, 
+        !transform.op<"linalg.transpose">,
+        !transform.op<"tensor.collapse_shape">,
+        !transform.op<"tensor.extract_slice">)
+}


        


More information about the Mlir-commits mailing list