[Mlir-commits] [mlir] b6113db - [mlir][Linalg] Generalize linalg vectorization

Nicolas Vasilache llvmlistbot at llvm.org
Thu Apr 29 00:48:14 PDT 2021


Author: Nicolas Vasilache
Date: 2021-04-29T07:44:01Z
New Revision: b6113db955aa7783de9715adeffaf88ba12f2699

URL: https://github.com/llvm/llvm-project/commit/b6113db955aa7783de9715adeffaf88ba12f2699
DIFF: https://github.com/llvm/llvm-project/commit/b6113db955aa7783de9715adeffaf88ba12f2699.diff

LOG: [mlir][Linalg] Generalize linalg vectorization

This revision adds support for vectorizing more general linalg operations with projected permutation maps.

This is achieved by eagerly broadcasting the intermediate vector to the common size
of the iteration domain of the linalg op. This allows a much more natural expression of
generalized vectorization but may introduce additional computations until all the
proper canonicalizations are implemented.

This generalization modifies the vector.transfer_read/write permutation logic and
exposes the fact that the logic employed in vector.contract was too ad-hoc.

As a consequence, changes occur in the permutation / transposition logic for contraction. In turn this prompts supporting more cases in the lowering of contract
to matrix intrinsics, which is required to make the corresponding tests pass.

Differential revision: https://reviews.llvm.org/D101165

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/SliceAnalysis.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 9d276c2987f8..26646858a3b9 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -33,7 +33,7 @@ using TransitiveFilter = llvm::function_ref<bool(Operation *)>;
 /// This additionally takes a TransitiveFilter which acts as a frontier:
 /// when looking at uses transitively, an operation that does not pass the
 /// filter is never propagated through. This allows in particular to carve out
-/// the scope within a ForInst or the scope within an IfInst.
+/// the scope within a ForOp or the scope within an IfOp.
 ///
 /// The implementation traverses the use chains in postorder traversal for
 /// efficiency reasons: if an operation is already in `forwardSlice`, no
@@ -82,7 +82,7 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
 /// This additionally takes a TransitiveFilter which acts as a frontier:
 /// when looking at defs transitively, an operation that does not pass the
 /// filter is never propagated through. This allows in particular to carve out
-/// the scope within a ForInst or the scope within an IfInst.
+/// the scope within a ForOp or the scope within an IfOp.
 ///
 /// The implementation traverses the def chains in postorder traversal for
 /// efficiency reasons: if an operation is already in `backwardSlice`, no

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 70dbdf8d0558..d08c73ff2f4d 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -286,6 +286,70 @@ def Vector_ReductionOp :
   }];
 }
 
+def Vector_MultiDimReductionOp :
+  Vector_Op<"multi_reduction", [NoSideEffect,
+     PredOpTrait<"source operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(ins Vector_CombiningKindAttr:$kind,
+                   AnyVector:$source,
+                   I64ArrayAttr:$reduction_dims)>,
+    Results<(outs AnyType:$dest)> {
+  let summary = "Multi-dimensional reduction operation";
+  let description = [{
+    Reduces an n-D vector into an (n-k)-D vector using the given operation
+    (add/mul/min/max for int/fp and and/or/xor for int only).
+
+    Example:
+
+    ```mlir
+    %1 = vector.multi_reduction "add", %0 [1, 3] :
+      vector<4x8x16x32xf32> into vector<4x16xf32>
+    ```
+  }];
+  let builders = [
+    OpBuilder<(ins "Value":$source, "ArrayRef<bool>":$reductionMask,
+                   "CombiningKind":$kind)>
+  ];
+  let extraClassDeclaration = [{
+    static StringRef getKindAttrName() { return "kind"; }
+    static StringRef getReductionDimsAttrName() { return "reduction_dims"; }
+
+    VectorType getSourceVectorType() {
+      return source().getType().cast<VectorType>();
+    }
+    VectorType getDestVectorType() {
+      return dest().getType().cast<VectorType>();
+    }
+
+    SmallVector<bool> getReductionMask() {
+      SmallVector<bool> res(getSourceVectorType().getRank(), false);
+      for (auto ia : reduction_dims().getAsRange<IntegerAttr>())
+        res[ia.getInt()] = true;
+      return res;
+    }
+    static SmallVector<bool> getReductionMask(
+        ArrayRef<int64_t> reductionDims, unsigned sourceRank) {
+      SmallVector<bool> res(sourceRank, false);
+      for (auto idx : reductionDims)
+        res[idx] = true;
+      return res;
+    }
+
+    static SmallVector<int64_t> inferDestShape(
+      ArrayRef<int64_t> shape, ArrayRef<bool> reducedDimsMask) {
+      assert(shape.size() == reducedDimsMask.size() && 
+             "shape and maks of 
diff erent sizes");
+      SmallVector<int64_t> res;
+      for (auto it : llvm::zip(reducedDimsMask, shape))
+        if (!std::get<0>(it))
+          res.push_back(std::get<1>(it));
+      return res;
+    }
+  }];
+  let assemblyFormat =
+    "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
+}
+
 def Vector_BroadcastOp :
   Vector_Op<"broadcast", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",
@@ -1317,6 +1381,18 @@ def Vector_TransferReadOp :
       "ArrayAttr":$inBounds)>
   ];
 
