[Mlir-commits] [mlir] [mlir][vector] Add pattern for dropping unit dims from for loops (PR #109585)

Quinn Dawkins llvmlistbot at llvm.org
Sun Sep 22 11:02:54 PDT 2024


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/109585

This adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes.

>From 8f3b21204806d74190ca3e43879e5b0e703306a4 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 22 Sep 2024 13:27:04 -0400
Subject: [PATCH] [mlir][vector] Add pattern for dropping unit dims from for
 loops

This adds a pattern for dropping unit dims from the iter_args of scf.for
ops using vector.shape_cast. This composes with the other patterns for
dropping unit dims from elementwise ops and transposes.
---
 mlir/include/mlir/Dialect/SCF/IR/SCF.h        |  11 ++
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 142 +++++++++---------
 .../Vector/Transforms/VectorTransforms.cpp    |  60 +++++++-
 .../drop-unit-dims-with-shape-cast.mlir       |  75 +++++++++
 4 files changed, 215 insertions(+), 73 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 644118ca884c6b..d89d566ece62c1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -107,6 +107,17 @@ LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
                        function_ref<void(OpBuilder &, Location, ValueRange)>
                            bodyBuilder = nullptr);
 
+/// Perform a replacement of one iter OpOperand of an scf.for to the
+/// `replacement` value with a different type. A callback is used to insert
+/// cast ops inside the block to account for type differences.
+using ValueTypeCastFnTy =
+    std::function<Value(OpBuilder &, Location loc, Type, Value)>;
+SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
+                                              scf::ForOp forOp,
+                                              OpOperand &operand,
+                                              Value replacement,
+                                              const ValueTypeCastFnTy &castFn);
+
 } // namespace scf
 } // namespace mlir
 #endif // MLIR_DIALECT_SCF_SCF_H
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6d47ff3890977a..d1c9fd2d217dad 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest(
                        });
 }
 
+SmallVector<Value>
+mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
+                                      OpOperand &operand, Value replacement,
+                                      const ValueTypeCastFnTy &castFn) {
+  assert(operand.getOwner() == forOp);
+  Type oldType = operand.get().getType(), newType = replacement.getType();
+
+  // 1. Create new iter operands, exactly 1 is replaced.
+  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
+         "expected an iter OpOperand");
+  assert(operand.get().getType() != replacement.getType() &&
+         "Expected a different type");
+  SmallVector<Value> newIterOperands;
+  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
+    if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
+      newIterOperands.push_back(replacement);
+      continue;
+    }
+    newIterOperands.push_back(opOperand.get());
+  }
+
+  // 2. Create the new forOp shell.
+  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
+      forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+      forOp.getStep(), newIterOperands);
+  newForOp->setAttrs(forOp->getAttrs());
+  Block &newBlock = newForOp.getRegion().front();
+  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
+                                             newBlock.getArguments().end());
+
+  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
+  // corresponding to the `replacement` value.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
+  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
+      &newForOp->getOpOperand(operand.getOperandNumber()));
+  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
+  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
+
+  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
+  Block &oldBlock = forOp.getRegion().front();
+  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+
+  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
+  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+  rewriter.setInsertionPoint(clonedYieldOp);
+  unsigned yieldIdx =
+      newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
+  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
+                         clonedYieldOp.getOperand(yieldIdx));
+  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
+  newYieldOperands[yieldIdx] = castOut;
+  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
+  rewriter.eraseOp(clonedYieldOp);
+
+  // 6. Inject an outgoing cast op after the forOp.
+  rewriter.setInsertionPointAfter(newForOp);
+  SmallVector<Value> newResults = newForOp.getResults();
+  newResults[yieldIdx] =
+      castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
+
+  return newResults;
+}
+
 namespace {
 // Fold away ForOp iter arguments when:
 // 1) The op yields the iter arguments.
@@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
   }
 };
 
