[Mlir-commits] [mlir] 753a67b - [mlir][Linalg] Refactor and improve vectorization to add support for reduction into 0-d tensors.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Oct 12 05:47:41 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-12T12:47:36Z
New Revision: 753a67b5c98f86ddddd4326e73de600250ea3cbe

URL: https://github.com/llvm/llvm-project/commit/753a67b5c98f86ddddd4326e73de600250ea3cbe
DIFF: https://github.com/llvm/llvm-project/commit/753a67b5c98f86ddddd4326e73de600250ea3cbe.diff

LOG: [mlir][Linalg] Refactor and improve vectorization to add support for reduction into 0-d tensors.

This revision takes advantage of the recently added support for 0-d transfers and vector.multi_reduction that return a scalar.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D111626

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f48ef35cdab07..cdd5fcdbc548a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1288,7 +1288,7 @@ def Vector_TransferReadOp :
     OpBuilder<(ins "VectorType":$vector, "Value":$source,
       "ValueRange":$indices, "AffineMap":$permutationMap,
       CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
-    // Builder that sets padding to 'getMinorIdentityMap'.
+    // Builder that sets permutation map to 'getMinorIdentityMap'.
     OpBuilder<(ins "VectorType":$vector, "Value":$source,
       "ValueRange":$indices, "Value":$padding,
       CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
@@ -1306,6 +1306,17 @@ def Vector_TransferReadOp :
       "ArrayAttr":$inBounds)>
   ];
 
+  let extraClassDeclaration = [{
+    /// Temporary convenience builders to account for the fact that we do not
+    /// have 0-d vectors atm. These create a constant `vector<1xt>` and
+    /// insert/extract into it.
+    // Builder that sets permutation map (resp. padding) to
+    // 'getMinorIdentityMap' (resp. zero).
+    static Value createScalarOp(OpBuilder &builder, Location loc, Value source,
+                                ValueRange indices,
+                                ArrayRef<bool> inBounds = ArrayRef<bool>{});
+  }];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
@@ -1416,11 +1427,12 @@ def Vector_TransferWriteOp :
   }];
 
   let builders = [
+    // Builder that sets an empty mask.
+    OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
+      "AffineMap":$permutationMap, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
     // Builder that sets permutation map to 'getMinorIdentityMap'.
     OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
       CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
-    OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
-      "AffineMap":$permutationMap)>,
     OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
       "AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
     OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
@@ -1429,6 +1441,18 @@ def Vector_TransferWriteOp :
       "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
   ];
 
+  let extraClassDeclaration = [{
+    /// Temporary convenience builders to account for the fact that we do not
+    /// have 0-d vectors atm. These create a constant `vector<1xt>` and
+    /// insert/extract into it.
+    // Builder that sets permutation map (resp. padding) to
+    // 'getMinorIdentityMap' (resp. zero).
+    static Operation *createScalarOp(
+      OpBuilder &builder, Location loc, Value value,
+      Value dest, ValueRange indices,
+      ArrayRef<bool> inBounds = ArrayRef<bool>{});
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 60a9e67e476a6..b6408135b7b7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -40,6 +40,9 @@ using llvm::dbgs;
 
 #define DEBUG_TYPE "linalg-vectorization"
 
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X)
+
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
 template <typename OpType>
@@ -106,7 +109,7 @@ struct VectorizationResult {
 /// ShapedType of `v`.
 static VectorType extractVectorTypeFromShapedValue(Value v) {
   auto st = v.getType().cast<ShapedType>();
-  if (st.isa<MemRefType>() && st.getShape().empty())
+  if (st.getShape().empty())
     return VectorType();
   return VectorType::get(st.getShape(), st.getElementType());
 }
@@ -163,16 +166,23 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
   return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
 }
 
-/// If value of assumed VectorType has a shape 
diff erent than `shape`, build and
-/// return a new vector.broadcast to `shape`.
-/// Otherwise, just return value.
-static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
-                            Value value, OpOperand *outputOperand) {
+/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
+/// whether a reduction is needed to produce a `targetType` and create that
+/// reduction if it is the case.
+static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
+                            OpOperand *outputOperand) {
+  LDBG("Reduce " << value << " to type " << targetType);
+  LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
+                               << *(outputOperand->getOwner()));
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
   auto vecType = value.getType().dyn_cast<VectorType>();