+  let extraClassDeclaration = [{
+    /// Return a new `result` map with `0` inserted in the proper positions so 
+    /// that vector.transfer_read `result` produces a vector of same element 
+    /// type as `vt` and shape `targetShape.
+    /// Assume that `map` is a permutation map for a vector.transfer_read op, 
+    /// `vt` the vector type produced by the vector.transfer_read and 
+    /// `targetShape` is the desired `targetShape` for a broadcast version of 
+    /// `vt`.
+    static AffineMap insertBroadcasts(AffineMap map, VectorType vt,
+                                      ArrayRef<int64_t> targetShape);
+  }];
+
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 14ef418ed591..c4afed4d71f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -25,6 +26,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <type_traits>
@@ -40,7 +42,8 @@ using llvm::dbgs;
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
-template <typename OpType> static OpType getSingleOpOfType(Block &block) {
+template <typename OpType>
+static OpType getSingleOpOfType(Block &block) {
   OpType res;
   block.walk([&](OpType op) {
     if (res) {
@@ -53,6 +56,31 @@ template <typename OpType> static OpType getSingleOpOfType(Block &block) {
   return res;
 }
 
+/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
+/// projectedPermutation, compress the unused dimensions to serve as a
+/// permutation_map for a vector transfer operation.
+/// For example, given a linalg op such as:
+///
+/// ```
+///   %0 = linalg.generic {
+///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
+///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
+///      }
+///     ins(%0 : tensor<2x3x4xf32>)
+///    outs(%1 : tensor<5x6xf32>)
+/// ```
+///
+/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
+/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
+/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
+static AffineMap reindexIndexingMap(AffineMap map) {
+  assert(map.isProjectedPermutation() && "expected projected permutation");
+  auto res = compressUnusedDims(map);
+  assert(res.getNumDims() == res.getNumResults() &&
+         "expected reindexed map with same number of dims and results");
+  return res;
+}
+
 /// Helper data structure to represent the result of vectorization.
 /// In certain specific cases, like terminators, we do not want to propagate/
 enum VectorizationStatus {
@@ -83,6 +111,116 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
   return VectorType::get(st.getShape(), st.getElementType());
 }
 
+/// Given an `outputOperand` of a LinalgOp, compute the intersection of the
+/// forward slice starting from `outputOperand` and the backward slice
+/// starting from the corresponding linalg.yield operand.
+/// This intersection is assumed to have a single binary operation that is
+/// the reduction operation. Multiple reduction operations would impose an
+/// ordering between reduction dimensions and is currently unsupported in
+/// Linalg. This limitation is motivated by the fact that e.g.
+/// min(max(X)) != max(min(X))
+// TODO: use in LinalgOp verification, there is a circular dependency atm.
+static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) {
+  auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
+  auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
+  unsigned yieldNum =
+      outputOperand.getOperandNumber() - linalgOp.getNumInputs();
+  llvm::SetVector<Operation *> backwardSlice, forwardSlice;
+  BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
+      outputOperand.getOperandNumber());
+  Value yieldVal = yieldOp->getOperand(yieldNum);
+  getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
+    return op->getParentOp() == linalgOp;
+  });
+  backwardSlice.insert(yieldVal.getDefiningOp());
+  getForwardSlice(bbArg, &forwardSlice,
+                  [&](Operation *op) { return op->getParentOp() == linalgOp; });
+  // Search for the (assumed unique) elementwiseMappable op at the intersection
+  // of forward and backward slices.
+  Operation *reductionOp = nullptr;
+  for (Operation *op : llvm::reverse(backwardSlice)) {
+    if (!forwardSlice.contains(op))
+      continue;
+    if (OpTrait::hasElementwiseMappableTraits(op)) {
+      if (reductionOp) {
+        // Reduction detection fails: found more than 1 elementwise-mappable op.
+        return nullptr;
+      }
+      reductionOp = op;
+    }
+  }
+  // TODO: also assert no other subsequent ops break the reduction.
+  return reductionOp;
+}
+
+/// If `value` of assumed VectorType has a shape 
diff erent than `shape`, try to
+/// build and return a new vector.broadcast to `shape`.
+/// Otherwise, just return `value`.
+// TODO: this is best effort atm and there is currently no guarantee of
+// correctness for the broadcast semantics.
+static Value broadcastIfNeeded(OpBuilder &builder, Value value,
+                               ArrayRef<int64_t> shape) {
+  unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
+                                        [](int64_t val) { return val > 1; });
+  auto vecType = value.getType().dyn_cast<VectorType>();
+  if (shape.empty() ||
+      (vecType != nullptr &&
+       (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
+    return value;
+  auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
+                                                   : value.getType());
+  return builder.create<vector::BroadcastOp>(
+      builder.getInsertionPoint()->getLoc(), newVecType, value);
+}
+
+static llvm::Optional<vector::CombiningKind>
+getKindForOp(Operation *reductionOp) {
+  if (!reductionOp)
+    return llvm::None;
+  return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
+             reductionOp)
+      .Case<AddIOp, AddFOp>([&](auto op) {
+        return llvm::Optional<vector::CombiningKind>{
+            vector::CombiningKind::ADD};
+      })
+      .Default([&](auto op) { return llvm::None; });
+}
+
+/// If value of assumed VectorType has a shape 
diff erent than `shape`, build and
+/// return a new vector.broadcast to `shape`.
+/// Otherwise, just return value.
+static Value reduceIfNeeded(OpBuilder &builder, VectorType targetVectorType,
+                            Value value, OpOperand &outputOperand) {
+  assert(targetVectorType.getShape() ==
+         outputOperand.get().getType().cast<ShapedType>().getShape());
+  auto vecType = value.getType().dyn_cast<VectorType>();
+  if (!vecType || vecType.getShape() == targetVectorType.getShape())
+    return value;
+  // At this point, we know we need to reduce. Detect the reduction operator.
+  // TODO: Use the generic reduction detection util.
+  Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
+  auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
+  unsigned pos = 0;
+  MLIRContext *ctx = builder.getContext();
+  SmallVector<AffineExpr> exprs;
+  for (auto s : linalgOp.iterator_types())
+    if (isParallelIterator(s))
+      exprs.push_back(getAffineDimExpr(pos++, ctx));
+  auto loc = value.getLoc();
+  // TODO: reuse common CombiningKing logic and support more than add.
+  auto maybeKind = getKindForOp(reductionOp);
+  assert(maybeKind && "Failed precondition: could not get reduction kind");
+  unsigned idx = 0;
+  SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
+  for (auto attr : linalgOp.iterator_types()) {
+    if (isReductionIteratorType(attr))
+      reductionMask[idx] = true;
+    ++idx;
+  }
+  return builder.create<vector::MultiDimReductionOp>(loc, value, reductionMask,
+                                                     *maybeKind);
+}
+
 /// Build a vector.transfer_read from `source` at indices set to all `0`.
 /// If source has rank zero, build an memref.load.
 /// Return the produced value.
@@ -90,29 +228,30 @@ static Value buildVectorRead(OpBuilder &builder, Value source,
                              VectorType vectorType, AffineMap map) {
   edsc::ScopedContext scope(builder);
   auto shapedType = source.getType().cast<ShapedType>();
-  if (vectorType) {
-    SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
-    if (map)
-      return vector_transfer_read(vectorType, source, indices, map);
-    return vector_transfer_read(vectorType, source, indices);
-  }
-  return memref_load(source);
+  SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
+  return vector_transfer_read(vectorType, source, indices, map);
 }
 
-/// Build a vector.transfer_write of `value` into `dest` at indices set to all
-/// `0`. If `dest` has null rank, build an memref.store.
+/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
+/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
+/// currently being vectorized. If `dest` has null rank, build an memref.store.
 /// Return the produced value or null if no value is produced.
-static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
+static Value buildVectorWrite(OpBuilder &builder, Value value,
+                              OpOperand &outputOperand) {
   edsc::ScopedContext scope(builder);
   Operation *write;
-  auto shapedType = dest.getType().cast<ShapedType>();
-  if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
+  auto shapedType = outputOperand.get().getType().cast<ShapedType>();
+  if (VectorType vectorType =
+          extractVectorTypeFromShapedValue(outputOperand.get())) {
+    auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
+    AffineMap map = reindexIndexingMap(
+        linalgOp.getIndexingMap(outputOperand.getOperandNumber()));
     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
-    if (vectorType != value.getType())
-      value = vector_broadcast(vectorType, value);
-    write = vector_transfer_write(value, dest, indices);
+    value = broadcastIfNeeded(builder, value, vectorType.getShape());
+    value = reduceIfNeeded(builder, vectorType, value, outputOperand);
+    write = vector_transfer_write(value, outputOperand.get(), indices, map);
   } else {
-    write = memref_store(value, dest);
+    write = memref_store(value, outputOperand.get());
   }
   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
   if (!write->getResults().empty())
@@ -120,20 +259,6 @@ static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
   return Value();
 }
 
-/// If value of assumed VectorType has a shape 
diff erent than `shape`, buil and
-/// return a new vector.broadcast to `shape`.
-/// Otherwise, just return value.
-static Value broadcastIfNeeded(OpBuilder &builder, Value value,
-                               ArrayRef<int64_t> shape) {
-  auto vecType = value.getType().dyn_cast<VectorType>();
-  if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
-    return value;
-  auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
-                                                   : value.getType());
-  return builder.create<vector::BroadcastOp>(
-      builder.getInsertionPoint()->getLoc(), newVecType, value);
-}
-
 // Custom vectorization function type. Produce a vector form of Operation*
 // assuming all its vectorized operands are already in the BlockAndValueMapping.
 // Return nullptr if the Operation cannot be vectorized.