-/// Perform a replacement of one iter OpOperand of an scf.for to the
-/// `replacement` value which is expected to be the source of a tensor.cast.
-/// tensor.cast ops are inserted inside the block to account for the type cast.
-static SmallVector<Value>
-replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
-                              Value replacement) {
-  Type oldType = operand.get().getType(), newType = replacement.getType();
-  assert(llvm::isa<RankedTensorType>(oldType) &&
-         llvm::isa<RankedTensorType>(newType) &&
-         "expected ranked tensor types");
-
-  // 1. Create new iter operands, exactly 1 is replaced.
-  ForOp forOp = cast<ForOp>(operand.getOwner());
-  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
-         "expected an iter OpOperand");
-  assert(operand.get().getType() != replacement.getType() &&
-         "Expected a different type");
-  SmallVector<Value> newIterOperands;
-  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
-    if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
-      newIterOperands.push_back(replacement);
-      continue;
-    }
-    newIterOperands.push_back(opOperand.get());
-  }
-
-  // 2. Create the new forOp shell.
-  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
-      forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-      forOp.getStep(), newIterOperands);
-  newForOp->setAttrs(forOp->getAttrs());
-  Block &newBlock = newForOp.getRegion().front();
-  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
-                                             newBlock.getArguments().end());
-
-  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
-  // corresponding to the `replacement` value.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
-  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
-      &newForOp->getOpOperand(operand.getOperandNumber()));
-  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
-                                                 newRegionIterArg);
-  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
-
-  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
-  Block &oldBlock = forOp.getRegion().front();
-  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
-
-  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
-  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
-  rewriter.setInsertionPoint(clonedYieldOp);
-  unsigned yieldIdx =
-      newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
-  Value castOut = rewriter.create<tensor::CastOp>(
-      newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
-  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
-  newYieldOperands[yieldIdx] = castOut;
-  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
-  rewriter.eraseOp(clonedYieldOp);
-
-  // 6. Inject an outgoing cast op after the forOp.
-  rewriter.setInsertionPointAfter(newForOp);
-  SmallVector<Value> newResults = newForOp.getResults();
-  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
-      newForOp.getLoc(), oldType, newResults[yieldIdx]);
-
-  return newResults;
-}
-
 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
 ///
@@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
         continue;
 
       // Create a new ForOp with that iter operand replaced.
+      ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
+                                    Value source) {
+        return b.create<tensor::CastOp>(loc, type, source);
+      };
       rewriter.replaceOp(
-          op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
-                                            incomingCast.getSource()));
+          op, replaceAndCastForOpIterArg(rewriter, op, iterOpOperand,
+                                         incomingCast.getSource(), castFn));
       return success();
     }
     return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ad4e42b31962e1..ba32583fc3cdc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1796,6 +1796,63 @@ struct DropUnitDimsFromTransposeOp final
   }
 };
 
+/// A pattern to drop unit dims from the iter_args of an scf.for.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
+///    ...
+///    scf.yield %
+///  }
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %drop = vector.shape_cast %init
+///    : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
+///  %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
+///    %new_iter = vector.shape_cast %iter
+///      : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+///    ...
+///  }
+///  %res = vector.shape_cast %new_loop
+///    : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+///  ```
+struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    /// Find the first iter_arg with droppable unit dims. Further applications
+    /// of this pattern will apply to later arguments.
+    for (OpOperand &operand : forOp.getInitArgsMutable()) {
+      auto vectorType = dyn_cast<VectorType>(operand.get().getType());
+      if (!vectorType)
+        continue;
+
+      VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
+      if (vectorType == newVectorType)
+        continue;
+
+      // Create a new ForOp with that iter operand replaced.
+      mlir::scf::ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc,
+                                               Type type, Value source) {
+        return b.create<vector::ShapeCastOp>(loc, type, source);
+      };
+
+      Value replacement =
+          castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
+      rewriter.replaceOp(forOp,
+                         replaceAndCastForOpIterArg(rewriter, forOp, operand,
+                                                    replacement, castFn));
+      return success();
+    }
+    return failure();
+  }
+};
+
 /// Pattern to eliminate redundant zero-constants added to reduction operands.
 /// It's enough for there to be one initial zero value, so we can eliminate the
 /// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -2001,7 +2058,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