-  if (!vecType || vecType.getShape() == targetVectorType.getShape())
+  VectorType targetVectorType = targetType.dyn_cast<VectorType>();
+  if (!vecType)
+    return value;
+  if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
     return value;
 
+  // At this point, we know we need to reduce. Detect the reduction operator.
   unsigned pos = 0;
   MLIRContext *ctx = b.getContext();
   SmallVector<AffineExpr> exprs;
@@ -181,7 +191,6 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
       exprs.push_back(getAffineDimExpr(pos++, ctx));
   auto loc = value.getLoc();
 
-  // At this point, we know we need to reduce. Detect the reduction operator.
   auto maybeKind = matchLinalgReduction(outputOperand);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   unsigned idx = 0;
@@ -196,16 +205,18 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
 }
 
 /// Build a vector.transfer_read from `source` at indices set to all `0`.
-/// If source has rank zero, build an memref.load.
+/// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
 /// Return the produced value.
-static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
+static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
                              AffineMap map) {
   Location loc = source.getLoc();
   auto shapedType = source.getType().cast<ShapedType>();
   SmallVector<Value> indices(shapedType.getRank(),
                              b.create<ConstantIndexOp>(loc, 0));
-  return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
-                                          map);
+  if (auto vectorType = readType.dyn_cast<VectorType>())
+    return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
+                                            map);
+  return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
 }
 
 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@@ -216,13 +227,14 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
                               OpOperand *outputOperand) {
   Operation *write;
   Location loc = value.getLoc();
+  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
   if (VectorType vectorType =
           extractVectorTypeFromShapedValue(outputOperand->get())) {
-    auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
     AffineMap map =
         reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
     SmallVector<int64_t> transposeShape =
         applyPermutationMap(inversePermutation(map), vectorType.getShape());
+    assert(!transposeShape.empty() && "unexpected empty transpose shape");
     vectorType = VectorType::get(transposeShape, vectorType.getElementType());
     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
                                b.create<ConstantIndexOp>(loc, 0));
@@ -231,9 +243,12 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
                                               indices, map);
   } else {
-    write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
+    value =
+        reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand);
+    write = vector::TransferWriteOp::createScalarOp(
+        b, loc, value, outputOperand->get(), ValueRange{});
   }
-  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
+  LDBG("vectorized op: " << *write);
   if (!write->getResults().empty())
     return write->getResult(0);
   return Value();
@@ -329,7 +344,7 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
 static VectorizationResult
 vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
-  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
+  LDBG("vectorize op " << *op);
 
   // 1. Try to apply any CustomVectorizationHook.
   if (!customVectorizationHooks.empty()) {
@@ -466,33 +481,27 @@ LogicalResult vectorizeAsLinalgGeneric(
       continue;
     }
     // TODO: 0-d vectors.
-    if (linalgOp.getShape(opOperand).empty()) {
-      Value loaded =
-          b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
-      LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
-                        << bbarg.getArgNumber() << "): " << loaded);
-      bvm.map(bbarg, loaded);
-      bvm.map(opOperand->get(), loaded);
-      continue;
-    }
+    Type readType;
     AffineMap map;