@@ -158,8 +283,8 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
     // TODO: Scan for an opportunity for reuse.
     // TODO: use a map.
     Value vectorValue = bvm.lookup(outputs.value());
-    Value newResult = buildVectorWrite(builder, vectorValue,
-                                       linalgOp.getOutput(outputs.index()));
+    Value newResult = buildVectorWrite(
+        builder, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]);
     if (newResult)
       newResults.push_back(newResult);
   }
@@ -307,20 +432,6 @@ static bool isElementwise(Operation *op) {
   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
 }
 
-// Calculate the map to apply to transfer_read to convert the input shape into
-// the output shape.
-static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
-  AffineMap linalgMap = linalgOp.getIndexingMap(argIndex);
-  MLIRContext *context = linalgMap.getContext();
-  AffineExpr zero = mlir::getAffineConstantExpr(0, context);
-  SmallVector<AffineExpr, 4> exprs(linalgMap.getNumInputs(), zero);
-  for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) {
-    exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context);
-  }
-  return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs,
-                        context);
-}
-
 /// Generic vectorization function that rewrites the body of a `linalgOp` into
 /// vector form. Generic vectorization proceeds as follows:
 ///   1. Verify the `linalgOp` has one non-empty region.
@@ -333,42 +444,70 @@ static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
 ///   4b. Register CustomVectorizationHook for IndexOp to access the iteration
 ///   indices.
 ///   5. Iteratively call vectorizeOneOp on the region operations.
