[Mlir-commits] [mlir] afad0cd - [mlir][vector] Refactor linalg vectorization for reductions

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 14 13:38:06 PDT 2021


Author: thomasraoux
Date: 2021-10-14T13:37:56-07:00
New Revision: afad0cdf31e8fe37fc805b36a2a57c4a8db8e5f7

URL: https://github.com/llvm/llvm-project/commit/afad0cdf31e8fe37fc805b36a2a57c4a8db8e5f7
DIFF: https://github.com/llvm/llvm-project/commit/afad0cdf31e8fe37fc805b36a2a57c4a8db8e5f7.diff

LOG: [mlir][vector] Refactor linalg vectorization for reductions

Emit reduction during op vectorization instead of doing it when creating the
transfer write. This allow us to not broadcast output arguments for reduction
initial value.

Differential Revision: https://reviews.llvm.org/D111825

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2fecffdbbd6b..f2641e20cdf0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -189,65 +189,18 @@ static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
 }
 
 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
-/// assumes that `reductionOp` has tow operands and one of them is the reduction
+/// assumes that `reductionOp` has two operands and one of them is the reduction
 /// initial value.
 static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
-                                 Value outputArg,
-                                 const SmallVector<bool> &reductionMask,
-                                 const BlockAndValueMapping &bvm) {
+                                 Value valueToReduce,
+                                 const SmallVector<bool> &reductionMask) {
   auto maybeKind = getKindForOp(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
-  Value operandToReduce = reduceOp->getOperand(0) == outputArg
-                              ? reduceOp->getOperand(1)
-                              : reduceOp->getOperand(0);
-  Value vec = bvm.lookup(operandToReduce);
-  return b.create<vector::MultiDimReductionOp>(reduceOp->getLoc(), vec,
-                                               reductionMask, *maybeKind);
-}
-
-/// Read the initial value associated to the given `outputOperand`.
-static Value readInitialValue(OpBuilder &b, LinalgOp linalgOp,
-                              OpOperand *outputOperand) {
-  AffineMap map = inversePermutation(
-      reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)));
-  Type readType;
-  if (linalgOp.getShape(outputOperand).empty()) {
-    readType = getElementTypeOrSelf(outputOperand->get());
-  } else {
-    readType = VectorType::get(map.compose(linalgOp.getShape(outputOperand)),
-                               getElementTypeOrSelf(outputOperand->get()));
-  }
-  Value vectorRead = buildVectorRead(b, outputOperand->get(), readType, map);
-  return vectorRead;
+  return b.create<vector::MultiDimReductionOp>(
+      reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
 }
 
-/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
-/// whether a reduction is needed to produce a `targetType` and create that
-/// reduction if it is the case.
-static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
-                            OpOperand *outputOperand,
-                            const BlockAndValueMapping &bvm) {
-  LDBG("Reduce " << value << " to type " << targetType);
-  LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
-                               << *(outputOperand->getOwner()));
-  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
-  auto vecType = value.getType().dyn_cast<VectorType>();
-  VectorType targetVectorType = targetType.dyn_cast<VectorType>();
-  if (!vecType)
-    return value;
-  if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
-    return value;
-
-  // At this point, we know we need to reduce. Detect the reduction operator.
-  unsigned pos = 0;
-  MLIRContext *ctx = b.getContext();
-  SmallVector<AffineExpr> exprs;
-  for (auto s : linalgOp.iterator_types())
-    if (isParallelIterator(s))
-      exprs.push_back(getAffineDimExpr(pos++, ctx));
-
-  Operation *reduceOp = matchLinalgReduction(outputOperand);
-  assert(reduceOp && "Failed precondition: could not math a reduction");
+static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
   unsigned idx = 0;
   SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
   for (auto attr : linalgOp.iterator_types()) {
@@ -255,24 +208,7 @@ static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
       reductionMask[idx] = true;
     ++idx;
   }
-  assert(reduceOp->getNumOperands() == 2 &&
-         "Only support binary reduce op right now");
-  unsigned outputPos =
-      outputOperand->getOperandNumber() - linalgOp.getNumInputs();
-  Value outputArg = linalgOp.getRegionOutputArgs()[outputPos];
-  // Reduce across the iteration space.
-  Value reduce =
-      buildMultiDimReduce(b, reduceOp, outputArg, reductionMask, bvm);
-
-  // Read the original output value.
-  Value initialValue = readInitialValue(b, linalgOp, outputOperand);
-
-  // Combine the output argument with the reduced value.
-  OperationState state(reduceOp->getLoc(), reduceOp->getName());
-  state.addAttributes(reduceOp->getAttrs());
-  state.addOperands({reduce, initialValue});
-  state.addTypes(initialValue.getType());
-  return b.createOperation(state)->getResult(0);
+  return reductionMask;
 }
 
 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@@ -280,8 +216,7 @@ static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
 /// currently being vectorized. If `dest` has null rank, build an memref.store.
 /// Return the produced value or null if no value is produced.
 static Value buildVectorWrite(OpBuilder &b, Value value,
-                              OpOperand *outputOperand,
-                              const BlockAndValueMapping &bvm) {
+                              OpOperand *outputOperand) {
   Operation *write;
   Location loc = value.getLoc();
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
@@ -296,12 +231,9 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
                                b.create<arith::ConstantIndexOp>(loc, 0));
     value = broadcastIfNeeded(b, value, vectorType.getShape());
-    value = reduceIfNeeded(b, vectorType, value, outputOperand, bvm);
     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
                                               indices, map);
   } else {
-    value = reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand,
-                           bvm);
     write = vector::TransferWriteOp::createScalarOp(
         b, loc, value, outputOperand->get(), ValueRange{});
   }