-    VectorType vectorType;
-    if (broadcastToMaximalCommonShape) {
-      map = inverseAndBroadcastProjectedPermuation(
-          linalgOp.getTiedIndexingMap(opOperand));
-      vectorType = VectorType::get(commonVectorShape,
-                                   getElementTypeOrSelf(opOperand->get()));
+    if (linalgOp.getShape(opOperand).empty()) {
+      readType = bbarg.getType();
     } else {
-      map = inversePermutation(
-          reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
-      vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+      if (broadcastToMaximalCommonShape) {
+        map = inverseAndBroadcastProjectedPermuation(
+            linalgOp.getTiedIndexingMap(opOperand));
+        readType = VectorType::get(commonVectorShape,
+                                   getElementTypeOrSelf(opOperand->get()));
+      } else {
+        map = inversePermutation(
+            reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
+        readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
                                    getElementTypeOrSelf(opOperand->get()));
+      }
     }
-    Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
-    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
-                      << bbarg.getArgNumber() << "): " << vectorRead);
-    bvm.map(bbarg, vectorRead);
-    bvm.map(opOperand->get(), vectorRead);
+    Value readValue = buildVectorRead(b, opOperand->get(), readType, map);
+    LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
+    bvm.map(bbarg, readValue);
+    bvm.map(opOperand->get(), readValue);
   }
 
   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@@ -516,12 +525,11 @@ LogicalResult vectorizeAsLinalgGeneric(
   for (Operation &op : block.getOperations()) {
     VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
-      LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
+      LDBG("failed to vectorize: " << op);
       return failure();
     }
     if (result.status == VectorizationStatus::NewOp) {
-      LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
-                        << *result.newOp;);
+      LDBG("new vector op: " << *result.newOp;);
       bvm.map(op.getResults(), result.newOp->getResults());
     }
   }
@@ -536,9 +544,9 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
   Location loc = linalgOp.getLoc();
   // Vectorize other ops as vector contraction.
   // TODO: interface.
-  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
-                    << "Rewrite linalg op as vector.contract: ";
-             linalgOp.dump());
+  LDBG(""
+           << "Rewrite linalg op as vector.contract: ";
+       linalgOp.dump());
   // Special function that describes how to vectorize the multiplication op in a
   // linalg contraction.
   CustomVectorizationHook vectorizeContraction =
@@ -592,11 +600,15 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
 // TODO: probably need some extra checks for reduction followed by consumer
 // ops that may not commute (e.g. linear reduction + non-linear instructions).
 static LogicalResult reductionPreconditions(LinalgOp op) {
-  if (llvm::none_of(op.iterator_types(), isReductionIterator))
+  if (llvm::none_of(op.iterator_types(), isReductionIterator)) {
+    LDBG("reduction precondition failed: no reduction iterator");
     return failure();
+  }
   for (OpOperand *opOperand : op.getOutputOperands()) {
-    if (!matchLinalgReduction(opOperand))
+    if (!matchLinalgReduction(opOperand)) {
+      LDBG("reduction precondition failed: reduction detection failed");
       return failure();
+    }
   }
   return success();
 }
@@ -604,8 +616,10 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   auto linalgOp = cast<linalg::LinalgOp>(op);
   // All types must be static shape to go to vector.
-  if (linalgOp.hasDynamicShape())
+  if (linalgOp.hasDynamicShape()) {
+    LDBG("precondition failed: dynamic shape");
     return failure();
+  }
   if (isElementwise(op))
     return success();
   if (isaContractionOpInterface(linalgOp))
@@ -613,10 +627,15 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
-  if (allIndexingsAreProjectedPermutation(linalgOp) &&
-      succeeded(reductionPreconditions(linalgOp)))
-    return success();
-  return failure();
+  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
+    LDBG("precondition failed: not projected permutations");
+    return failure();
+  }
+  if (failed(reductionPreconditions(linalgOp))) {
+    LDBG("precondition failed: reduction preconditions");
+    return failure();
+  }
+  return success();
 }
 
 LogicalResult
@@ -629,10 +648,10 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
   if (isaContractionOpInterface(linalgOp))
     return vectorizeContraction(b, linalgOp, newResults);
 
-  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
-                    << "Vectorize linalg op as a generic by broadcasting to "
-                       "maximal common shape: "
-                    << *op);
+  LDBG(""
+       << "Vectorize linalg op as a generic by broadcasting to "
+          "maximal common shape: "
+       << *op);
   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
                                   /*broadcastToMaximalCommonShape=*/true);
 }