+///
+/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
+/// performed to the maximal common vector size implied by the `linalgOp`
+/// iteration space. This eager broadcasting is introduced in the
+/// permutation_map of the vector.transfer_read operations. The eager
+/// broadcasting makes it trivial to detrmine where broadcast, transposes and
+/// reductions should occur, without any bookkeeping. The tradeoff is that, in
+/// the absence of good canonicalizations, the amount of work increases.
+/// This is not deemed a problem as we expect canonicalizations and foldings to
+/// aggressively clean up the useless work.
 LogicalResult vectorizeAsLinalgGeneric(
     OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
+    bool broadcastToMaximalCommonShape = false,
     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
   // 1. Fail to vectorize if the operation does not have one non-empty region.
   if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty())
     return failure();
   auto &block = linalgOp->getRegion(0).front();
 
-  BlockAndValueMapping bvm;
   // 2. Values defined above the region can only be broadcast for now. Make them
   // map to themselves.
+  BlockAndValueMapping bvm;
   SetVector<Value> valuesSet;
   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
 
+  if (linalgOp.getNumOutputs() == 0)
+    return failure();
+
+  // TODO: the common vector shape is equal to the static loop sizes only when
+  // all indexing maps are projected permutations. For convs and stencils the
+  // logic will need to evolve.
+  SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
+
   // 3. Turn all BBArgs into vector.transfer_read / load.
   SmallVector<AffineMap> indexings;
   for (auto bbarg : block.getArguments()) {
-    Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
-    AffineMap map;
-    VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg);
-    if (isElementwise(linalgOp) &&
-        !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) {
-      // Currently assume we don't support output permutations.
-      assert(linalgOp.getNumOutputs() > 0 &&
-             linalgOp.getOutputIndexingMap(0).isIdentity());
-      ArrayRef<int64_t> outputShape =
-          linalgOp.getOutputShapedType(0).getShape();
-      vectorType = VectorType::get(outputShape, vectorType.getElementType());
-      map = getTransferReadMap(linalgOp, bbarg.getArgNumber());
+    Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
+    ShapedType shapedType = shapedArg.getType().cast<ShapedType>();
+    // TODO: 0-d vectors.
+    if (shapedType.getShape().empty()) {
+      Value loaded =
+          builder.create<memref::LoadOp>(linalgOp.getLoc(), shapedArg);
+      LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
+                        << bbarg.getArgNumber() << "): " << loaded);
+      bvm.map(bbarg, loaded);
+      bvm.map(shapedArg, loaded);
+      continue;
+    }
+    AffineMap map = inversePermutation(
+        reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
+    VectorType vectorType = VectorType::get(map.compose(shapedType.getShape()),
+                                            shapedType.getElementType());
+    if (broadcastToMaximalCommonShape) {
+      map = vector::TransferReadOp::insertBroadcasts(map, vectorType,
+                                                     commonVectorShape);
+      vectorType =
+          VectorType::get(commonVectorShape, vectorType.getElementType());
     }
-    Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map);
+    Value vectorRead = buildVectorRead(builder, shapedArg, vectorType, map);
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
                       << bbarg.getArgNumber() << "): " << vectorRead);
     bvm.map(bbarg, vectorRead);
