[Mlir-commits] [mlir] 1d1a331 - [mlir][Linalg] NFC - Expose packing transpose implementation as a standalone functional-style API call

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jan 24 07:47:46 PST 2023


Author: Nicolas Vasilache
Date: 2023-01-24T07:47:40-08:00
New Revision: 1d1a3313513f5e15759328d27e4c2350977140c4

URL: https://github.com/llvm/llvm-project/commit/1d1a3313513f5e15759328d27e4c2350977140c4
DIFF: https://github.com/llvm/llvm-project/commit/1d1a3313513f5e15759328d27e4c2350977140c4.diff

LOG: [mlir][Linalg] NFC - Expose packing transpose implementation as a standalone functional-style API call

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07f56770b63c8..6cb73bacb0e06 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallSet.h"
@@ -1141,12 +1142,32 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter);
 
-/// Implement packing of a single LinalgOp by performing packing by
-/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
+/// Implement packing of a single LinalgOp by `packedSizes`.
+/// There must be one packedSizes entry per `linalgOp` iterator.
 /// Return the packed Linalg op on success, failure otherwise.
 FailureOr<linalg::LinalgOp> pack(RewriterBase &rewriter,
                                  linalg::LinalgOp linalgOp,
                                  ArrayRef<OpFoldResult> packedSizes);
