[Mlir-commits] [mlir] 912ebf6 - [mlir][linalg] Cleanup LinalgOp usage in vectorization (NFC).

Tobias Gysi llvmlistbot at llvm.org
Tue Jun 1 01:09:29 PDT 2021


Author: Tobias Gysi
Date: 2021-06-01T08:08:40Z
New Revision: 912ebf60b15123827299df73a7c9136f6693b487

URL: https://github.com/llvm/llvm-project/commit/912ebf60b15123827299df73a7c9136f6693b487
DIFF: https://github.com/llvm/llvm-project/commit/912ebf60b15123827299df73a7c9136f6693b487.diff

LOG: [mlir][linalg] Cleanup LinalgOp usage in vectorization (NFC).

Replace the uses of deprecated Structured Op Interface methods in Vectorization.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103410

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7ee5d5f4dd744..12a8d80c72fcc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -116,14 +116,14 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
 /// Linalg. This limitation is motivated by the fact that e.g.
 /// min(max(X)) != max(min(X))
 // TODO: use in LinalgOp verification, there is a circular dependency atm.
-static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) {
-  auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
+static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
+  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
   auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
   unsigned yieldNum =
-      outputOperand.getOperandNumber() - linalgOp.getNumInputs();
+      outputOperand->getOperandNumber() - linalgOp.getNumInputs();
   llvm::SetVector<Operation *> backwardSlice, forwardSlice;
   BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
-      outputOperand.getOperandNumber());
+      outputOperand->getOperandNumber());
   Value yieldVal = yieldOp->getOperand(yieldNum);
   getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
     return op->getParentOp() == linalgOp;
@@ -186,16 +186,15 @@ getKindForOp(Operation *reductionOp) {
 /// return a new vector.broadcast to `shape`.
 /// Otherwise, just return value.
 static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
-                            Value value, OpOperand &outputOperand) {
-  assert(targetVectorType.getShape() ==
-         outputOperand.get().getType().cast<ShapedType>().getShape());
+                            Value value, OpOperand *outputOperand) {
+  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
+  assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand));
   auto vecType = value.getType().dyn_cast<VectorType>();
   if (!vecType || vecType.getShape() == targetVectorType.getShape())
     return value;
   // At this point, we know we need to reduce. Detect the reduction operator.
   // TODO: Use the generic reduction detection util.
   Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
-  auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
   unsigned pos = 0;
   MLIRContext *ctx = b.getContext();
   SmallVector<AffineExpr> exprs;
@@ -235,23 +234,22 @@ static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
 /// currently being vectorized. If `dest` has null rank, build an memref.store.
 /// Return the produced value or null if no value is produced.
 static Value buildVectorWrite(OpBuilder &b, Value value,
-                              OpOperand &outputOperand) {
+                              OpOperand *outputOperand) {
   Operation *write;
   Location loc = value.getLoc();
-  auto shapedType = outputOperand.get().getType().cast<ShapedType>();
   if (VectorType vectorType =
-          extractVectorTypeFromShapedValue(outputOperand.get())) {
-    auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
-    AffineMap map = reindexIndexingMap(
-        linalgOp.getIndexingMap(outputOperand.getOperandNumber()));
-    SmallVector<Value> indices(shapedType.getRank(),
+          extractVectorTypeFromShapedValue(outputOperand->get())) {
+    auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
+    AffineMap map =
+        reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
+    SmallVector<Value> indices(linalgOp.getRank(outputOperand),
                                b.create<ConstantIndexOp>(loc, 0));
     value = broadcastIfNeeded(b, value, vectorType.getShape());
     value = reduceIfNeeded(b, vectorType, value, outputOperand);
-    write = b.create<vector::TransferWriteOp>(loc, value, outputOperand.get(),
+    write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
                                               indices, map);
   } else {
-    write = b.create<memref::StoreOp>(loc, value, outputOperand.get());
+    write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
   }
   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
   if (!write->getResults().empty())
@@ -284,7 +282,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
     // TODO: use a map.
     Value vectorValue = bvm.lookup(outputs.value());
     Value newResult = buildVectorWrite(
-        b, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]);
+        b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
     if (newResult)
       newResults.push_back(newResult);
   }
@@ -422,8 +420,8 @@ static bool isElementwise(Operation *op) {
   if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
     return false;
   // TODO: relax the restrictions on indexing map.
-  for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
-    if (!linalgOp.getOutputIndexingMap(i).isIdentity())
+  for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
+    if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
       return false;
   }
   if (linalgOp->getNumRegions() != 1)