-    bvm.map(vectorArg, vectorRead);
+    bvm.map(shapedArg, vectorRead);
   }
 
   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@@ -428,15 +567,48 @@ static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
                      : VectorType::get(outShape, op->getResult(0).getType());
     auto zero =
         builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
+    // Indexing maps at the time of vector.transfer_read are adjusted to order
+    // vector dimensions in the same order as the canonical linalg op iteration
+    // space order.
+    // The indexings for the contraction therefore need to be adjusted.
+    // TODO: consider dropping contraction special casing altogether, this will
+    // require more advanced canonicalizations involving vector.multi_reduction
+    // that are not yet available.
+    SmallVector<AffineMap> indexingMaps{
+        inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0)))
+            .compose(linalgOp.getIndexingMap(0)),
+        inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1)))
+            .compose(linalgOp.getIndexingMap(1)),
+        inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2)))
+            .compose(linalgOp.getIndexingMap(2))};
     Operation *contract = builder.create<vector::ContractionOp>(
         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
-        linalgOp.indexing_maps(), linalgOp.iterator_types());
+        builder.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
     return VectorizationResult{VectorizationStatus::NewOp, contract};
   };
   return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
+                                  /*broadcastToMaximalCommonShape=*/false,
                                   {vectorizeContraction});
 }
 
+static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
+  return llvm::all_of(op.getIndexingMaps(),
+                      [](AffineMap m) { return m.isProjectedPermutation(); });
+}
+
+// TODO: probably need some extra checks for reduction followed by consumer
+// ops that may not commute (e.g. linear reduction + non-linear instructions).
+static LogicalResult reductionPreconditions(LinalgOp op) {
+  if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
+    return failure();
+  for (auto &operand : op.getOutputOpOperands()) {
+    Operation *reductionOp = getSingleBinaryOpAssumedReduction(operand);
+    if (!getKindForOp(reductionOp))
+      return failure();
+  }
+  return success();
+}
+
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   auto linalgOp = cast<linalg::LinalgOp>(op);
   // All types must be static shape to go to vector.
@@ -448,7 +620,15 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
       return failure();
   if (isElementwise(op))
     return success();
-  return success(isaContractionOpInterface(linalgOp));
+  if (isaContractionOpInterface(linalgOp))
+    return success();
+  // TODO: the common vector shape is equal to the static loop sizes only when
+  // all indexing maps are projected permutations. For convs and stencils the
+  // logic will need to evolve.
+  if (allIndexingsAreProjectedPermutation(linalgOp) &&
+      succeeded(reductionPreconditions(linalgOp)))
+    return success();
+  return failure();
 }
 
 LogicalResult
@@ -458,13 +638,17 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op,
     return failure();
 
   edsc::ScopedContext scope(builder, op->getLoc());
-  if (isElementwise(op)) {
-    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
-                      << "Vectorize linalg op as a generic: " << *op);
-    return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults);
-  }
+  auto linalgOp = cast<LinalgOp>(op);
+
+  if (isaContractionOpInterface(linalgOp))
+    return vectorizeContraction(builder, linalgOp, newResults);
 
-  return vectorizeContraction(builder, cast<LinalgOp>(op), newResults);
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
+                    << "Vectorize linalg op as a generic by broadcasting to "
+                       "maximal common shape: "
+                    << *op);
+  return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
+                                  /*broadcastToMaximalCommonShape=*/true);
 }
 
 //----------------------------------------------------------------------------//

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index d65481955155..9fd9e1e40866 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -231,6 +231,45 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
   return builder.getI64ArrayAttr(values);
 }
 