-               ShapeCastOpFolder>(patterns.getContext(), benefit);
+               ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index af3fc924c1dbe7..8249400a43c757 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -207,3 +207,78 @@ func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vect
 
 // CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
 // CHECK-NOT: vector.shape_cast
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromScfForOp]
+///----------------------------------------------------------------------------------------
+
+func.func @scf_for_with_internal_unit_dims(%vec: vector<4x1x1x[4]xf32>) -> vector<4x1x1x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<4x1x1x[4]xf32> {
+    %s = math.sqrt %iter : vector<4x1x1x[4]xf32>
+    scf.yield %s : vector<4x1x1x[4]xf32>
+  }
+  return %res : vector<4x1x1x[4]xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_internal_unit_dims
+//  CHECK-SAME:   %[[VEC:[A-Za-z0-9]+]]: vector<4x1x1x[4]xf32>
+//       CHECK:   %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
+//       CHECK:   %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
+//       CHECK:     %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<4x[4]xf32>
+//       CHECK:     scf.yield %[[SQRT]]
+//       CHECK:   %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<4x[4]xf32> to vector<4x1x1x[4]xf32>
+//       CHECK:   return %[[CASTBACK]]
+
+// -----
+
+func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x1xf32> {
+    %s = math.sqrt %iter : vector<1x1xf32>
+    scf.yield %s : vector<1x1xf32>
+  }
+  return %res : vector<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_all_unit_dims
+//  CHECK-SAME:   %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32>
+//       CHECK:   %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32>
+//       CHECK:   %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
+//       CHECK:     %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32>
+//       CHECK:     scf.yield %[[SQRT]]
+//       CHECK:   %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32>
+//       CHECK:   return %[[CASTBACK]]
+
+// -----
+
+func.func @scf_for_with_multiple_operands(%idx: index, %vec0: vector<1x4xf32>, %vec1: vector<1x4xf32>) -> vector<1x4xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %res:3 = scf.for %i = %c0 to %c4 step %c1
+    iter_args(%id = %idx, %iter0 = %vec0, %iter1 = %vec1) -> (index, vector<1x4xf32>, vector<1x4xf32>) {
+    %add = arith.addf %iter0, %iter1 : vector<1x4xf32>
+    scf.yield %id, %add, %add : index, vector<1x4xf32>, vector<1x4xf32>
+  }
+  return %res#1 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_multiple_operands
+//  CHECK-SAME:   %[[IDX:[A-Za-z0-9]+]]: index
+//  CHECK-SAME:   %[[VEC0:[A-Za-z0-9]+]]: vector<1x4xf32>
+//  CHECK-SAME:   %[[VEC1:[A-Za-z0-9]+]]: vector<1x4xf32>
+//   CHECK-DAG:   %[[CAST0:.+]] = vector.shape_cast %[[VEC0]] : vector<1x4xf32> to vector<4xf32>
+//   CHECK-DAG:   %[[CAST1:.+]] = vector.shape_cast %[[VEC1]] : vector<1x4xf32> to vector<4xf32>
+//       CHECK:   %[[LOOP:.+]]:3 = scf.for
+//  CHECK-SAME:     iter_args(%{{.*}} = %[[IDX]], %[[ITER0:.+]] = %[[CAST0]], %[[ITER1:.+]] = %[[CAST1]])
+//       CHECK:     %[[ADD:.+]] = arith.addf %[[ITER0]], %[[ITER1]] : vector<4xf32>
+//       CHECK:     scf.yield %{{.*}}, %[[ADD]], %[[ADD]]
+//       CHECK:   %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]]#1 : vector<4xf32> to vector<1x4xf32>
+//       CHECK:   return %[[CASTBACK]]



More information about the Mlir-commits mailing list