@@ -336,7 +268,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
     // TODO: use a map.
     Value vectorValue = bvm.lookup(outputs.value());
     Value newResult = buildVectorWrite(
-        b, vectorValue, linalgOp.getOutputOperand(outputs.index()), bvm);
+        b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
     if (newResult)
       newResults.push_back(newResult);
   }
@@ -379,6 +311,36 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
 }
 
+/// Create a new vectorized verstion of `op` with the given operands and types.
+static Operation *createVectorizedOp(OpBuilder &b, Operation *op,
+                                     ValueRange newOperands,
+                                     ArrayRef<Type> types) {
+  OperationState state(op->getLoc(), op->getName());
+  state.addAttributes(op->getAttrs());
+  state.addOperands(newOperands);
+  state.addTypes(types);
+  return b.createOperation(state);
+}
+
+/// Emit reduction operations if the shapes of the value to reduce is 
diff erent
+/// that the result shape.
+static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
+                                 Value reduceValue, Value initialValue,
+                                 const BlockAndValueMapping &bvm) {
+  Value reduceVec = bvm.lookup(reduceValue);
+  Value outputVec = bvm.lookup(initialValue);
+  auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
+  auto outputType = outputVec.getType().dyn_cast<VectorType>();
+  // Reduce only if needed as the value may already have been reduce for
+  // contraction vectorization.
+  if (!reduceType ||
+      (outputType && reduceType.getShape() == outputType.getShape()))
+    return nullptr;
+  SmallVector<bool> reductionMask = getReductionMask(linalgOp);
+  Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask);
+  return createVectorizedOp(b, op, {reduce, outputVec}, reduce.getType());
+}
+
 /// Generic vectorization for a single operation `op`, given already vectorized
 /// operands carried by `bvm`. Vectorization occurs as follows:
 ///   1. Try to apply any of the `customVectorizationHooks` and return its
@@ -399,7 +361,8 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
 /// This function does not update `bvm` but returns a VectorizationStatus that
 /// instructs the caller what `bvm` update needs to occur.
 static VectorizationResult
-vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
+vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
+               const BlockAndValueMapping &bvm,
                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
   LDBG("vectorize op " << *op);
 
@@ -422,7 +385,30 @@ vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
 
-  // 4. Generic vectorization path for ElementwiseMappable ops.
+  // 4 . Check if the operation is a reduction.
+  SmallVector<std::pair<Value, Value>> reductionOperands;
+  for (Value operand : op->getOperands()) {
+    auto arg = operand.dyn_cast<BlockArgument>();
+    if (!arg || arg.getArgNumber() < linalgOp.getNumInputs())
+      continue;
+    SmallVector<Operation *> reductionOps;
+    Value reduceValue = matchReduction(
+        linalgOp.getRegionOutputArgs(),
+        arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps);
+    if (!reduceValue)
+      continue;
+    reductionOperands.push_back(std::make_pair(reduceValue, operand));
+  }
+  if (!reductionOperands.empty()) {
+    assert(reductionOperands.size() == 1);
+    Operation *reduceOp =
+        reduceIfNeeded(b, linalgOp, op, reductionOperands[0].first,
+                       reductionOperands[0].second, bvm);
+    if (reduceOp)
+      return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
+  }
+
+  // 5. Generic vectorization path for ElementwiseMappable ops.
   //   a. first get the first max ranked shape.
   SmallVector<int64_t, 4> firstMaxRankedShape;
   for (Value operand : op->getOperands()) {
@@ -444,12 +430,10 @@ vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
   });
 
   // Build and return the new op.
-  OperationState state(op->getLoc(), op->getName());
-  state.addAttributes(op->getAttrs());
-  state.addOperands(llvm::to_vector<4>(vectorizedOperands));
-  state.addTypes(llvm::to_vector<4>(returnTypes));
-  return VectorizationResult{VectorizationStatus::NewOp,
-                             b.createOperation(state)};
+  return VectorizationResult{
+      VectorizationStatus::NewOp,
+      createVectorizedOp(b, op, llvm::to_vector<4>(vectorizedOperands),
+                         llvm::to_vector<4>(returnTypes))};
 }
 
 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -539,7 +523,8 @@ LogicalResult vectorizeAsLinalgGeneric(
     if (linalgOp.getShape(opOperand).empty()) {
       readType = bbarg.getType();
     } else {
-      if (broadcastToMaximalCommonShape) {
+      if (broadcastToMaximalCommonShape &&
+          opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
         map = inverseAndBroadcastProjectedPermuation(
             linalgOp.getTiedIndexingMap(opOperand));
         readType = VectorType::get(commonVectorShape,
@@ -576,7 +561,7 @@ LogicalResult vectorizeAsLinalgGeneric(
 
   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
   for (Operation &op : block->getOperations()) {
-    VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
+    VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
       LDBG("failed to vectorize: " << op);
       return failure();

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 99c863d4b4d8..206bd7f94a5c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -749,9 +749,9 @@ func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
   -> tensor<4x16xf32>
 {
   // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
   // CHECK: math.exp {{.*}} : vector<4x16x8xf32>
   // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
-  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
   // CHECK: addf {{.*}} : vector<4x16xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
   // CHECK: return {{.*}} : tensor<4x16xf32>
@@ -782,11 +782,11 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
 {
   // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32>
   // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32>
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
   // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
   // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
   // CHECK: addf {{.*}} : vector<2x3x4x5xf32>
   // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}}  [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
-  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
   // CHECK: addf {{.*}} : vector<2x5xf32>
   // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
   // CHECK: return {{.*}} : tensor<5x2xf32>


        


More information about the Mlir-commits mailing list