+//===----------------------------------------------------------------------===//
+// MultiDimReductionOp
+//===----------------------------------------------------------------------===//
+
+void vector::MultiDimReductionOp::build(OpBuilder &builder,
+                                        OperationState &result, Value source,
+                                        ArrayRef<bool> reductionMask,
+                                        CombiningKind kind) {
+  result.addOperands(source);
+  auto sourceVectorType = source.getType().cast<VectorType>();
+  auto targetShape = MultiDimReductionOp::inferDestShape(
+      sourceVectorType.getShape(), reductionMask);
+  auto targetVectorType =
+      VectorType::get(targetShape, sourceVectorType.getElementType());
+  result.addTypes(targetVectorType);
+
+  SmallVector<int64_t> reductionDims;
+  for (auto en : llvm::enumerate(reductionMask))
+    if (en.value())
+      reductionDims.push_back(en.index());
+  result.addAttribute(getReductionDimsAttrName(),
+                      builder.getI64ArrayAttr(reductionDims));
+  result.addAttribute(getKindAttrName(),
+                      CombiningKindAttr::get(kind, builder.getContext()));
+}
+
+static LogicalResult verify(MultiDimReductionOp op) {
+  auto reductionMask = op.getReductionMask();
+  auto targetShape = MultiDimReductionOp::inferDestShape(
+      op.getSourceVectorType().getShape(), reductionMask);
+  auto targetVectorType =
+      VectorType::get(targetShape, op.getSourceVectorType().getElementType());
+  if (targetVectorType != op.getDestVectorType())
+    return op.emitError("invalid output vector type: ")
+           << op.getDestVectorType() << " (expected: " << targetVectorType
+           << ")";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ReductionOp
 //===----------------------------------------------------------------------===//
@@ -2160,6 +2199,29 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
 // TransferReadOp
 //===----------------------------------------------------------------------===//
 
+AffineMap TransferReadOp::insertBroadcasts(AffineMap map, VectorType vt,
+                                           ArrayRef<int64_t> targetShape) {
+  unsigned targetRank = targetShape.size();
+  assert(vt.getShape().size() <= targetRank && "mismatching ranks");
+  if (vt.getShape().size() == targetRank)
+    return map;
+  MLIRContext *ctx = map.getContext();
+  SmallVector<AffineExpr> exprs;
+  exprs.reserve(targetRank);
+  for (unsigned idx = 0, vtidx = 0; idx < targetRank; ++idx) {
+    // If shapes match, just keep the existing indexing and advance ranks.
+    if (vtidx < vt.getShape().size() &&
+        vt.getShape()[vtidx] == targetShape[idx]) {
+      exprs.push_back(map.getResult(vtidx));
+      ++vtidx;
+      continue;
+    }
+    // Otherwise insert a broadcast.
+    exprs.push_back(getAffineConstantExpr(0, ctx));
+  }
+  return AffineMap::get(map.getNumDims(), /*numSymbols=*/0, exprs, ctx);
+}
+
 template <typename EmitFun>
 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
                                           EmitFun emitOpError) {

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 2af0b361db4d..58e2c1d3a83d 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1752,18 +1752,23 @@ namespace mlir {
 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
 /// semantics to:
 /// ```
-///    %flattened_a = vector.shape_cast %a
-///    %flattened_b = vector.shape_cast %b
+///    %mta = maybe_transpose
+///    %mtb = maybe_transpose
+///    %flattened_a = vector.shape_cast %mta
+///    %flattened_b = vector.shape_cast %mtb
 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
-///    %d = vector.shape_cast %%flattened_d
+///    %mtd = vector.shape_cast %flattened_d
+///    %d = maybe_untranspose %mtd
 ///    %e = add %c, %d
 /// ```
 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
 //
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
+/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
+/// vector.transpose operations are inserted if the vector.contract op is not a
+/// row-major matrix multiply.
+LogicalResult
+ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
+                                                 PatternRewriter &rew) const {
   // TODO: implement masks
   if (llvm::size(op.masks()) != 0)
     return failure();
@@ -1779,37 +1784,67 @@ LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(
       !isReductionIterator(iteratorTypes[2]))
     return failure();
 
-  if (!isRowMajorMatmul(op.indexing_maps()))
-    return failure();
-
   Type elementType = op.getLhsType().getElementType();
   if (!elementType.isIntOrFloat())
     return failure();
 
-  VectorType lhsType = op.getLhsType();
-  VectorType rhsType = op.getRhsType();
+  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
+  // Bail out if the contraction cannot be put in this form.
+  MLIRContext *ctx = op.getContext();
+  Location loc = op.getLoc();
+  AffineExpr m, n, k;
+  bindDims(rew.getContext(), m, n, k);
+  // LHS must be A(m, k) or A(k, m).
+  Value lhs = op.lhs();
+  auto lhsMap = op.indexing_maps()[0].cast<AffineMapAttr>().getValue();
+  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
+    lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
+  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
+    return failure();
+
+  // RHS must be B(k, n) or B(n, k).
+  Value rhs = op.rhs();
+  auto rhsMap = op.indexing_maps()[1].cast<AffineMapAttr>().getValue();
+  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
+    rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
+  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
+    return failure();
+
+  // At this point lhs and rhs are in row-major.
+  VectorType lhsType = lhs.getType().cast<VectorType>();
+  VectorType rhsType = rhs.getType().cast<VectorType>();
   int64_t lhsRows = lhsType.getDimSize(0);
   int64_t lhsColumns = lhsType.getDimSize(1);
   int64_t rhsColumns = rhsType.getDimSize(1);
 
   Type flattenedLHSType =
       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
+
   Type flattenedRHSType =
       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
-  auto lhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedLHSType,
-                                                  op.lhs());
-  auto rhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedRHSType,
-                                                  op.rhs());
-
-  Value mul = rewriter.create<vector::MatmulOp>(op.getLoc(), lhs, rhs, lhsRows,
-                                                lhsColumns, rhsColumns);
-  mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
-                                             mul);
-  if (elementType.isa<IntegerType>())
-    rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
-  else
-    rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
-
+  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
+
+  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
+                                           rhsColumns);
+  mul = rew.create<vector::ShapeCastOp>(
+      loc,
+      VectorType::get({lhsRows, rhsColumns},
+                      getElementTypeOrSelf(op.acc().getType())),
+      mul);
+
+  // ACC must be C(m, n) or C(n, m).
+  auto accMap = op.indexing_maps()[2].cast<AffineMapAttr>().getValue();
+  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
+    mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
+  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
+    llvm_unreachable("invalid contraction semantics");
+
+  Value res = elementType.isa<IntegerType>()
+                  ? static_cast<Value>(rew.create<AddIOp>(loc, op.acc(), mul))
+                  : static_cast<Value>(rew.create<AddFOp>(loc, op.acc(), mul));
+
+  rew.replaceOp(op, res);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index dbdf19341920..2a8a2cc6fb88 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -22,6 +22,6 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 //
 //      CHECK: vector.contract
 // CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"]