@@ -479,36 +477,37 @@ LogicalResult vectorizeAsLinalgGeneric(
 
   // 3. Turn all BBArgs into vector.transfer_read / load.
   SmallVector<AffineMap> indexings;
-  for (auto bbarg : block.getArguments()) {
-    Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
-    ShapedType shapedType = shapedArg.getType().cast<ShapedType>();
+  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
+    BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
     // TODO: 0-d vectors.
-    if (shapedType.getShape().empty()) {
-      Value loaded = b.create<memref::LoadOp>(linalgOp.getLoc(), shapedArg);
+    if (linalgOp.getShape(opOperand).empty()) {
+      Value loaded =
+          b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
                         << bbarg.getArgNumber() << "): " << loaded);
       bvm.map(bbarg, loaded);
-      bvm.map(shapedArg, loaded);
+      bvm.map(opOperand->get(), loaded);
       continue;
     }
     AffineMap map;
     VectorType vectorType;
     if (broadcastToMaximalCommonShape) {
       map = inverseAndBroadcastProjectedPermuation(
-          linalgOp.getIndexingMap(bbarg.getArgNumber()));
-      vectorType =
-          VectorType::get(commonVectorShape, shapedType.getElementType());
+          linalgOp.getTiedIndexingMap(opOperand));
+      vectorType = VectorType::get(
+          commonVectorShape, getElementTypeOrSelf(opOperand->get().getType()));
     } else {
       map = inversePermutation(
-          reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
-      vectorType = VectorType::get(map.compose(shapedType.getShape()),
-                                   shapedType.getElementType());
+          reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
+      vectorType =
+          VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+                          getElementTypeOrSelf(opOperand->get().getType()));
     }
-    Value vectorRead = buildVectorRead(b, shapedArg, vectorType, map);
+    Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
                       << bbarg.getArgNumber() << "): " << vectorRead);
     bvm.map(bbarg, vectorRead);
-    bvm.map(shapedArg, vectorRead);
+    bvm.map(opOperand->get(), vectorRead);
   }
 
   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@@ -562,7 +561,8 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
           const BlockAndValueMapping &bvm) -> VectorizationResult {
     if (!isa<MulIOp, MulFOp>(op))
       return VectorizationResult{VectorizationStatus::Failure, nullptr};
-    auto outShape = linalgOp.getOutputShapedType(0).getShape();
+    ArrayRef<int64_t> outShape =
+        linalgOp.getShape(linalgOp.getOutputOperand(0));
     auto vType = outShape.empty()
                      ? op->getResult(0).getType()
                      : VectorType::get(outShape, op->getResult(0).getType());
@@ -574,13 +574,14 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
     // TODO: consider dropping contraction special casing altogether, this will
     // require more advanced canonicalizations involving vector.multi_reduction
     // that are not yet available.
-    SmallVector<AffineMap> indexingMaps{
-        inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0)))
-            .compose(linalgOp.getIndexingMap(0)),
-        inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1)))
-            .compose(linalgOp.getIndexingMap(1)),
-        inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2)))
-            .compose(linalgOp.getIndexingMap(2))};
+    SmallVector<AffineMap> indexingMaps;
+    indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
+    llvm::transform(linalgOp.getIndexingMaps(),
+                    std::back_inserter(indexingMaps),
+                    [](AffineMap indexingMap) {
+                      return inversePermutation(reindexIndexingMap(indexingMap))
+                          .compose(indexingMap);
+                    });
     Operation *contract = b.create<vector::ContractionOp>(
         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
         b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
@@ -601,8 +602,8 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
 static LogicalResult reductionPreconditions(LinalgOp op) {
   if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
     return failure();
-  for (auto &operand : op.getOutputOpOperands()) {
-    Operation *reductionOp = getSingleBinaryOpAssumedReduction(operand);
+  for (OpOperand *opOperand : op.getOutputOperands()) {
+    Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);
     if (!getKindForOp(reductionOp))
       return failure();
   }
@@ -612,12 +613,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   auto linalgOp = cast<linalg::LinalgOp>(op);
   // All types must be static shape to go to vector.
-  for (Value operand : linalgOp.getShapedOperands())
-    if (!operand.getType().cast<ShapedType>().hasStaticShape())
-      return failure();
-  for (Type outputTensorType : linalgOp.getOutputTensorTypes())
-    if (!outputTensorType.cast<ShapedType>().hasStaticShape())
-      return failure();
+  if (linalgOp.hasDynamicShape())
+    return failure();
   if (isElementwise(op))
     return success();
   if (isaContractionOpInterface(linalgOp))
@@ -722,13 +719,14 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
   Location loc = op.getLoc();
   MLIRContext *context = op.getContext();
 
-  ShapedType inShapeType = op.getInputShapedType(0);
-  ShapedType kShapeType = op.getInputShapedType(1);
-
-  ArrayRef<int64_t> inShape = inShapeType.getShape();
-  ArrayRef<int64_t> kShape = kShapeType.getShape();
+  OpOperand *input = op.getInputOperand(0);
+  OpOperand *kernel = op.getInputOperand(1);
+  OpOperand *output = op.getOutputOperand(0);
+  ArrayRef<int64_t> inShape = op.getShape(input);
+  ArrayRef<int64_t> kShape = op.getShape(kernel);
 
-  if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
+  if (llvm::any_of(inShape, ShapedType::isDynamic) ||
+      llvm::any_of(kShape, ShapedType::isDynamic))
     return failure();
 
   SmallVector<AffineExpr, 4> mapping;
@@ -747,22 +745,18 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
     }
   }
 