+
+/// Struct to hold the result of a `packTranspose` call.
+struct PackTransposeResult {
+  tensor::PackOp transposedPackOp;
+  linalg::LinalgOp transposedLinalgOp;
+  tensor::UnPackOp transposedUnPackOp;
+};
+/// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the
+/// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements.
+/// Return failure if either:
+///   1. the `packOp` does not have the `linalgOp` as its unique use.
+///   2. the `maybeUnPackOp`, if specified must be a consumer of the result tied
+///      to the unique `packOp` use.
+///   3. `outerPerm` (resp. `innerPerm`) must be valid permutations of
+///      `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty.
+FailureOr<PackTransposeResult>
+packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
+              linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
+              ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4e0be5aa8dbd4..554328c6cbc1c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -950,6 +950,23 @@ void transform::PackOp::getEffects(
 // PackTransposeOp
 //===---------------------------------------------------------------------===//
 
+LogicalResult transform::PackTransposeOp::verify() {
+  if (!isPermutationVector(getInnerPerm())) {
+    return emitOpError() << getInnerPermAttrName()
+                         << " is not a valid permutation";
+  }
+  if (!isPermutationVector(getOuterPerm())) {
+    return emitOpError() << getOuterPermAttrName()
+                         << " is not a valid permutation";
+  }
+  if (getInnerPerm().empty() && getOuterPerm().empty()) {
+    return emitOpError() << " at least one of " << getInnerPermAttrName()
+                         << " or " << getOuterPermAttrName()
+                         << " must be specified";
+  }
+  return success();
+}
+
 namespace {
 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
 } // namespace
@@ -983,84 +1000,6 @@ bool isValidPackingPermutation(
          isPermutationVector(permutation);
 }
 
-/// Return a copy of `tensorType` after permutation by `permutationVector`.
-// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
-// but this would introduce a dependence on Dialect in IR.
-// TODO: Restructure.
-static RankedTensorType permuteShape(RankedTensorType tensorType,
-                                     ArrayRef<int64_t> permutationVector) {
-  SmallVector<int64_t> shape(tensorType.getShape());
-  applyPermutationToVector(shape, permutationVector);
-  return RankedTensorType::Builder(tensorType).setShape(shape);
-}
-
-/// Return a new GenericOp obtained by transposing opOperand by the permutation
-/// vector:
-///   - the corresponding indexing map is transposed by `permutation`
-///   - the corresponding operand value is replaced by `transposedValue`
-/// `linalgOp` is replaced by the return op in the process.
-/// Asserts that `transposedValue` is of the proper transposed ShapedType.
-static LinalgOp transposeOneLinalgOperandAndReplace(
-    RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
-    ArrayRef<int64_t> permutation, Value transposedValue) {
-  // Sanity check the operand.
-  assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
-
-  // Sanity check of the expected transposed tensor type.
-  auto tensorType = permuteShape(
-      opOperand.get().getType().cast<RankedTensorType>(), permutation);
-  (void)tensorType;
-  assert(tensorType == transposedValue.getType() &&
-         "expected tensor type mismatch");
-
-  // Compute the transposed indexing map.
-  // Sigh unsigned pollution.
-  SmallVector<unsigned> tmpTransposition = llvm::to_vector(
-      llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
-  AffineMap permutationMap =
-      AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
-  AffineMap transposedMap =
-      permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
-
-  // Set the transposed indexing map in the proper position.
-  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
-  indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
-  // Set the transposedValue in the proper operand position.
-  SmallVector<Value> operands = linalgOp->getOperands();
-  operands[opOperand.getOperandNumber()] = transposedValue;
-
-  ValueRange operandsRef(operands);
-  auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
-      /*location=*/linalgOp->getLoc(),
-      /*resultTensorTypes=*/
-      operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
-      /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
-      /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
-      /*indexingMaps=*/indexingMaps,
-      /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
-  transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
-  rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
-
-  return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
-}
-
-LogicalResult transform::PackTransposeOp::verify() {
-  if (!isPermutationVector(getInnerPerm())) {
-    return emitOpError() << getInnerPermAttrName()
-                         << " is not a valid permutation";
-  }
-  if (!isPermutationVector(getOuterPerm())) {
-    return emitOpError() << getOuterPermAttrName()
-                         << " is not a valid permutation";
-  }
-  if (getInnerPerm().empty() && getOuterPerm().empty()) {
-    return emitOpError() << " at least one of " << getInnerPermAttrName()
-                         << " or " << getOuterPermAttrName()
-                         << " must be specified";
-  }
-  return success();
-}
-
 DiagnosedSilenceableFailure
 transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
                                   transform::TransformState &state) {
@@ -1138,68 +1077,23 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
   assert(packOp && linalgOp && "unexpected null op");
 
   // Step 3. Actually transpose the ops.
-  Location loc = linalgOp.getLoc();
   IRRewriter rewriter(getContext());
+  FailureOr<PackTransposeResult> res = packTranspose(
+      rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
+  // Preconditions have been checked, it is an error to fail here.
+  assert(succeeded(res) && "unexpected packTranspose failure");
 
-  // Step 3.a. Transpose packOp.
-  rewriter.setInsertionPoint(packOp);
-  tensor::PackOp transposedPackOp = packOp.createTransposedClone(
-      rewriter, loc, getInnerPerm(), getOuterPerm());
-
-  // Step 3.b. Transpose linalgOp.
-  assert(packOp.getResult().hasOneUse() && "expect single use");
-  // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
-  // identity. Don't rely on it.
-  int64_t numLeadingDims = packOp.getSourceRank();
-  int64_t numTrailingDims = packOp.getInnerDimsPos().size();
-  // Step 3.b.i. Compute the permutation on the whole operand.
-  // Leading part just reuse the outerPerm.
-  SmallVector<int64_t> permutation(getOuterPerm());
-  if (permutation.empty())
-    llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
-  // Trailing part needs to reindex positions by `numLeadingDims`.
-  if (getInnerPerm().empty()) {
-    llvm::append_range(
-        permutation,
-        llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
-  } else {
-    llvm::append_range(permutation,
-                       llvm::map_range(getInnerPerm(), [&](int64_t pos) {
-                         return numLeadingDims + pos;
-                       }));
-  }
-  assert(isPermutationVector(permutation) && "invalid permutation");
-  // Step 3.b.ii. Save the transposedPackUse operand number in case we need to
-  // get the tied OpResult after `linalgOp` has been replaced.
-  OpOperand &packUse = *(packOp.getResult().getUses().begin());
-  int64_t packUseOperandNumber = packUse.getOperandNumber();
-  // Step 3.b.iii. Actually perform the transposition.
-  rewriter.setInsertionPoint(linalgOp);
-  linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
-      rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
-
-  // Step 3.c. Maybe transpose unPackOp.
-  tensor::UnPackOp transposedUnPackOp;
-  if (unPackOp) {
-    OpOperand &opOperand =
-        transposedLinalgOp->getOpOperand(packUseOperandNumber);
-    OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
-    rewriter.setInsertionPoint(unPackOp);
-    transposedUnPackOp = unPackOp.createTransposedClone(
-        rewriter, loc, transposedResult, getInnerPerm(), getOuterPerm());
-  }
-
-  // Step 4. Replace and return results.
-  rewriter.replaceOp(packOp, transposedPackOp->getResults());
-  transformResults.set(getPackOp().cast<OpResult>(), {transposedPackOp});
-  // transposedLinalgOp was replaced in `transposeOneLinalgOperandAndReplace`.
-  transformResults.set(getPackedOp().cast<OpResult>(), {transposedLinalgOp});
+  // Step 4. Return results.
+  transformResults.set(getPackOp().cast<OpResult>(), {res->transposedPackOp});
+  transformResults.set(getPackedOp().cast<OpResult>(),
+                       {res->transposedLinalgOp});
   if (unPackOp) {
-    rewriter.replaceOp(unPackOp, transposedUnPackOp->getResults());
-    transformResults.set(getUnPackOp().cast<OpResult>(), {transposedUnPackOp});
+    transformResults.set(getUnPackOp().cast<OpResult>(),
+                         {res->transposedUnPackOp});
   } else {
     transformResults.set(getUnPackOp().cast<OpResult>(), {});
   }
+
   return DiagnosedSilenceableFailure::success();
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8e8657d64e851..4799b6e4e98c7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1161,3 +1161,147 @@ FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
   // Return packedLinalgOp.
   return cast<linalg::LinalgOp>(packedLinalgOp.getOperation());
 }
+
+//===----------------------------------------------------------------------===//
+// packTranspose transformation.
+//===----------------------------------------------------------------------===//
+
+/// Return a copy of `tensorType` after permutation by `permutationVector`.
+// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
+// but this would introduce a dependence on Dialect in IR.
+// TODO: Restructure.
+static RankedTensorType permuteShape(RankedTensorType tensorType,
+                                     ArrayRef<int64_t> permutationVector) {
+  SmallVector<int64_t> shape(tensorType.getShape());
+  applyPermutationToVector(shape, permutationVector);
+  return RankedTensorType::Builder(tensorType).setShape(shape);
+}
+
+/// Return a new GenericOp obtained by transposing opOperand by the permutation
+/// vector:
+///   - the corresponding indexing map is transposed by `permutation`
+///   - the corresponding operand value is replaced by `transposedValue`
+/// `linalgOp` is replaced by the return op in the process.
+/// Asserts that `transposedValue` is of the proper transposed ShapedType.
+static LinalgOp transposeOneLinalgOperandAndReplace(
+    RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
+    ArrayRef<int64_t> permutation, Value transposedValue) {
+  // Sanity check the operand.
+  assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
+
+  // Sanity check of the expected transposed tensor type.
+  auto tensorType = permuteShape(
+      opOperand.get().getType().cast<RankedTensorType>(), permutation);
+  (void)tensorType;
+  assert(tensorType == transposedValue.getType() &&
+         "expected tensor type mismatch");
+
+  // Compute the transposed indexing map.
+  // Sigh unsigned pollution.
+  SmallVector<unsigned> tmpTransposition = llvm::to_vector(
+      llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
+  AffineMap permutationMap =
+      AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
+  AffineMap transposedMap =
+      permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
+
+  // Set the transposed indexing map in the proper position.
+  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+  indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
+  // Set the transposedValue in the proper operand position.
+  SmallVector<Value> operands = linalgOp->getOperands();
+  operands[opOperand.getOperandNumber()] = transposedValue;
+
+  ValueRange operandsRef(operands);
+  auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
+      /*location=*/linalgOp->getLoc(),
+      /*resultTensorTypes=*/
+      operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
+      /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
+      /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
+      /*indexingMaps=*/indexingMaps,
+      /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
+  transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
+  rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
+
+  return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
+}
+
+FailureOr<PackTransposeResult>
+linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
+                      linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
+                      ArrayRef<int64_t> outerPerm,
+                      ArrayRef<int64_t> innerPerm) {
+  Location loc = linalgOp.getLoc();
+
+  // Step 1. Transpose packOp.
+  rewriter.setInsertionPoint(packOp);
+  tensor::PackOp transposedPackOp =
+      packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
+
+  if (!packOp.getResult().hasOneUse())
+    return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
+
+  OpOperand &packUse = *packOp->getUses().begin();
+  if (packUse.getOwner() != linalgOp) {
+    return rewriter.notifyMatchFailure(
+        linalgOp, "not a single use by the LinalgOp target");
+  }
+  if (maybeUnPackOp &&
+      (!linalgOp.isDpsInit(&packUse) ||
+       maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
+    return rewriter.notifyMatchFailure(linalgOp,
+                                       "not produced by the LinalgOp target");
+  }
+
+  // Step 2. Transpose linalgOp.
+  // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
+  // identity. Don't rely on it.
+  int64_t numLeadingDims = packOp.getSourceRank();
+  int64_t numTrailingDims = packOp.getInnerDimsPos().size();
+  // Step 2.a. Compute the permutation on the whole operand.
+  // Leading part just reuse the outerPerm.
+  SmallVector<int64_t> permutation(outerPerm);
+  if (permutation.empty())
+    llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
+  // Trailing part needs to reindex positions by `numLeadingDims`.
+  if (innerPerm.empty()) {
+    llvm::append_range(
+        permutation,
+        llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
+  } else {
+    llvm::append_range(permutation,
+                       llvm::map_range(innerPerm, [&](int64_t pos) {
+                         return numLeadingDims + pos;
+                       }));
+  }
+  if (!isPermutationVector(permutation))
+    return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
+
+  // Step 2.b. Save the transposedPackUse operand number in case we need to
+  // get the tied OpResult after `linalgOp` has been replaced.
+  int64_t packUseOperandNumber = packUse.getOperandNumber();
+  // Step 2.c. Actually perform the transposition.
+  rewriter.setInsertionPoint(linalgOp);
+  linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
+      rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
+
+  // Step 3. Maybe transpose unPackOp.
+  tensor::UnPackOp transposedUnPackOp;
+  if (maybeUnPackOp) {
+    OpOperand &opOperand =
+        transposedLinalgOp->getOpOperand(packUseOperandNumber);
+    OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
+    rewriter.setInsertionPoint(maybeUnPackOp);
+    transposedUnPackOp = maybeUnPackOp.createTransposedClone(
+        rewriter, loc, transposedResult, innerPerm, outerPerm);
+
+    rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
+  }
+
+  // Step 4. Finally, replace packOp now that we don't need it anymore.
+  rewriter.replaceOp(packOp, transposedPackOp->getResults());
+
+  return PackTransposeResult{transposedPackOp, transposedLinalgOp,
+                             transposedUnPackOp};
+}


        


More information about the Mlir-commits mailing list