-// CHECK-SAME:   : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
+// CHECK-SAME:   : vector<8x16xf32>, vector<12x16xf32> into vector<8x12xf32>
 //
 //      CHECK: linalg.copy

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c18bf5b5cd8b..3eafc5acd6f5 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -58,18 +58,19 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
+// CHECK-DAG: #[[$trans_2d:.*]] =  affine_map<(d0, d1) -> (d1, d0)>
 // CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
 // CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: func @vectorization_test
 func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
-  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
-  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]]
-  //  CHECK-SAME:   vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
+  //  CHECK-SAME:   vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
   linalg.generic #matmul_trait
     ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
@@ -95,18 +96,19 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
+// CHECK-DAG: #[[$trans_2d:.*]] =  affine_map<(d0, d1) -> (d1, d0)>
 // CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
 // CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: func @vectorization_test_integer
 func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
                                  %C: memref<8x32xi32>) {
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
-  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
+  //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<32x16xi32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
-  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]],
-  //  CHECK-SAME:   vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
+  //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]],
+  //  CHECK-SAME:   vector<8x16xi32>, vector<32x16xi32> into vector<8x32xi32>
   //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
   linalg.generic #matmul_trait
     ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
@@ -252,13 +254,12 @@ func @generic_vectorize(%arg0: memref<4x256xf32>,
     memref<4x256xf32>, memref<4x256xf32>) {
   ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
   //       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
-  //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
+  //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<4x256xf32>
   //       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
   //       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
     %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
     %arg14 : f32):
-  //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-  //       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
+  //       CHECK:   %[[ADD:.*]] = addf %[[V0]], %[[V1]] : vector<4x256xf32>
     %6 = addf %arg4, %arg6 : f32
   //       CHECK:   %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
     %7 = cmpf ogt, %arg3, %arg6 : f32
@@ -274,8 +275,7 @@ func @generic_vectorize(%arg0: memref<4x256xf32>,
     %12 = math.rsqrt %arg5 : f32
   //       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
     %13 = select %7, %arg5, %arg6 : f32
-  //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-  //       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
+  //       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0]] : vector<4x256xf32>
     %14 = subf %arg5, %arg4 : f32
   //       CHECK:   %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32>
     %15 = math.tanh %arg5 : f32
@@ -334,11 +334,10 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
   //   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
   //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
   //       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
-  //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
+  //       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<4x256xf32>
   //       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
   //       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
-  //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-  //       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
+  //       CHECK:   %[[ADD:.*]] = addf %[[V0]], %[[V1]] : vector<4x256xf32>
     %6 = addf %arg4, %arg6 : f32
   //       CHECK:   %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32>
     %7 = cmpf ogt, %arg3, %arg6 : f32
@@ -354,8 +353,7 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
     %12 = math.rsqrt %arg5 : f32
   //       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
     %13 = select %7, %arg5, %arg6 : f32
-  //       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
-  //       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
+  //       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0]] : vector<4x256xf32>
     %14 = subf %arg5, %arg4 : f32
   //       CHECK:   %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32>
     %15 = math.tanh %arg5 : f32
@@ -428,12 +426,15 @@ func @matmul_tensors(
   //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
   //   CHECK-DAG:   %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32>
   //   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
-  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
+  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<12x4xf32>
   //   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
   //
   // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
   // a later canonicalization fuses the add into vector.contract.
-  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
+  //       CHECK:   %[[C:.*]] = vector.contract
+  //  CHECK-SAME:     iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+  //  CHECK-SAME:     %[[V0]], %[[V1]], %[[VEC_C0]] :
+  //  CHECK-SAME:     vector<8x4xf32>, vector<12x4xf32> into vector<8x12xf32>
   //       CHECK:   %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
   //       CHECK:   %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
   %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
@@ -453,15 +454,17 @@ func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12x
   //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
   //   CHECK-DAG:   %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32>
   //   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8>
-  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8>
+  //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<12x6xi8>
   //   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32>
   //   CHECK-DAG:   %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32>
-  //   CHECK-DAG:   %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32>
+  //   CHECK-DAG:   %[[V1_32:.*]] = sexti %[[V1]] : vector<12x6xi8> to vector<12x6xi32>
   //
   // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
   // a later canonicalization fuses the add into vector.contract.
-  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0_32]], %[[V1_32]], %[[VEC_C0]]
-  //  CHECK-SAME:     vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32>
+  //       CHECK:   %[[C:.*]] = vector.contract
+  //  CHECK-SAME:      iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+  //  CHECK-SAME:      %[[V0_32]], %[[V1_32]], %[[VEC_C0]]
+  //  CHECK-SAME:     vector<4x6xi32>, vector<12x6xi32> into vector<4x12xi32>
   //       CHECK:   %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32>
   //       CHECK:   vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]}
   //  CHECK-SAME:     vector<4x12xi32>, memref<4x12xi32>