-  Value input = op.getInput(0);
-  Value kernel = op.getInput(1);
-  Value output = op.getOutputBuffer(0);
-
-  unsigned rank = inShapeType.getRank();
-  unsigned numDims = mapping.size();
-  Type elemType = inShapeType.getElementType();
+  int64_t rank = op.getRank(input);
+  int64_t numDims = mapping.size();
+  Type elemType = getElementTypeOrSelf(input->get().getType());
 
   auto map = AffineMap::get(rank, 0, mapping, context);
   SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
   auto vecType = VectorType::get(vectorDims, elemType);
 
-  auto inputVec =
-      rewriter.create<vector::TransferReadOp>(loc, vecType, input, zeros, map);
-  auto kernelVec =
-      rewriter.create<vector::TransferReadOp>(loc, vecType, kernel, zeros, map);
+  auto inputVec = rewriter.create<vector::TransferReadOp>(
+      loc, vecType, input->get(), zeros, map);
+  auto kernelVec = rewriter.create<vector::TransferReadOp>(
+      loc, vecType, kernel->get(), zeros, map);
 
   auto acc = rewriter.create<ConstantOp>(loc, elemType,
                                          rewriter.getZeroAttr(elemType));
@@ -779,7 +773,8 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
       rewriter.getAffineMapArrayAttr(indexingMaps),
       rewriter.getStrArrayAttr(iteratorTypes));
 
-  rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
+  rewriter.create<memref::StoreOp>(loc, result, output->get(),
+                                   ValueRange(zeros));
   rewriter.eraseOp(op);
   return success();
 }
@@ -939,7 +934,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   CopyOp copyOp;
   for (auto &u : subView.getUses()) {
     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
-      if (newCopyOp.getOutputBuffer(0) != subView)
+      assert(newCopyOp.output().getType().isa<MemRefType>());
+      if (newCopyOp.output() != subView)
         continue;
       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
                               << "copy candidate " << *newCopyOp);
@@ -958,7 +954,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   FillOp maybeFillOp;
   for (auto &u : viewOrAlloc.getUses()) {
     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
-      if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
+      assert(newFillOp.output().getType().isa<MemRefType>());
+      if (newFillOp.output() != viewOrAlloc)
         continue;
       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
                               << "fill candidate " << *newFillOp);
@@ -976,7 +973,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
                             << "with maybeFillOp " << *maybeFillOp);
 
   // `in` is the subview that linalg.copy reads. Replace it.
-  Value in = copyOp.getInput(0);
+  Value in = copyOp.input();
 
   // linalg.copy + linalg.fill can be used to create a padded local buffer.
   // The `masked` attribute is only valid on this padded buffer.
@@ -1014,7 +1011,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
   CopyOp copyOp;
   for (auto &u : subViewOp.getResult().getUses()) {
     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
-      if (newCopyOp.getInput(0) != subView)
+      if (newCopyOp.getInputOperand(0)->get() != subView)
         continue;
       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
         continue;
@@ -1026,7 +1023,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
     return failure();
 
   // `out` is the subview copied into that we replace.
-  Value out = copyOp.getOutputBuffer(0);
+  assert(copyOp.output().getType().isa<MemRefType>());
+  Value out = copyOp.output();
 
   // Forward vector.transfer into copy.
   // linalg.copy + linalg.fill can be used to create a padded local buffer.


        


More information about the Mlir-commits mailing list