@@ -1200,9 +1219,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
                                     ValueRange values) {
   if (firstOp->getBlock() != secondOp->getBlock() ||
       !firstOp->isBeforeInBlock(secondOp)) {
-    LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
-                            << "interleavedUses precondition failed, firstOp: "
-                            << *firstOp << ", second op: " << *secondOp);
+    LDBG("interleavedUses precondition failed, firstOp: "
+         << *firstOp << ", second op: " << *secondOp);
     return true;
   }
   for (auto v : values) {
@@ -1214,10 +1232,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
       if (owner->getBlock() == firstOp->getBlock() &&
           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
         continue;
-      LLVM_DEBUG(llvm::dbgs()
-                 << "\n[" DEBUG_TYPE "]: "
-                 << " found interleaved op " << *owner
-                 << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
+      LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
+                                    << ", second op: " << *secondOp);
       return true;
     }
   }
@@ -1248,15 +1264,14 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
     return failure();
 
-  LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
+  LDBG(viewOrAlloc);
 
   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
   if (!subViewOp)
     return failure();
   Value subView = subViewOp.getResult();
-  LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
-                          << "with subView " << subView);
+  LDBG("with subView " << subView);
 
   // Find the copy into `subView` without interleaved uses.
   CopyOp copyOp;
@@ -1265,8 +1280,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
       assert(newCopyOp.output().getType().isa<MemRefType>());
       if (newCopyOp.output() != subView)
         continue;
-      LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
-                              << "copy candidate " << *newCopyOp);
+      LDBG("copy candidate " << *newCopyOp);
       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
         continue;
       copyOp = newCopyOp;
@@ -1275,8 +1289,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   }
   if (!copyOp)
     return failure();
-  LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
-                          << "with copy " << *copyOp);
+  LDBG("with copy " << *copyOp);
 
   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
   FillOp maybeFillOp;
@@ -1285,8 +1298,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
       assert(newFillOp.output().getType().isa<MemRefType>());
       if (newFillOp.output() != viewOrAlloc)
         continue;
-      LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
-                              << "fill candidate " << *newFillOp);
+      LDBG("fill candidate " << *newFillOp);
       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
         continue;
       maybeFillOp = newFillOp;
@@ -1297,8 +1309,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
     return failure();
   if (maybeFillOp)
-    LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
-                            << "with maybeFillOp " << *maybeFillOp);
+    LDBG("with maybeFillOp " << *maybeFillOp);
 
   // `in` is the subview that linalg.copy reads. Replace it.
   Value in = copyOp.input();

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 09d1bdc5349c2..e696f3481ed0f 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2439,6 +2439,18 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
         /*mask=*/Value(), inBounds);
 }
 
+Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc,
+                                     Value source, ValueRange indices,
+                                     ArrayRef<bool> inBounds) {
+  Type elemType = source.getType().cast<ShapedType>().getElementType();
+  auto vectorType = VectorType::get(ArrayRef<int64_t>{1}, elemType);
+  AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
+                                 getAffineConstantExpr(0, loc.getContext()));
+  Value read = builder.create<vector::TransferReadOp>(loc, vectorType, source,
+                                                      indices, map, inBounds);
+  return builder.create<vector::ExtractOp>(loc, read, ArrayRef<int64_t>{0});
+}
+
 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   SmallVector<StringRef, 3> elidedAttrs;
   elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
@@ -2769,6 +2781,16 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
 
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+                            Value vector, Value dest, ValueRange indices,
+                            AffineMap permutationMap, ArrayRef<bool> inBounds) {
+  if (inBounds.empty())
+    return build(builder, result, vector, dest, indices, permutationMap,
+                 /*mask=*/Value(), ArrayAttr());
+  build(builder, result, vector, dest, indices, permutationMap,
+        /*mask=*/Value(), builder.getBoolArrayAttr(inBounds));
+}
+
 /// Builder that sets permutation map to 'getMinorIdentityMap'.
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                             Value vector, Value source, ValueRange indices,
@@ -2783,13 +2805,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
 }
 