@@ -491,6 +494,8 @@ func @pad_static(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32>
   return %0 : tensor<2x3x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @pad_static_high_padding
 //       CHECK:   linalg.pad_tensor
 func @pad_static_high_padding(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
@@ -501,6 +506,8 @@ func @pad_static_high_padding(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tens
   return %0 : tensor<2x3x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @pad_dynamic
 //       CHECK:   linalg.pad_tensor
 func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
@@ -511,3 +518,72 @@ func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
     } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
   return %0 : tensor<6x?x?x?xf32>
 }
+
+// -----
+
+// CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)>
+
+// CHECK-LABEL: func @sum_exp
+func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
+  -> tensor<4x16xf32>
+{
+  // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
+  // CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$M0]]} : tensor<4x16xf32>, vector<4x16x8xf32>
+  // CHECK: math.exp {{.*}} : vector<4x16x8xf32>
+  // CHECK: addf {{.*}} : vector<4x16x8xf32>
+  // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
+  // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
+  // CHECK: return {{.*}} : tensor<4x16xf32>
+  %0 = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+        affine_map<(d0, d1, d2) -> (d0, d1)>
+      ],
+      iterator_types = ["parallel", "parallel", "reduction"]
+    } ins(%input : tensor<4x16x8xf32>) outs(%output : tensor<4x16xf32>) {
+    ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
+      %1 = math.exp %arg0 : f32
+      %2 = addf %1, %arg1 : f32
+      linalg.yield %2 : f32
+    } -> tensor<4x16xf32>
+  return %0 : tensor<4x16xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$M1:.*]] =  affine_map<(d0, d1) -> (d1, d0, 0, 0)>
+// CHECK-DAG: #[[$M2:.*]] =  affine_map<(d0, d1) -> (0, 0, d1, d0)>
+// CHECK-DAG: #[[$M3:.*]] =  affine_map<(d0, d1) -> (d1, 0, 0, d0)>
+// CHECK-DAG: #[[$M4:.*]] =  affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: func @sum_exp_2
+func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: tensor<5x2xf32>)
+  -> tensor<5x2xf32>
+{
+  // CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32>
+  // CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32>
+  // CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x3x4x5xf32>
+  // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
+  // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
+  // CHECK: addf {{.*}} : vector<2x3x4x5xf32>
+  // CHECK: addf {{.*}} : vector<2x3x4x5xf32>
+  // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}}  [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
+  // CHECK: vector.transfer_write {{.*}} {permutation_map = #[[$M4]]} : vector<2x5xf32>, tensor<5x2xf32>
+  // CHECK: return {{.*}} : tensor<5x2xf32>
+  %0 = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0, d1, d2, d3) -> (d1, d0)>,
+        affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+        affine_map<(d0, d1, d2, d3) -> (d3, d0)>
+      ],
+      iterator_types = ["parallel", "reduction", "reduction", "parallel"]
+    } ins(%input, %input_2 : tensor<3x2xf32>, tensor<5x4xf32>) outs(%output : tensor<5x2xf32>) {
+    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):  // no predecessors
+      %1 = math.exp %arg0 : f32
+      %2 = math.exp %arg1 : f32
+      %3 = addf %1, %2 : f32
+      %4 = addf %3, %arg2 : f32
+      linalg.yield %4 : f32
+    } -> tensor<5x2xf32>
+  return %0 : tensor<5x2xf32>
+}


        


More information about the Mlir-commits mailing list