-void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
-                            Value vector, Value source, ValueRange indices,
-                            AffineMap permutationMap) {
-  build(builder, result, vector, source, indices, permutationMap,
-        /*inBounds=*/ArrayAttr());
-}
-
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                             Value vector, Value source, ValueRange indices,
                             AffineMapAttr permutationMap,
@@ -2817,6 +2832,20 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
         mask, inBounds);
 }
 
+Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc,
+                                           Value value, Value dest,
+                                           ValueRange indices,
+                                           ArrayRef<bool> inBounds) {
+  Value vectorOfAScalar = value;
+  if (!value.getType().isa<VectorType>())
+    vectorOfAScalar = builder.create<vector::BroadcastOp>(
+        loc, VectorType::get({1}, value.getType()), value);
+  AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
+                                 getAffineConstantExpr(0, loc.getContext()));
+  return builder.create<vector::TransferWriteOp>(loc, vectorOfAScalar, dest,
+                                                 indices, map, inBounds);
+}
+
 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
                                         OperationState &result) {
   auto &builder = parser.getBuilder();

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d56214fb2fdf7..c3a4a05413deb 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -203,8 +203,9 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
 
 // CHECK-LABEL: func @test_vectorize_fill
 func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
-  //  CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
-  //       CHECK:   store %[[V]], %[[M]][] : memref<f32>
+  // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
+  //      CHECK:   %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
+  //      CHECK:   vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref<f32>
   linalg.fill(%arg0, %A) : f32, memref<f32>
   return
 }
@@ -223,8 +224,11 @@ func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
 
 // CHECK-LABEL: func @test_vectorize_copy_scalar
 func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
-  //       CHECK: %[[V:.*]] = memref.load {{.*}} : memref<f32>
-  //       CHECK: store %[[V]], {{.*}} : memref<f32>
+  //  CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
+  //       CHECK:   %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<1xf32>
+  //       CHECK:   %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32>
+  //       CHECK:   %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
+  //       CHECK:   vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref<f32>
   linalg.copy(%A, %B) :  memref<f32>, memref<f32>
   return
 }
@@ -857,3 +861,42 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   return %red : tensor<4xf32>
 }
 
+// -----
+
+//  CHECK-LABEL: func @reduce_1d(
+//   CHECK-SAME:   %[[A:.*]]: tensor<32xf32>
+func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
+  //  CHECK-DAG: %[[F0_v1:.*]] = constant dense<0.000000e+00> : vector<1xf32>
+  //  CHECK-DAG: %[[F0_v32:.*]] = constant dense<0.000000e+00> : vector<32xf32>
+  //  CHECK-DAG: %[[C0:.*]] = constant 0 : index
+  %f0 = constant 0.000000e+00 : f32
+
+  //      CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
+  %0 = linalg.init_tensor [] : tensor<f32>
+
+  //      CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][]
+  // CHECK-SAME:   : vector<1xf32>, tensor<f32>
+  %1 = linalg.fill(%f0, %0) : f32, tensor<f32> -> tensor<f32>
+
+  //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
+  // CHECK-SAME:   : tensor<32xf32>, vector<32xf32>
+  //      CHECK: %[[a:.*]] = addf %[[r]], %[[F0_v32]] : vector<32xf32>
+  //      CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[a]] [0]
+  // CHECK-SAME:   : vector<32xf32> to f32
+  //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<1xf32>
+  //      CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
+  // CHECK-SAME:   : vector<1xf32>, tensor<f32>
+  %2 = linalg.generic {
+         indexing_maps = [affine_map<(d0) -> (d0)>,
+                          affine_map<(d0) -> ()>],
+         iterator_types = ["reduction"]}
+         ins(%arg0 : tensor<32xf32>)
+         outs(%1 : tensor<f32>) {
+    ^bb0(%a: f32, %b: f32):  // no predecessors
+      %3 = addf %a, %b : f32
+      linalg.yield %3 : f32
+    } -> tensor<f32>
+
+  return %2 : tensor<f32>
+}
+


        


More information about the Mlir-commits mailing list