[Mlir-commits] [mlir] [mlir][scf][vector] Add `scf.parallel` vectorizer (PR #94168)

Ivan Butygin llvmlistbot at llvm.org
Mon Jun 3 10:48:42 PDT 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/94168

>From 0437a9ed2e2f2ab02cb4a73da54ec64144f271ee Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 5 Dec 2023 16:44:04 -0600
Subject: [PATCH 01/10] [mlir][scf] upstream numba's scf vectorizer

---
 mlir/include/mlir/Transforms/SCFVectorize.h |  49 ++
 mlir/lib/Transforms/CMakeLists.txt          |   1 +
 mlir/lib/Transforms/SCFVectorize.cpp        | 661 ++++++++++++++++++++
 3 files changed, 711 insertions(+)
 create mode 100644 mlir/include/mlir/Transforms/SCFVectorize.h
 create mode 100644 mlir/lib/Transforms/SCFVectorize.cpp

diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
new file mode 100644
index 0000000000000..d754b38d5bc23
--- /dev/null
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -0,0 +1,49 @@
+//===- SCFVectorize.h - ------------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_
+#define MLIR_TRANSFORMS_SCFVECTORIZE_H_
+
+#include <memory>
+#include <optional>
+
+namespace mlir {
+class OpBuilder;
+class Pass;
+struct LogicalResult;
+namespace scf {
+class ParallelOp;
+}
+} // namespace mlir
+
+namespace mlir {
+struct SCFVectorizeInfo {
+  unsigned dim = 0;
+  unsigned factor = 0;
+  unsigned count = 0;
+  bool masked = false;
+};
+
+std::optional<SCFVectorizeInfo> getLoopVectorizeInfo(mlir::scf::ParallelOp loop,
+                                                     unsigned dim,
+                                                     unsigned vectorBitWidth);
+
+struct SCFVectorizeParams {
+  unsigned dim = 0;
+  unsigned factor = 0;
+  bool masked = false;
+};
+
+mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder,
+                                  mlir::scf::ParallelOp loop,
+                                  const SCFVectorizeParams &params);
+
+std::unique_ptr<mlir::Pass> createSCFVectorizePass();
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
\ No newline at end of file
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46..ed71c73c938ed 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
   PrintIR.cpp
   RemoveDeadValues.cpp
   SCCP.cpp
+  SCFVectorize.cpp
   SROA.cpp
   StripDebugInfo.cpp
   SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
new file mode 100644
index 0000000000000..d7545ee30e29a
--- /dev/null
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -0,0 +1,661 @@
+//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/SCFVectorize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+static unsigned getTypeBitWidth(mlir::Type type) {
+  if (mlir::isa<mlir::IndexType>(type))
+    return 64; // TODO: unhardcode
+
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+
+  return 0;
+}
+
+static unsigned getArgsTypeWidth(mlir::Operation &op) {
+  unsigned ret = 0;
+  for (auto arg : op.getOperands())
+    ret = std::max(ret, getTypeBitWidth(arg.getType()));
+
+  for (auto res : op.getResults())
+    ret = std::max(ret, getTypeBitWidth(res.getType()));
+
+  return ret;
+}
+
+static bool isSupportedVectorOp(mlir::Operation &op) {
+  return op.hasTrait<mlir::OpTrait::Vectorizable>();
+}
+
+static bool isSupportedVecElem(mlir::Type type) {
+  return type.isIntOrIndexOrFloat();
+}
+
+static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) {
+  if (val1.size() != val2.size())
+    return false;
+
+  for (auto v1 : val1) {
+    auto it = llvm::find(val2, v1);
+    if (it == val2.end())
+      return false;
+  }
+  return true;
+}
+
+template <typename Op>
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
+                               Op memOp) {
+  auto loopIndexVars = loop.getInductionVars();
+  assert(dim < loopIndexVars.size());
+  auto memref = memOp.getMemRef();
+  auto type = mlir::cast<mlir::MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(type.getElementType());
+  if (width == 0)
+    return std::nullopt;
+
+  if (!type.getLayout().isIdentity())
+    return std::nullopt;
+
+  if (!isRangePermutation(memOp.getIndices(), loopIndexVars))
+    return std::nullopt;
+
+  if (memOp.getIndices().back() != loopIndexVars[dim])
+    return std::nullopt;
+
+  mlir::DominanceInfo dom;
+  if (!dom.properlyDominates(memref, loop))
+    return std::nullopt;
+
+  return width;
+}
+
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim,
+                           mlir::Operation &op) {
+  assert(dim < loop.getInductionVars().size());
+  if (auto storeOp = mlir::dyn_cast<mlir::memref::StoreOp>(op))
+    return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp);
+
+  if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op))
+    return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp);
+
+  return std::nullopt;
+}
+
+template <typename T>
+static bool isOp(mlir::Operation &op) {
+  return mlir::isa<T>(op);
+}
+
+static std::optional<mlir::vector::CombiningKind>
+getReductionKind(mlir::scf::ReduceOp op) {
+  mlir::Block &body = op.getReductionOperator().front();
+  if (!llvm::hasSingleElement(body.without_terminator()))
+    return std::nullopt;
+
+  mlir::Operation &redOp = body.front();
+
+  using fptr_t = bool (*)(mlir::Operation &);
+  using CC = mlir::vector::CombiningKind;
+  const std::pair<fptr_t, CC> handlers[] = {
+      // clang-format off
+      {&isOp<mlir::arith::AddIOp>, CC::ADD},
+      {&isOp<mlir::arith::AddFOp>, CC::ADD},
+      {&isOp<mlir::arith::MulIOp>, CC::MUL},
+      {&isOp<mlir::arith::MulFOp>, CC::MUL},
+      // clang-format on
+  };
+
+  for (auto &&[handler, cc] : handlers) {
+    if (handler(redOp))
+      return cc;
+  }
+
+  return std::nullopt;
+}
+
+std::optional<mlir::SCFVectorizeInfo>
+mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
+                           unsigned vectorBitwidth) {
+  assert(dim < loop.getStep().size());
+  assert(vectorBitwidth > 0);
+  unsigned factor = vectorBitwidth / 8;
+  if (factor <= 1)
+    return std::nullopt;
+
+  if (!mlir::isConstantIntValue(loop.getStep()[dim], 1))
+    return std::nullopt;
+
+  unsigned count = 0;
+  bool masked = true;
+
+  for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+    if (auto reduce = mlir::dyn_cast<mlir::scf::ReduceOp>(op)) {
+      if (!getReductionKind(reduce))
+        masked = false;
+
+      continue;
+    }
+
+    if (op.getNumRegions() > 0)
+      return std::nullopt;
+
+    if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) {
+      auto newFactor = vectorBitwidth / *w;
+      if (newFactor > 1) {
+        factor = std::min(factor, newFactor);
+        ++count;
+      }
+      continue;
+    }
+
+    if (!isSupportedVectorOp(op)) {
+      masked = false;
+      continue;
+    }
+
+    auto width = getArgsTypeWidth(op);
+    if (width == 0)
+      return std::nullopt;
+
+    auto newFactor = vectorBitwidth / width;
+    if (newFactor <= 1)
+      continue;
+
+    factor = std::min(factor, newFactor);
+
+    ++count;
+  }
+
+  if (count == 0)
+    return std::nullopt;
+
+  return SCFVectorizeInfo{dim, factor, count, masked};
+}
+
+static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) {
+  if (auto fmf = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(op))
+    return fmf.getFastMathFlagsAttr().getValue();
+
+  return mlir::arith::FastMathFlags::none;
+}
+
+mlir::LogicalResult
+mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
+                    const mlir::SCFVectorizeParams &params) {
+  auto dim = params.dim;
+  auto factor = params.factor;
+  auto masked = params.masked;
+  assert(dim < loop.getStep().size());
+  assert(factor > 1);
+  assert(mlir::isConstantIntValue(loop.getStep()[dim], 1));
+
+  mlir::OpBuilder::InsertionGuard g(builder);
+  builder.setInsertionPoint(loop);
+
+  auto lower = llvm::to_vector(loop.getLowerBound());
+  auto upper = llvm::to_vector(loop.getUpperBound());
+  auto step = llvm::to_vector(loop.getStep());
+
+  auto loc = loop.getLoc();
+
+  auto origIndexVar = loop.getInductionVars()[dim];
+
+  mlir::Value factorVal =
+      builder.create<mlir::arith::ConstantIndexOp>(loc, factor);
+
+  auto origLower = lower[dim];
+  auto origUpper = upper[dim];
+  mlir::Value count =
+      builder.create<mlir::arith::SubIOp>(loc, origUpper, origLower);
+  mlir::Value newCount;
+  if (masked) {
+    mlir::Value incCount =
+        builder.create<mlir::arith::AddIOp>(loc, count, factorVal);
+    mlir::Value one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
+    incCount = builder.create<mlir::arith::SubIOp>(loc, incCount, one);
+    newCount = builder.create<mlir::arith::DivSIOp>(loc, incCount, factorVal);
+  } else {
+    newCount = builder.create<mlir::arith::DivSIOp>(loc, count, factorVal);
+  }
+
+  mlir::Value zero = builder.create<mlir::arith::ConstantIndexOp>(loc, 0);
+  lower[dim] = zero;
+  upper[dim] = newCount;
+
+  auto newLoop = builder.create<mlir::scf::ParallelOp>(loc, lower, upper, step,
+                                                       loop.getInitVals());
+  auto newIndexVar = newLoop.getInductionVars()[dim];
+
+  auto toVectorType = [&](mlir::Type elemType) -> mlir::VectorType {
+    int64_t f = factor;
+    return mlir::VectorType::get(f, elemType);
+  };
+
+  mlir::IRMapping mapping;
+  mlir::IRMapping scalarMapping;
+
+  auto createPosionVec = [&](mlir::VectorType vecType) -> mlir::Value {
+    return builder.create<mlir::ub::PoisonOp>(loc, vecType, nullptr);
+  };
+
+  auto getVecVal = [&](mlir::Value orig) -> mlir::Value {
+    if (auto mapped = mapping.lookupOrNull(orig))
+      return mapped;
+
+    if (orig == origIndexVar) {
+      auto vecType = toVectorType(builder.getIndexType());
+      llvm::SmallVector<mlir::Attribute> elems(factor);
+      for (auto i : llvm::seq(0u, factor))
+        elems[i] = builder.getIndexAttr(i);
+      auto attr = mlir::DenseElementsAttr::get(vecType, elems);
+      mlir::Value vec =
+          builder.create<mlir::arith::ConstantOp>(loc, vecType, attr);
+
+      mlir::Value idx =
+          builder.create<mlir::arith::MulIOp>(loc, newIndexVar, factorVal);
+      idx = builder.create<mlir::arith::AddIOp>(loc, idx, origLower);
+      idx = builder.create<mlir::vector::SplatOp>(loc, idx, vecType);
+      vec = builder.create<mlir::arith::AddIOp>(loc, idx, vec);
+      mapping.map(orig, vec);
+      return vec;
+    }
+    auto type = orig.getType();
+    assert(isSupportedVecElem(type));
+
+    mlir::Value val = orig;
+    auto origIndexVars = loop.getInductionVars();
+    auto it = llvm::find(origIndexVars, orig);
+    if (it != origIndexVars.end())
+      val = newLoop.getInductionVars()[it - origIndexVars.begin()];
+
+    auto vecType = toVectorType(type);
+    mlir::Value vec = builder.create<mlir::vector::SplatOp>(loc, val, vecType);
+    mapping.map(orig, vec);
+    return vec;
+  };
+
+  llvm::DenseMap<mlir::Value, llvm::SmallVector<mlir::Value>> unpackedVals;
+  auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange {
+    auto it = unpackedVals.find(val);
+    if (it != unpackedVals.end())
+      return it->second;
+
+    auto &ret = unpackedVals[val];
+    assert(ret.empty());
+    if (!isSupportedVecElem(val.getType())) {
+      ret.resize(factor, val);
+      return ret;
+    }
+
+    auto vecVal = getVecVal(val);
+    ret.resize(factor);
+    for (auto i : llvm::seq(0u, factor)) {
+      mlir::Value idx = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
+      ret[i] = builder.create<mlir::vector::ExtractElementOp>(loc, vecVal, idx);
+    }
+    return ret;
+  };
+
+  auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) {
+    assert(newVals.size() == factor);
+    assert(unpackedVals.count(origVal) == 0);
+    unpackedVals[origVal].append(newVals.begin(), newVals.end());
+
+    auto type = origVal.getType();
+    if (!isSupportedVecElem(type))
+      return;
+
+    auto vecType = toVectorType(type);
+
+    mlir::Value vec = createPosionVec(vecType);
+    for (auto i : llvm::seq(0u, factor)) {
+      mlir::Value idx = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
+      vec = builder.create<mlir::vector::InsertElementOp>(loc, newVals[i], vec,
+                                                          idx);
+    }
+    mapping.map(origVal, vec);
+  };
+
+  mlir::Value mask;
+  auto getMask = [&]() -> mlir::Value {
+    if (mask)
+      return mask;
+
+    mlir::OpFoldResult maskSize;
+    if (masked) {
+      mlir::Value size =
+          builder.create<mlir::arith::MulIOp>(loc, factorVal, newIndexVar);
+      maskSize =
+          builder.create<mlir::arith::SubIOp>(loc, count, size).getResult();
+    } else {
+      maskSize = builder.getIndexAttr(factor);
+    }
+    auto vecType = toVectorType(builder.getI1Type());
+    mask = builder.create<mlir::vector::CreateMaskOp>(loc, vecType, maskSize);
+
+    return mask;
+  };
+
+  mlir::DominanceInfo dom;
+
+  auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
+    return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op);
+  };
+
+  auto getMemrefVecIndices = [&](mlir::ValueRange indices) {
+    scalarMapping.clear();
+    scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
+
+    llvm::SmallVector<mlir::Value> ret(indices.size());
+    for (auto &&[i, val] : llvm::enumerate(indices)) {
+      if (val == origIndexVar) {
+        mlir::Value idx =
+            builder.create<mlir::arith::MulIOp>(loc, newIndexVar, factorVal);
+        idx = builder.create<mlir::arith::AddIOp>(loc, idx, origLower);
+        ret[i] = idx;
+        continue;
+      }
+      ret[i] = scalarMapping.lookup(val);
+    }
+
+    return ret;
+  };
+
+  auto canGatherScatter = [&](auto op) {
+    auto memref = op.getMemRef();
+    auto memrefType = mlir::cast<mlir::MemRefType>(memref.getType());
+    if (!isSupportedVecElem(memrefType.getElementType()))
+      return false;
+
+    return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
+           memrefType.getLayout().isIdentity();
+  };
+
+  auto genLoad = [&](auto loadOp) {
+    auto indices = getMemrefVecIndices(loadOp.getIndices());
+    auto resType = toVectorType(loadOp.getResult().getType());
+    auto memref = loadOp.getMemRef();
+    mlir::Value vecLoad;
+    if (masked) {
+      auto mask = getMask();
+      auto init = createPosionVec(resType);
+      vecLoad = builder.create<mlir::vector::MaskedLoadOp>(loc, resType, memref,
+                                                           indices, mask, init);
+    } else {
+      vecLoad =
+          builder.create<mlir::vector::LoadOp>(loc, resType, memref, indices);
+    }
+    mapping.map(loadOp.getResult(), vecLoad);
+  };
+
+  auto genStore = [&](auto storeOp) {
+    auto indices = getMemrefVecIndices(storeOp.getIndices());
+    auto value = getVecVal(storeOp.getValueToStore());
+    auto memref = storeOp.getMemRef();
+    if (masked) {
+      auto mask = getMask();
+      builder.create<mlir::vector::MaskedStoreOp>(loc, memref, indices, mask,
+                                                  value);
+    } else {
+      builder.create<mlir::vector::StoreOp>(loc, value, memref, indices);
+    }
+  };
+
+  llvm::SmallVector<mlir::Value> duplicatedArgs;
+  llvm::SmallVector<mlir::Value> duplicatedResults;
+
+  builder.setInsertionPointToStart(newLoop.getBody());
+  for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+    loc = op.getLoc();
+    if (isSupportedVectorOp(op)) {
+      for (auto arg : op.getOperands())
+        getVecVal(arg); // init mapper for op args
+
+      auto newOp = builder.clone(op, mapping);
+      for (auto res : newOp->getResults())
+        res.setType(toVectorType(res.getType()));
+
+      continue;
+    }
+
+    if (auto reduceOp = mlir::dyn_cast<mlir::scf::ReduceOp>(op)) {
+      scalarMapping.clear();
+      auto &reduceBody = reduceOp.getReductionOperator().front();
+      assert(reduceBody.getNumArguments() == 2);
+
+      mlir::Value reduceVal;
+      if (auto redKind = getReductionKind(reduceOp)) {
+        mlir::Value redArg = getVecVal(reduceOp.getOperand());
+        if (redArg) {
+          auto neutral = mlir::arith::getNeutralElement(&reduceBody.front());
+          assert(neutral);
+          mlir::Value neutralVal =
+              builder.create<mlir::arith::ConstantOp>(loc, *neutral);
+          mlir::Value neutralVec = builder.create<mlir::vector::SplatOp>(
+              loc, neutralVal, redArg.getType());
+          auto mask = getMask();
+          redArg = builder.create<mlir::arith::SelectOp>(loc, mask, redArg,
+                                                         neutralVec);
+        }
+
+        auto fmf = getFMF(reduceBody.front());
+        reduceVal = builder.create<mlir::vector::ReductionOp>(loc, *redKind,
+                                                              redArg, fmf);
+      } else {
+        if (masked)
+          return op.emitError("Cannot vectorize op in masked mode");
+
+        auto reduceTerm =
+            mlir::cast<mlir::scf::ReduceReturnOp>(reduceBody.getTerminator());
+        auto lhs = reduceBody.getArgument(0);
+        auto rhs = reduceBody.getArgument(1);
+        auto unpacked = getUnpackedVals(reduceOp.getOperand());
+        assert(unpacked.size() == factor);
+        reduceVal = unpacked.front();
+        for (auto i : llvm::seq(1u, factor)) {
+          mlir::Value val = unpacked[i];
+          scalarMapping.map(lhs, reduceVal);
+          scalarMapping.map(rhs, val);
+          for (auto &redOp : reduceBody.without_terminator())
+            builder.clone(redOp, scalarMapping);
+
+          reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult());
+        }
+      }
+      scalarMapping.clear();
+      scalarMapping.map(reduceOp.getOperand(), reduceVal);
+      builder.clone(op, scalarMapping);
+      continue;
+    }
+
+    if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op)) {
+      if (canTriviallyVectorizeMemOp(loadOp)) {
+        genLoad(loadOp);
+        continue;
+      }
+      if (canGatherScatter(loadOp)) {
+        auto resType = toVectorType(loadOp.getResult().getType());
+        auto memref = loadOp.getMemRef();
+        auto mask = getMask();
+        auto indexVec = getVecVal(loadOp.getIndices()[0]);
+        auto init = createPosionVec(resType);
+
+        auto gather = builder.create<mlir::vector::GatherOp>(
+            loc, resType, memref, zero, indexVec, mask, init);
+        mapping.map(loadOp.getResult(), gather.getResult());
+        continue;
+      }
+    }
+
+    if (auto storeOp = mlir::dyn_cast<mlir::memref::StoreOp>(op)) {
+      if (canTriviallyVectorizeMemOp(storeOp)) {
+        genStore(storeOp);
+        continue;
+      }
+      if (canGatherScatter(storeOp)) {
+        auto memref = storeOp.getMemRef();
+        auto value = getVecVal(storeOp.getValueToStore());
+        auto mask = getMask();
+        auto indexVec = getVecVal(storeOp.getIndices()[0]);
+
+        builder.create<mlir::vector::ScatterOp>(loc, memref, zero, indexVec,
+                                                mask, value);
+      }
+    }
+
+    // Fallback: Failed to vectorize op, just duplicate it `factor` times
+    if (masked)
+      return op.emitError("Cannot vectorize op in masked mode");
+
+    scalarMapping.clear();
+
+    auto numArgs = op.getNumOperands();
+    auto numResults = op.getNumResults();
+    duplicatedArgs.resize(numArgs * factor);
+    duplicatedResults.resize(numResults * factor);
+
+    for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) {
+      auto unpacked = getUnpackedVals(arg);
+      assert(unpacked.size() == factor);
+      for (auto j : llvm::seq(0u, factor))
+        duplicatedArgs[j * numArgs + i] = unpacked[j];
+    }
+
+    for (auto i : llvm::seq(0u, factor)) {
+      auto args = mlir::ValueRange(duplicatedArgs)
+                      .drop_front(numArgs * i)
+                      .take_front(numArgs);
+      scalarMapping.map(op.getOperands(), args);
+      auto results = builder.clone(op, scalarMapping)->getResults();
+
+      for (auto j : llvm::seq(0u, numResults))
+        duplicatedResults[j * factor + i] = results[j];
+    }
+
+    for (auto i : llvm::seq(0u, numResults)) {
+      auto results = mlir::ValueRange(duplicatedResults)
+                         .drop_front(factor * i)
+                         .take_front(factor);
+      setUnpackedVals(op.getResult(i), results);
+    }
+  }
+
+  if (masked) {
+    loop->replaceAllUsesWith(newLoop.getResults());
+    loop->erase();
+  } else {
+    builder.setInsertionPoint(loop);
+    mlir::Value newLower =
+        builder.create<mlir::arith::MulIOp>(loc, newCount, factorVal);
+    newLower = builder.create<mlir::arith::AddIOp>(loc, origLower, newLower);
+
+    auto lowerCopy = llvm::to_vector(loop.getLowerBound());
+    lowerCopy[dim] = newLower;
+    loop.getLowerBoundMutable().assign(lowerCopy);
+    loop.getInitValsMutable().assign(newLoop.getResults());
+  }
+
+  return mlir::success();
+}
+
+llvm::StringRef getVectorLengthName() { return "numba.vector_length"; }
+
+static std::optional<unsigned> getVectorLength(mlir::Operation *op) {
+  auto func = op->getParentOfType<mlir::FunctionOpInterface>();
+  if (!func)
+    return std::nullopt;
+
+  auto attr = func->getAttrOfType<mlir::IntegerAttr>(getVectorLengthName());
+  if (!attr)
+    return std::nullopt;
+
+  auto val = attr.getInt();
+  if (val <= 0 || val > std::numeric_limits<unsigned>::max())
+    return std::nullopt;
+
+  return static_cast<unsigned>(val);
+}
+
+namespace {
+struct SCFVectorizePass
+    : public mlir::PassWrapper<SCFVectorizePass, mlir::OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass)
+
+  void getDependentDialects(mlir::DialectRegistry &registry) const override {
+    registry.insert<mlir::arith::ArithDialect>();
+    registry.insert<mlir::scf::SCFDialect>();
+    registry.insert<mlir::ub::UBDialect>();
+    registry.insert<mlir::vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    llvm::SmallVector<
+        std::pair<mlir::scf::ParallelOp, mlir::SCFVectorizeParams>>
+        toVectorize;
+
+    auto getBenefit = [](const mlir::SCFVectorizeInfo &info) {
+      return info.factor * info.count * (int(info.masked) + 1);
+    };
+
+    getOperation()->walk([&](mlir::scf::ParallelOp loop) {
+      auto len = getVectorLength(loop);
+      if (!len)
+        return;
+
+      std::optional<mlir::SCFVectorizeInfo> best;
+      for (auto dim : llvm::seq(0u, loop.getNumLoops())) {
+        auto info = mlir::getLoopVectorizeInfo(loop, dim, *len);
+        if (!info)
+          continue;
+
+        if (!best) {
+          best = *info;
+          continue;
+        }
+
+        if (getBenefit(*info) > getBenefit(*best))
+          best = *info;
+      }
+
+      if (!best)
+        return;
+
+      toVectorize.emplace_back(
+          loop,
+          mlir::SCFVectorizeParams{best->dim, best->factor, best->masked});
+    });
+
+    if (toVectorize.empty())
+      return markAllAnalysesPreserved();
+
+    mlir::OpBuilder builder(&getContext());
+    for (auto &&[loop, params] : toVectorize) {
+      builder.setInsertionPoint(loop);
+      if (mlir::failed(mlir::vectorizeLoop(builder, loop, params)))
+        return signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> mlir::createSCFVectorizePass() {
+  return std::make_unique<SCFVectorizePass>();
+}

>From 7f5e4672c73a50c3188c0265504bab9c2b536db4 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 25 Apr 2024 10:57:05 -0500
Subject: [PATCH 02/10] get new stuff

---
 mlir/include/mlir/Transforms/SCFVectorize.h |  27 ++-
 mlir/lib/Transforms/SCFVectorize.cpp        | 210 ++++++++++++++------
 2 files changed, 172 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
index d754b38d5bc23..93a7864b976ec 100644
--- a/mlir/include/mlir/Transforms/SCFVectorize.h
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -22,23 +22,48 @@ class ParallelOp;
 } // namespace mlir
 
 namespace mlir {
+
+/// Loop vectorization info
 struct SCFVectorizeInfo {
+  /// Loop dimension on which to vectorize.
   unsigned dim = 0;
+
+  /// Biggest vector width, in elements.
   unsigned factor = 0;
+
+  /// Number of ops, which will be vectorized.
   unsigned count = 0;
+
+  /// Can use masked vector ops for our of bounds memory accesses.
   bool masked = false;
 };
 
+/// Collect vectorization statistics on specified `scf.parallel` dimension.
+/// Return `SCFVectorizeInfo` or `std::nullopt` if loop cannot be vectorized on
+/// specified dimension.
+///
+/// `vectorBitwidth` - maximum vector size, in bits.
 std::optional<SCFVectorizeInfo> getLoopVectorizeInfo(mlir::scf::ParallelOp loop,
                                                      unsigned dim,
-                                                     unsigned vectorBitWidth);
+                                                     unsigned vectorBitwidth);
 
+/// Vectorization params
 struct SCFVectorizeParams {
+  /// Loop dimension on which to vectorize.
   unsigned dim = 0;
+
+  /// Desired vector length, in elements
   unsigned factor = 0;
+
+  /// Use masked vector ops for memory access outside loop bounds.
   bool masked = false;
 };
 
+/// Vectorize loop on specified dimension with specified factor.
+///
+/// If `masked` is `true` and loop bound is not divisible by `factor`, instead
+/// of generating second loop to process remainig iterations, extend loop count
+/// and generate masked vector ops to handle out-of bounds memory accesses.
 mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder,
                                   mlir::scf::ParallelOp loop,
                                   const SCFVectorizeParams &params);
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
index d7545ee30e29a..13a9eca9cd2d3 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -16,7 +16,17 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
-
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/MemRef/IR/MemRef.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/Dialect/UB/IR/UBOps.h>
+#include <mlir/Dialect/Vector/IR/VectorOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/Interfaces/FunctionInterfaces.h>
+#include <mlir/Pass/Pass.h>
+
+/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// vectorized.
 static unsigned getTypeBitWidth(mlir::Type type) {
   if (mlir::isa<mlir::IndexType>(type))
     return 64; // TODO: unhardcode
@@ -46,6 +56,8 @@ static bool isSupportedVecElem(mlir::Type type) {
   return type.isIntOrIndexOrFloat();
 }
 
+/// Check if one `ValueRange` is permutation of another, i.e. contains same
+/// values, potentially in different order.
 static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) {
   if (val1.size() != val2.size())
     return false;
@@ -86,6 +98,10 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
   return width;
 }
 
+/// Check if memref load/store can be converted into vectorized load/store
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
 static std::optional<unsigned>
 cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim,
                            mlir::Operation &op) {
@@ -104,9 +120,10 @@ static bool isOp(mlir::Operation &op) {
   return mlir::isa<T>(op);
 }
 
+/// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot
+/// `std::nullopt` if reduction cannot be handled by `vector.reduce`.
 static std::optional<mlir::vector::CombiningKind>
-getReductionKind(mlir::scf::ReduceOp op) {
-  mlir::Block &body = op.getReductionOperator().front();
+getReductionKind(mlir::Block &body) {
   if (!llvm::hasSingleElement(body.without_terminator()))
     return std::nullopt;
 
@@ -140,23 +157,31 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
   if (factor <= 1)
     return std::nullopt;
 
+  /// Only step==1 is supported for now.
   if (!mlir::isConstantIntValue(loop.getStep()[dim], 1))
     return std::nullopt;
 
   unsigned count = 0;
   bool masked = true;
 
-  for (mlir::Operation &op : loop.getBody()->without_terminator()) {
-    if (auto reduce = mlir::dyn_cast<mlir::scf::ReduceOp>(op)) {
-      if (!getReductionKind(reduce))
-        masked = false;
+  /// Check if `scf.reduce` can be handled by `vector.reduce`.
+  /// If not we still can vectorize the loop but we cannot use masked
+  /// vectorize.
+  auto reduce =
+      mlir::cast<mlir::scf::ReduceOp>(loop.getBody()->getTerminator());
+  for (mlir::Region &reg : reduce.getReductions()) {
+    if (!getReductionKind(reg.front()))
+      masked = false;
 
-      continue;
-    }
+    continue;
+  }
 
+  for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+    /// Ops with nested regions are not supported yet.
     if (op.getNumRegions() > 0)
       return std::nullopt;
 
+    /// Check mem ops.
     if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) {
       auto newFactor = vectorBitwidth / *w;
       if (newFactor > 1) {
@@ -166,6 +191,8 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
       continue;
     }
 
+    /// If met the op which cannot be vectorized, we can replicate it and still
+    /// potentially vectorize other ops, but we cannot use masked vectorize.
     if (!isSupportedVectorOp(op)) {
       masked = false;
       continue;
@@ -184,12 +211,14 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
     ++count;
   }
 
+  /// No ops to vectorize.
   if (count == 0)
     return std::nullopt;
 
   return SCFVectorizeInfo{dim, factor, count, masked};
 }
 
+/// Get fastmath flags if ops support them or default (none).
 static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) {
   if (auto fmf = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(op))
     return fmf.getFastMathFlagsAttr().getValue();
@@ -226,6 +255,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
   mlir::Value count =
       builder.create<mlir::arith::SubIOp>(loc, origUpper, origLower);
   mlir::Value newCount;
+
+  // Compute new loop count, ceildiv if masked, floordiv otherwise.
   if (masked) {
     mlir::Value incCount =
         builder.create<mlir::arith::AddIOp>(loc, count, factorVal);
@@ -240,6 +271,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
   lower[dim] = zero;
   upper[dim] = newCount;
 
+  // Vectorized loop.
   auto newLoop = builder.create<mlir::scf::ParallelOp>(loc, lower, upper, step,
                                                        loop.getInitVals());
   auto newIndexVar = newLoop.getInductionVars()[dim];
@@ -256,10 +288,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     return builder.create<mlir::ub::PoisonOp>(loc, vecType, nullptr);
   };
 
+  // Get vector value in new loop for provided `orig` value in source loop.
   auto getVecVal = [&](mlir::Value orig) -> mlir::Value {
+    // Use cached value if present.
     if (auto mapped = mapping.lookupOrNull(orig))
       return mapped;
 
+    // Vectorized loop index, loop index is divided by factor, so for factorN
+    // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
     if (orig == origIndexVar) {
       auto vecType = toVectorType(builder.getIndexType());
       llvm::SmallVector<mlir::Attribute> elems(factor);
@@ -283,9 +319,16 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     mlir::Value val = orig;
     auto origIndexVars = loop.getInductionVars();
     auto it = llvm::find(origIndexVars, orig);
+
+    // If loop index, but not on vectorized dimension, just take new loop index
+    // and splat it.
     if (it != origIndexVars.end())
       val = newLoop.getInductionVars()[it - origIndexVars.begin()];
 
+    // Values which are defined inside loop body are preemptively added to the
+    // mapper and not handled here. Values defined outside body are just
+    // splatted.
+
     auto vecType = toVectorType(type);
     mlir::Value vec = builder.create<mlir::vector::SplatOp>(loc, val, vecType);
     mapping.map(orig, vec);
@@ -293,18 +336,28 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
   };
 
   llvm::DenseMap<mlir::Value, llvm::SmallVector<mlir::Value>> unpackedVals;
+
+  // Get unpacked values for provided `orig` value in source loop.
+  // Values are returned as `ValueRange` and not as vector value.
   auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange {
+    // Use cached values if present.
     auto it = unpackedVals.find(val);
     if (it != unpackedVals.end())
       return it->second;
 
+    // Values which are defined inside loop body are preemptively added to the
+    // cache and not handled here.
+
     auto &ret = unpackedVals[val];
     assert(ret.empty());
     if (!isSupportedVecElem(val.getType())) {
+      // Non vectorizable value, it must be a value defined outside the loop,
+      // just replicate it.
       ret.resize(factor, val);
       return ret;
     }
 
+    // Get vector value and extract elements from it.
     auto vecVal = getVecVal(val);
     ret.resize(factor);
     for (auto i : llvm::seq(0u, factor)) {
@@ -314,6 +367,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     return ret;
   };
 
+  // Add unpacked values to the cache.
   auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) {
     assert(newVals.size() == factor);
     assert(unpackedVals.count(origVal) == 0);
@@ -323,6 +377,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     if (!isSupportedVecElem(type))
       return;
 
+    // If type is vectorizabale construct a vector add it to vector cache as
+    // well.
     auto vecType = toVectorType(type);
 
     mlir::Value vec = createPosionVec(vecType);
@@ -335,6 +391,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
   };
 
   mlir::Value mask;
+
+  // Contruct mask value and cache it. If not a masked mode mask is always all
+  // 1s.
   auto getMask = [&]() -> mlir::Value {
     if (mask)
       return mask;
@@ -360,6 +419,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op);
   };
 
+  // Get idices for vectorized memref load/store.
   auto getMemrefVecIndices = [&](mlir::ValueRange indices) {
     scalarMapping.clear();
     scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
@@ -379,6 +439,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     return ret;
   };
 
+  // Check if memref access can be converted into gather/scatter.
   auto canGatherScatter = [&](auto op) {
     auto memref = op.getMemRef();
     auto memrefType = mlir::cast<mlir::MemRefType>(memref.getType());
@@ -389,6 +450,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
            memrefType.getLayout().isIdentity();
   };
 
+  // Create vectorized memref load for specified non-vectorized load.
   auto genLoad = [&](auto loadOp) {
     auto indices = getMemrefVecIndices(loadOp.getIndices());
     auto resType = toVectorType(loadOp.getResult().getType());
@@ -406,6 +468,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     mapping.map(loadOp.getResult(), vecLoad);
   };
 
+  // Create vectorized memref store for specified non-vectorized store.
   auto genStore = [&](auto storeOp) {
     auto indices = getMemrefVecIndices(storeOp.getIndices());
     auto value = getVecVal(storeOp.getValueToStore());
@@ -426,6 +489,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
   for (mlir::Operation &op : loop.getBody()->without_terminator()) {
     loc = op.getLoc();
     if (isSupportedVectorOp(op)) {
+      // If op can be vectorized, clone it with vectorized inputs and  update
+      // resuls to vectorized types.
       for (auto arg : op.getOperands())
         getVecVal(arg); // init mapper for op args
 
@@ -436,56 +501,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
       continue;
     }
 
-    if (auto reduceOp = mlir::dyn_cast<mlir::scf::ReduceOp>(op)) {
-      scalarMapping.clear();
-      auto &reduceBody = reduceOp.getReductionOperator().front();
-      assert(reduceBody.getNumArguments() == 2);
-
-      mlir::Value reduceVal;
-      if (auto redKind = getReductionKind(reduceOp)) {
-        mlir::Value redArg = getVecVal(reduceOp.getOperand());
-        if (redArg) {
-          auto neutral = mlir::arith::getNeutralElement(&reduceBody.front());
-          assert(neutral);
-          mlir::Value neutralVal =
-              builder.create<mlir::arith::ConstantOp>(loc, *neutral);
-          mlir::Value neutralVec = builder.create<mlir::vector::SplatOp>(
-              loc, neutralVal, redArg.getType());
-          auto mask = getMask();
-          redArg = builder.create<mlir::arith::SelectOp>(loc, mask, redArg,
-                                                         neutralVec);
-        }
-
-        auto fmf = getFMF(reduceBody.front());
-        reduceVal = builder.create<mlir::vector::ReductionOp>(loc, *redKind,
-                                                              redArg, fmf);
-      } else {
-        if (masked)
-          return op.emitError("Cannot vectorize op in masked mode");
-
-        auto reduceTerm =
-            mlir::cast<mlir::scf::ReduceReturnOp>(reduceBody.getTerminator());
-        auto lhs = reduceBody.getArgument(0);
-        auto rhs = reduceBody.getArgument(1);
-        auto unpacked = getUnpackedVals(reduceOp.getOperand());
-        assert(unpacked.size() == factor);
-        reduceVal = unpacked.front();
-        for (auto i : llvm::seq(1u, factor)) {
-          mlir::Value val = unpacked[i];
-          scalarMapping.map(lhs, reduceVal);
-          scalarMapping.map(rhs, val);
-          for (auto &redOp : reduceBody.without_terminator())
-            builder.clone(redOp, scalarMapping);
-
-          reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult());
-        }
-      }
-      scalarMapping.clear();
-      scalarMapping.map(reduceOp.getOperand(), reduceVal);
-      builder.clone(op, scalarMapping);
-      continue;
-    }
-
+    // Vectorize memref load/store ops, vector load/store are preffered over
+    // gather/scatter.
     if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op)) {
       if (canTriviallyVectorizeMemOp(loadOp)) {
         genLoad(loadOp);
@@ -558,6 +575,70 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     }
   }
 
+  // Vectorize `scf.reduce` op.
+  auto reduceOp =
+      mlir::cast<mlir::scf::ReduceOp>(loop.getBody()->getTerminator());
+  llvm::SmallVector<mlir::Value> reduceVals;
+  reduceVals.reserve(reduceOp.getNumOperands());
+
+  for (auto &&[body, arg] :
+       llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) {
+    scalarMapping.clear();
+    mlir::Block &reduceBody = body.front();
+    assert(reduceBody.getNumArguments() == 2);
+
+    mlir::Value reduceVal;
+    if (auto redKind = getReductionKind(reduceBody)) {
+      // Generate `vector.reduce` if possible.
+      mlir::Value redArg = getVecVal(arg);
+      if (redArg) {
+        auto neutral = mlir::arith::getNeutralElement(&reduceBody.front());
+        assert(neutral);
+        mlir::Value neutralVal =
+            builder.create<mlir::arith::ConstantOp>(loc, *neutral);
+        mlir::Value neutralVec = builder.create<mlir::vector::SplatOp>(
+            loc, neutralVal, redArg.getType());
+        auto mask = getMask();
+        redArg = builder.create<mlir::arith::SelectOp>(loc, mask, redArg,
+                                                       neutralVec);
+      }
+
+      auto fmf = getFMF(reduceBody.front());
+      reduceVal =
+          builder.create<mlir::vector::ReductionOp>(loc, *redKind, redArg, fmf);
+    } else {
+      if (masked)
+        return reduceOp.emitError("Cannot vectorize reduce op in masked mode");
+
+      // If `vector.reduce` cannot be used, unpack values and reduce them
+      // individually.
+
+      auto reduceTerm =
+          mlir::cast<mlir::scf::ReduceReturnOp>(reduceBody.getTerminator());
+      auto lhs = reduceBody.getArgument(0);
+      auto rhs = reduceBody.getArgument(1);
+      auto unpacked = getUnpackedVals(arg);
+      assert(unpacked.size() == factor);
+      reduceVal = unpacked.front();
+      for (auto i : llvm::seq(1u, factor)) {
+        mlir::Value val = unpacked[i];
+        scalarMapping.map(lhs, reduceVal);
+        scalarMapping.map(rhs, val);
+        for (auto &redOp : reduceBody.without_terminator())
+          builder.clone(redOp, scalarMapping);
+
+        reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult());
+      }
+    }
+    reduceVals.emplace_back(reduceVal);
+  }
+
+  // Clone `scf.reduce` op to reduce across loop iterations.
+  if (!reduceVals.empty())
+    builder.clone(*reduceOp)->setOperands(reduceVals);
+
+  // If in masked mode remove old loop, otherwise update loop bounds to
+  // repurpose it for handling remaining values.
   if (masked) {
     loop->replaceAllUsesWith(newLoop.getResults());
     loop->erase();
@@ -576,14 +657,12 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
   return mlir::success();
 }
 
-llvm::StringRef getVectorLengthName() { return "numba.vector_length"; }
-
 static std::optional<unsigned> getVectorLength(mlir::Operation *op) {
   auto func = op->getParentOfType<mlir::FunctionOpInterface>();
   if (!func)
     return std::nullopt;
 
-  auto attr = func->getAttrOfType<mlir::IntegerAttr>(getVectorLengthName());
+  auto attr = func->getAttrOfType<mlir::IntegerAttr>("mlir.vector_length");
   if (!attr)
     return std::nullopt;
 
@@ -599,7 +678,8 @@ struct SCFVectorizePass
     : public mlir::PassWrapper<SCFVectorizePass, mlir::OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass)
 
-  void getDependentDialects(mlir::DialectRegistry &registry) const override {
+  virtual void
+  getDependentDialects(mlir::DialectRegistry &registry) const override {
     registry.insert<mlir::arith::ArithDialect>();
     registry.insert<mlir::scf::SCFDialect>();
     registry.insert<mlir::ub::UBDialect>();
@@ -611,6 +691,8 @@ struct SCFVectorizePass
         std::pair<mlir::scf::ParallelOp, mlir::SCFVectorizeParams>>
         toVectorize;
 
+    // Simple heuristic: total number of elements processed by vector ops, but
+    // prefer masked mode over non-masked.
     auto getBenefit = [](const mlir::SCFVectorizeInfo &info) {
       return info.factor * info.count * (int(info.masked) + 1);
     };
@@ -658,4 +740,4 @@ struct SCFVectorizePass
 
 std::unique_ptr<mlir::Pass> mlir::createSCFVectorizePass() {
   return std::make_unique<SCFVectorizePass>();
-}
+}
\ No newline at end of file

>From cbebaccdda8428163c4081e59498e748f5dc4c89 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 2 Jun 2024 17:07:51 +0200
Subject: [PATCH 03/10] working on pass

---
 mlir/include/mlir/Transforms/SCFVectorize.h   |  20 +-
 mlir/lib/Transforms/SCFVectorize.cpp          | 449 +++++++-----------
 mlir/test/Transforms/test-scf-vectorize.mlir  | 272 +++++++++++
 mlir/test/lib/Transforms/CMakeLists.txt       |   1 +
 mlir/test/lib/Transforms/TestSCFVectorize.cpp | 110 +++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  46 +-
 6 files changed, 592 insertions(+), 306 deletions(-)
 create mode 100644 mlir/test/Transforms/test-scf-vectorize.mlir
 create mode 100644 mlir/test/lib/Transforms/TestSCFVectorize.cpp

diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
index 93a7864b976ec..d2a5e3085ae37 100644
--- a/mlir/include/mlir/Transforms/SCFVectorize.h
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -9,12 +9,10 @@
 #ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_
 #define MLIR_TRANSFORMS_SCFVECTORIZE_H_
 
-#include <memory>
 #include <optional>
 
 namespace mlir {
-class OpBuilder;
-class Pass;
+class DataLayout;
 struct LogicalResult;
 namespace scf {
 class ParallelOp;
@@ -43,9 +41,9 @@ struct SCFVectorizeInfo {
 /// specified dimension.
 ///
 /// `vectorBitwidth` - maximum vector size, in bits.
-std::optional<SCFVectorizeInfo> getLoopVectorizeInfo(mlir::scf::ParallelOp loop,
-                                                     unsigned dim,
-                                                     unsigned vectorBitwidth);
+std::optional<SCFVectorizeInfo>
+getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
+                     unsigned vectorBitwidth, const DataLayout *DL = nullptr);
 
 /// Vectorization params
 struct SCFVectorizeParams {
@@ -64,11 +62,9 @@ struct SCFVectorizeParams {
 /// If `masked` is `true` and loop bound is not divisible by `factor`, instead
 /// of generating second loop to process remainig iterations, extend loop count
 /// and generate masked vector ops to handle out-of bounds memory accesses.
-mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder,
-                                  mlir::scf::ParallelOp loop,
-                                  const SCFVectorizeParams &params);
-
-std::unique_ptr<mlir::Pass> createSCFVectorizePass();
+mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop,
+                                  const SCFVectorizeParams &params,
+                                  const DataLayout *DL = nullptr);
 } // namespace mlir
 
-#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
\ No newline at end of file
+#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
index 13a9eca9cd2d3..29e184e584a56 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -1,4 +1,4 @@
-//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===//
+//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -9,27 +9,25 @@
 #include "mlir/Transforms/SCFVectorize.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/IRMapping.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include <mlir/Dialect/Arith/IR/Arith.h>
-#include <mlir/Dialect/MemRef/IR/MemRef.h>
-#include <mlir/Dialect/SCF/IR/SCF.h>
-#include <mlir/Dialect/UB/IR/UBOps.h>
-#include <mlir/Dialect/Vector/IR/VectorOps.h>
-#include <mlir/IR/IRMapping.h>
-#include <mlir/Interfaces/FunctionInterfaces.h>
-#include <mlir/Pass/Pass.h>
+
+using namespace mlir;
+
+static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
 
 /// Return type bitwidth for vectorization purposes or 0 if type cannot be
 /// vectorized.
-static unsigned getTypeBitWidth(mlir::Type type) {
-  if (mlir::isa<mlir::IndexType>(type))
-    return 64; // TODO: unhardcode
+static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+  if (!isSupportedVecElem(type))
+    return 0;
+
+  if (DL)
+    return DL->getTypeSizeInBits(type);
 
   if (type.isIntOrFloat())
     return type.getIntOrFloatBitWidth();
@@ -37,28 +35,24 @@ static unsigned getTypeBitWidth(mlir::Type type) {
   return 0;
 }
 
-static unsigned getArgsTypeWidth(mlir::Operation &op) {
+static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
   unsigned ret = 0;
   for (auto arg : op.getOperands())
-    ret = std::max(ret, getTypeBitWidth(arg.getType()));
+    ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
 
   for (auto res : op.getResults())
-    ret = std::max(ret, getTypeBitWidth(res.getType()));
+    ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
 
   return ret;
 }
 
-static bool isSupportedVectorOp(mlir::Operation &op) {
-  return op.hasTrait<mlir::OpTrait::Vectorizable>();
-}
-
-static bool isSupportedVecElem(mlir::Type type) {
-  return type.isIntOrIndexOrFloat();
+static bool isSupportedVectorOp(Operation &op) {
+  return op.hasTrait<OpTrait::Vectorizable>();
 }
 
 /// Check if one `ValueRange` is permutation of another, i.e. contains same
 /// values, potentially in different order.
-static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) {
+static bool isRangePermutation(ValueRange val1, ValueRange val2) {
   if (val1.size() != val2.size())
     return false;
 
@@ -72,13 +66,13 @@ static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) {
 
 template <typename Op>
 static std::optional<unsigned>
-cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
-                               Op memOp) {
+cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+                               const DataLayout *DL) {
   auto loopIndexVars = loop.getInductionVars();
   assert(dim < loopIndexVars.size());
   auto memref = memOp.getMemRef();
-  auto type = mlir::cast<mlir::MemRefType>(memref.getType());
-  auto width = getTypeBitWidth(type.getElementType());
+  auto type = cast<MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(type.getElementType(), DL);
   if (width == 0)
     return std::nullopt;
 
@@ -91,7 +85,7 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
   if (memOp.getIndices().back() != loopIndexVars[dim])
     return std::nullopt;
 
-  mlir::DominanceInfo dom;
+  DominanceInfo dom;
   if (!dom.properlyDominates(memref, loop))
     return std::nullopt;
 
@@ -103,54 +97,69 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
 /// Returns memref element bitwidth or `std::nullopt` if access cannot be
 /// vectorized.
 static std::optional<unsigned>
-cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim,
-                           mlir::Operation &op) {
+cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
+                           const DataLayout *DL) {
   assert(dim < loop.getInductionVars().size());
-  if (auto storeOp = mlir::dyn_cast<mlir::memref::StoreOp>(op))
-    return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp);
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
+
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
+
+  return std::nullopt;
+}
+
+template <typename Op>
+static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
+                                                    const DataLayout *DL) {
+  auto memref = op.getMemRef();
+  auto memrefType = cast<MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(memrefType.getElementType(), DL);
+  if (width == 0)
+    return std::nullopt;
+
+  DominanceInfo dom;
+  return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
+         memrefType.getLayout().isIdentity();
+}
 
-  if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op))
-    return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp);
+// Check if memref access can be converted into gather/scatter.
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return canGatherScatterImpl(loop, storeOp, DL);
+
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return canGatherScatterImpl(loop, loadOp, DL);
 
   return std::nullopt;
 }
 
-template <typename T>
-static bool isOp(mlir::Operation &op) {
-  return mlir::isa<T>(op);
+static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
+                                                    unsigned dim, Operation &op,
+                                                    const DataLayout *DL) {
+  if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL))
+    return w;
+
+  return canGatherScatter(loop, op, DL);
 }
 
 /// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot
 /// `std::nullopt` if reduction cannot be handled by `vector.reduce`.
-static std::optional<mlir::vector::CombiningKind>
-getReductionKind(mlir::Block &body) {
+static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
   if (!llvm::hasSingleElement(body.without_terminator()))
     return std::nullopt;
 
-  mlir::Operation &redOp = body.front();
-
-  using fptr_t = bool (*)(mlir::Operation &);
-  using CC = mlir::vector::CombiningKind;
-  const std::pair<fptr_t, CC> handlers[] = {
-      // clang-format off
-      {&isOp<mlir::arith::AddIOp>, CC::ADD},
-      {&isOp<mlir::arith::AddFOp>, CC::ADD},
-      {&isOp<mlir::arith::MulIOp>, CC::MUL},
-      {&isOp<mlir::arith::MulFOp>, CC::MUL},
-      // clang-format on
-  };
-
-  for (auto &&[handler, cc] : handlers) {
-    if (handler(redOp))
-      return cc;
-  }
-
-  return std::nullopt;
+  // TODO: Move getCombinerOpKind to vector dialect.
+  return linalg::getCombinerOpKind(&body.front());
 }
 
-std::optional<mlir::SCFVectorizeInfo>
-mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
-                           unsigned vectorBitwidth) {
+std::optional<SCFVectorizeInfo>
+mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
+                           unsigned vectorBitwidth, const DataLayout *DL) {
   assert(dim < loop.getStep().size());
   assert(vectorBitwidth > 0);
   unsigned factor = vectorBitwidth / 8;
@@ -158,7 +167,7 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
     return std::nullopt;
 
   /// Only step==1 is supported for now.
-  if (!mlir::isConstantIntValue(loop.getStep()[dim], 1))
+  if (!isConstantIntValue(loop.getStep()[dim], 1))
     return std::nullopt;
 
   unsigned count = 0;
@@ -167,22 +176,21 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
   /// Check if `scf.reduce` can be handled by `vector.reduce`.
   /// If not we still can vectorize the loop but we cannot use masked
   /// vectorize.
-  auto reduce =
-      mlir::cast<mlir::scf::ReduceOp>(loop.getBody()->getTerminator());
-  for (mlir::Region &reg : reduce.getReductions()) {
+  auto reduce = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
+  for (Region &reg : reduce.getReductions()) {
     if (!getReductionKind(reg.front()))
       masked = false;
 
     continue;
   }
 
-  for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+  for (Operation &op : loop.getBody()->without_terminator()) {
     /// Ops with nested regions are not supported yet.
     if (op.getNumRegions() > 0)
       return std::nullopt;
 
     /// Check mem ops.
-    if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) {
+    if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
       auto newFactor = vectorBitwidth / *w;
       if (newFactor > 1) {
         factor = std::min(factor, newFactor);
@@ -198,7 +206,7 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
       continue;
     }
 
-    auto width = getArgsTypeWidth(op);
+    auto width = getArgsTypeWidth(op, DL);
     if (width == 0)
       return std::nullopt;
 
@@ -219,26 +227,24 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
 }
 
 /// Get fastmath flags if ops support them or default (none).
-static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) {
-  if (auto fmf = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(op))
+static arith::FastMathFlags getFMF(Operation &op) {
+  if (auto fmf = dyn_cast<arith::ArithFastMathInterface>(op))
     return fmf.getFastMathFlagsAttr().getValue();
 
-  return mlir::arith::FastMathFlags::none;
+  return arith::FastMathFlags::none;
 }
 
-mlir::LogicalResult
-mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
-                    const mlir::SCFVectorizeParams &params) {
+LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
+                                  const SCFVectorizeParams &params,
+                                  const DataLayout *DL) {
   auto dim = params.dim;
   auto factor = params.factor;
   auto masked = params.masked;
   assert(dim < loop.getStep().size());
   assert(factor > 1);
-  assert(mlir::isConstantIntValue(loop.getStep()[dim], 1));
-
-  mlir::OpBuilder::InsertionGuard g(builder);
-  builder.setInsertionPoint(loop);
+  assert(isConstantIntValue(loop.getStep()[dim], 1));
 
+  OpBuilder builder(loop);
   auto lower = llvm::to_vector(loop.getLowerBound());
   auto upper = llvm::to_vector(loop.getUpperBound());
   auto step = llvm::to_vector(loop.getStep());
@@ -247,49 +253,53 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
 
   auto origIndexVar = loop.getInductionVars()[dim];
 
-  mlir::Value factorVal =
-      builder.create<mlir::arith::ConstantIndexOp>(loc, factor);
+  Value factorVal = builder.create<arith::ConstantIndexOp>(loc, factor);
 
   auto origLower = lower[dim];
   auto origUpper = upper[dim];
-  mlir::Value count =
-      builder.create<mlir::arith::SubIOp>(loc, origUpper, origLower);
-  mlir::Value newCount;
+  Value count = builder.createOrFold<arith::SubIOp>(loc, origUpper, origLower);
+  Value newCount;
 
   // Compute new loop count, ceildiv if masked, floordiv otherwise.
   if (masked) {
-    mlir::Value incCount =
-        builder.create<mlir::arith::AddIOp>(loc, count, factorVal);
-    mlir::Value one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
-    incCount = builder.create<mlir::arith::SubIOp>(loc, incCount, one);
-    newCount = builder.create<mlir::arith::DivSIOp>(loc, incCount, factorVal);
+    newCount = builder.createOrFold<arith::CeilDivSIOp>(loc, count, factorVal);
   } else {
-    newCount = builder.create<mlir::arith::DivSIOp>(loc, count, factorVal);
+    newCount = builder.createOrFold<arith::DivSIOp>(loc, count, factorVal);
   }
 
-  mlir::Value zero = builder.create<mlir::arith::ConstantIndexOp>(loc, 0);
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   lower[dim] = zero;
   upper[dim] = newCount;
 
   // Vectorized loop.
-  auto newLoop = builder.create<mlir::scf::ParallelOp>(loc, lower, upper, step,
-                                                       loop.getInitVals());
+  auto newLoop = builder.create<scf::ParallelOp>(loc, lower, upper, step,
+                                                 loop.getInitVals());
   auto newIndexVar = newLoop.getInductionVars()[dim];
 
-  auto toVectorType = [&](mlir::Type elemType) -> mlir::VectorType {
+  auto toVectorType = [&](Type elemType) -> VectorType {
     int64_t f = factor;
-    return mlir::VectorType::get(f, elemType);
+    return VectorType::get(f, elemType);
+  };
+
+  IRMapping mapping;
+  IRMapping scalarMapping;
+
+  auto createPosionVec = [&](VectorType vecType) -> Value {
+    return builder.create<ub::PoisonOp>(loc, vecType, nullptr);
   };
 
-  mlir::IRMapping mapping;
-  mlir::IRMapping scalarMapping;
+  Value indexVarMult;
+  auto getrIndexVarMult = [&]() -> Value {
+    if (indexVarMult)
+      return indexVarMult;
 
-  auto createPosionVec = [&](mlir::VectorType vecType) -> mlir::Value {
-    return builder.create<mlir::ub::PoisonOp>(loc, vecType, nullptr);
+    indexVarMult =
+        builder.createOrFold<arith::MulIOp>(loc, newIndexVar, factorVal);
+    return indexVarMult;
   };
 
   // Get vector value in new loop for provided `orig` value in source loop.
-  auto getVecVal = [&](mlir::Value orig) -> mlir::Value {
+  auto getVecVal = [&](Value orig) -> Value {
     // Use cached value if present.
     if (auto mapped = mapping.lookupOrNull(orig))
       return mapped;
@@ -298,25 +308,23 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
     if (orig == origIndexVar) {
       auto vecType = toVectorType(builder.getIndexType());
-      llvm::SmallVector<mlir::Attribute> elems(factor);
+      llvm::SmallVector<Attribute> elems(factor);
       for (auto i : llvm::seq(0u, factor))
         elems[i] = builder.getIndexAttr(i);
-      auto attr = mlir::DenseElementsAttr::get(vecType, elems);
-      mlir::Value vec =
-          builder.create<mlir::arith::ConstantOp>(loc, vecType, attr);
-
-      mlir::Value idx =
-          builder.create<mlir::arith::MulIOp>(loc, newIndexVar, factorVal);
-      idx = builder.create<mlir::arith::AddIOp>(loc, idx, origLower);
-      idx = builder.create<mlir::vector::SplatOp>(loc, idx, vecType);
-      vec = builder.create<mlir::arith::AddIOp>(loc, idx, vec);
+      auto attr = DenseElementsAttr::get(vecType, elems);
+      Value vec = builder.create<arith::ConstantOp>(loc, vecType, attr);
+
+      Value idx = getrIndexVarMult();
+      idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
+      idx = builder.create<vector::SplatOp>(loc, idx, vecType);
+      vec = builder.createOrFold<arith::AddIOp>(loc, idx, vec);
       mapping.map(orig, vec);
       return vec;
     }
     auto type = orig.getType();
     assert(isSupportedVecElem(type));
 
-    mlir::Value val = orig;
+    Value val = orig;
     auto origIndexVars = loop.getInductionVars();
     auto it = llvm::find(origIndexVars, orig);
 
@@ -330,16 +338,16 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     // splatted.
 
     auto vecType = toVectorType(type);
-    mlir::Value vec = builder.create<mlir::vector::SplatOp>(loc, val, vecType);
+    Value vec = builder.create<vector::SplatOp>(loc, val, vecType);
     mapping.map(orig, vec);
     return vec;
   };
 
-  llvm::DenseMap<mlir::Value, llvm::SmallVector<mlir::Value>> unpackedVals;
+  llvm::DenseMap<Value, llvm::SmallVector<Value>> unpackedVals;
 
   // Get unpacked values for provided `orig` value in source loop.
   // Values are returned as `ValueRange` and not as vector value.
-  auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange {
+  auto getUnpackedVals = [&](Value val) -> ValueRange {
     // Use cached values if present.
     auto it = unpackedVals.find(val);
     if (it != unpackedVals.end())
@@ -361,14 +369,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     auto vecVal = getVecVal(val);
     ret.resize(factor);
     for (auto i : llvm::seq(0u, factor)) {
-      mlir::Value idx = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
-      ret[i] = builder.create<mlir::vector::ExtractElementOp>(loc, vecVal, idx);
+      Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+      ret[i] = builder.create<vector::ExtractElementOp>(loc, vecVal, idx);
     }
     return ret;
   };
 
   // Add unpacked values to the cache.
-  auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) {
+  auto setUnpackedVals = [&](Value origVal, ValueRange newVals) {
     assert(newVals.size() == factor);
     assert(unpackedVals.count(origVal) == 0);
     unpackedVals[origVal].append(newVals.begin(), newVals.end());
@@ -381,55 +389,53 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     // well.
     auto vecType = toVectorType(type);
 
-    mlir::Value vec = createPosionVec(vecType);
+    Value vec = createPosionVec(vecType);
     for (auto i : llvm::seq(0u, factor)) {
-      mlir::Value idx = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
-      vec = builder.create<mlir::vector::InsertElementOp>(loc, newVals[i], vec,
-                                                          idx);
+      Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+      vec = builder.create<vector::InsertElementOp>(loc, newVals[i], vec, idx);
     }
     mapping.map(origVal, vec);
   };
 
-  mlir::Value mask;
+  Value mask;
 
   // Contruct mask value and cache it. If not a masked mode mask is always all
   // 1s.
-  auto getMask = [&]() -> mlir::Value {
+  auto getMask = [&]() -> Value {
     if (mask)
       return mask;
 
-    mlir::OpFoldResult maskSize;
+    OpFoldResult maskSize;
     if (masked) {
-      mlir::Value size =
-          builder.create<mlir::arith::MulIOp>(loc, factorVal, newIndexVar);
-      maskSize =
-          builder.create<mlir::arith::SubIOp>(loc, count, size).getResult();
+      Value size = getrIndexVarMult();
+      maskSize = builder.createOrFold<arith::SubIOp>(loc, count, size);
     } else {
       maskSize = builder.getIndexAttr(factor);
     }
     auto vecType = toVectorType(builder.getI1Type());
-    mask = builder.create<mlir::vector::CreateMaskOp>(loc, vecType, maskSize);
+    mask = builder.create<vector::CreateMaskOp>(loc, vecType, maskSize);
 
     return mask;
   };
 
-  mlir::DominanceInfo dom;
-
   auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
-    return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op);
+    return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+  };
+
+  auto canGatherScatter = [&](auto op) {
+    return !!::canGatherScatterImpl(loop, op, DL);
   };
 
   // Get idices for vectorized memref load/store.
-  auto getMemrefVecIndices = [&](mlir::ValueRange indices) {
+  auto getMemrefVecIndices = [&](ValueRange indices) {
     scalarMapping.clear();
     scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
 
-    llvm::SmallVector<mlir::Value> ret(indices.size());
+    llvm::SmallVector<Value> ret(indices.size());
     for (auto &&[i, val] : llvm::enumerate(indices)) {
       if (val == origIndexVar) {
-        mlir::Value idx =
-            builder.create<mlir::arith::MulIOp>(loc, newIndexVar, factorVal);
-        idx = builder.create<mlir::arith::AddIOp>(loc, idx, origLower);
+        Value idx = getrIndexVarMult();
+        idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
         ret[i] = idx;
         continue;
       }
@@ -439,31 +445,19 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     return ret;
   };
 
-  // Check if memref access can be converted into gather/scatter.
-  auto canGatherScatter = [&](auto op) {
-    auto memref = op.getMemRef();
-    auto memrefType = mlir::cast<mlir::MemRefType>(memref.getType());
-    if (!isSupportedVecElem(memrefType.getElementType()))
-      return false;
-
-    return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
-           memrefType.getLayout().isIdentity();
-  };
-
   // Create vectorized memref load for specified non-vectorized load.
   auto genLoad = [&](auto loadOp) {
     auto indices = getMemrefVecIndices(loadOp.getIndices());
     auto resType = toVectorType(loadOp.getResult().getType());
     auto memref = loadOp.getMemRef();
-    mlir::Value vecLoad;
+    Value vecLoad;
     if (masked) {
       auto mask = getMask();
       auto init = createPosionVec(resType);
-      vecLoad = builder.create<mlir::vector::MaskedLoadOp>(loc, resType, memref,
-                                                           indices, mask, init);
+      vecLoad = builder.create<vector::MaskedLoadOp>(loc, resType, memref,
+                                                     indices, mask, init);
     } else {
-      vecLoad =
-          builder.create<mlir::vector::LoadOp>(loc, resType, memref, indices);
+      vecLoad = builder.create<vector::LoadOp>(loc, resType, memref, indices);
     }
     mapping.map(loadOp.getResult(), vecLoad);
   };
@@ -475,18 +469,17 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     auto memref = storeOp.getMemRef();
     if (masked) {
       auto mask = getMask();
-      builder.create<mlir::vector::MaskedStoreOp>(loc, memref, indices, mask,
-                                                  value);
+      builder.create<vector::MaskedStoreOp>(loc, memref, indices, mask, value);
     } else {
-      builder.create<mlir::vector::StoreOp>(loc, value, memref, indices);
+      builder.create<vector::StoreOp>(loc, value, memref, indices);
     }
   };
 
-  llvm::SmallVector<mlir::Value> duplicatedArgs;
-  llvm::SmallVector<mlir::Value> duplicatedResults;
+  llvm::SmallVector<Value> duplicatedArgs;
+  llvm::SmallVector<Value> duplicatedResults;
 
   builder.setInsertionPointToStart(newLoop.getBody());
-  for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+  for (Operation &op : loop.getBody()->without_terminator()) {
     loc = op.getLoc();
     if (isSupportedVectorOp(op)) {
       // If op can be vectorized, clone it with vectorized inputs and  update
@@ -503,7 +496,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
 
     // Vectorize memref load/store ops, vector load/store are preffered over
     // gather/scatter.
-    if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op)) {
+    if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
       if (canTriviallyVectorizeMemOp(loadOp)) {
         genLoad(loadOp);
         continue;
@@ -515,14 +508,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
         auto indexVec = getVecVal(loadOp.getIndices()[0]);
         auto init = createPosionVec(resType);
 
-        auto gather = builder.create<mlir::vector::GatherOp>(
+        auto gather = builder.create<vector::GatherOp>(
             loc, resType, memref, zero, indexVec, mask, init);
         mapping.map(loadOp.getResult(), gather.getResult());
         continue;
       }
     }
 
-    if (auto storeOp = mlir::dyn_cast<mlir::memref::StoreOp>(op)) {
+    if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
       if (canTriviallyVectorizeMemOp(storeOp)) {
         genStore(storeOp);
         continue;
@@ -533,8 +526,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
         auto mask = getMask();
         auto indexVec = getVecVal(storeOp.getIndices()[0]);
 
-        builder.create<mlir::vector::ScatterOp>(loc, memref, zero, indexVec,
-                                                mask, value);
+        builder.create<vector::ScatterOp>(loc, memref, zero, indexVec, mask,
+                                          value);
+        continue;
       }
     }
 
@@ -557,7 +551,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     }
 
     for (auto i : llvm::seq(0u, factor)) {
-      auto args = mlir::ValueRange(duplicatedArgs)
+      auto args = ValueRange(duplicatedArgs)
                       .drop_front(numArgs * i)
                       .take_front(numArgs);
       scalarMapping.map(op.getOperands(), args);
@@ -568,7 +562,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     }
 
     for (auto i : llvm::seq(0u, numResults)) {
-      auto results = mlir::ValueRange(duplicatedResults)
+      auto results = ValueRange(duplicatedResults)
                          .drop_front(factor * i)
                          .take_front(factor);
       setUnpackedVals(op.getResult(i), results);
@@ -576,36 +570,33 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
   }
 
   // Vectorize `scf.reduce` op.
-  auto reduceOp =
-      mlir::cast<mlir::scf::ReduceOp>(loop.getBody()->getTerminator());
-  llvm::SmallVector<mlir::Value> reduceVals;
+  auto reduceOp = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
+  llvm::SmallVector<Value> reduceVals;
   reduceVals.reserve(reduceOp.getNumOperands());
 
   for (auto &&[body, arg] :
        llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) {
     scalarMapping.clear();
-    mlir::Block &reduceBody = body.front();
+    Block &reduceBody = body.front();
     assert(reduceBody.getNumArguments() == 2);
 
-    mlir::Value reduceVal;
+    Value reduceVal;
     if (auto redKind = getReductionKind(reduceBody)) {
       // Generate `vector.reduce` if possible.
-      mlir::Value redArg = getVecVal(arg);
+      Value redArg = getVecVal(arg);
       if (redArg) {
-        auto neutral = mlir::arith::getNeutralElement(&reduceBody.front());
+        auto neutral = arith::getNeutralElement(&reduceBody.front());
         assert(neutral);
-        mlir::Value neutralVal =
-            builder.create<mlir::arith::ConstantOp>(loc, *neutral);
-        mlir::Value neutralVec = builder.create<mlir::vector::SplatOp>(
-            loc, neutralVal, redArg.getType());
+        Value neutralVal = builder.create<arith::ConstantOp>(loc, *neutral);
+        Value neutralVec =
+            builder.create<vector::SplatOp>(loc, neutralVal, redArg.getType());
         auto mask = getMask();
-        redArg = builder.create<mlir::arith::SelectOp>(loc, mask, redArg,
-                                                       neutralVec);
+        redArg = builder.create<arith::SelectOp>(loc, mask, redArg, neutralVec);
       }
 
       auto fmf = getFMF(reduceBody.front());
       reduceVal =
-          builder.create<mlir::vector::ReductionOp>(loc, *redKind, redArg, fmf);
+          builder.create<vector::ReductionOp>(loc, *redKind, redArg, fmf);
     } else {
       if (masked)
         return reduceOp.emitError("Cannot vectorize reduce op in masked mode");
@@ -613,15 +604,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
       // If `vector.reduce` cannot be used, unpack values and reduce them
       // individually.
 
-      auto reduceTerm =
-          mlir::cast<mlir::scf::ReduceReturnOp>(reduceBody.getTerminator());
+      auto reduceTerm = cast<scf::ReduceReturnOp>(reduceBody.getTerminator());
       auto lhs = reduceBody.getArgument(0);
       auto rhs = reduceBody.getArgument(1);
       auto unpacked = getUnpackedVals(arg);
       assert(unpacked.size() == factor);
       reduceVal = unpacked.front();
       for (auto i : llvm::seq(1u, factor)) {
-        mlir::Value val = unpacked[i];
+        Value val = unpacked[i];
         scalarMapping.map(lhs, reduceVal);
         scalarMapping.map(rhs, val);
         for (auto &redOp : reduceBody.without_terminator())
@@ -644,9 +634,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     loop->erase();
   } else {
     builder.setInsertionPoint(loop);
-    mlir::Value newLower =
-        builder.create<mlir::arith::MulIOp>(loc, newCount, factorVal);
-    newLower = builder.create<mlir::arith::AddIOp>(loc, origLower, newLower);
+    Value newLower =
+        builder.createOrFold<arith::MulIOp>(loc, newCount, factorVal);
+    newLower = builder.createOrFold<arith::AddIOp>(loc, origLower, newLower);
 
     auto lowerCopy = llvm::to_vector(loop.getLowerBound());
     lowerCopy[dim] = newLower;
@@ -654,90 +644,5 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
     loop.getInitValsMutable().assign(newLoop.getResults());
   }
 
-  return mlir::success();
-}
-
-static std::optional<unsigned> getVectorLength(mlir::Operation *op) {
-  auto func = op->getParentOfType<mlir::FunctionOpInterface>();
-  if (!func)
-    return std::nullopt;
-
-  auto attr = func->getAttrOfType<mlir::IntegerAttr>("mlir.vector_length");
-  if (!attr)
-    return std::nullopt;
-
-  auto val = attr.getInt();
-  if (val <= 0 || val > std::numeric_limits<unsigned>::max())
-    return std::nullopt;
-
-  return static_cast<unsigned>(val);
+  return success();
 }
-
-namespace {
-struct SCFVectorizePass
-    : public mlir::PassWrapper<SCFVectorizePass, mlir::OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass)
-
-  virtual void
-  getDependentDialects(mlir::DialectRegistry &registry) const override {
-    registry.insert<mlir::arith::ArithDialect>();
-    registry.insert<mlir::scf::SCFDialect>();
-    registry.insert<mlir::ub::UBDialect>();
-    registry.insert<mlir::vector::VectorDialect>();
-  }
-
-  void runOnOperation() override {
-    llvm::SmallVector<
-        std::pair<mlir::scf::ParallelOp, mlir::SCFVectorizeParams>>
-        toVectorize;
-
-    // Simple heuristic: total number of elements processed by vector ops, but
-    // prefer masked mode over non-masked.
-    auto getBenefit = [](const mlir::SCFVectorizeInfo &info) {
-      return info.factor * info.count * (int(info.masked) + 1);
-    };
-
-    getOperation()->walk([&](mlir::scf::ParallelOp loop) {
-      auto len = getVectorLength(loop);
-      if (!len)
-        return;
-
-      std::optional<mlir::SCFVectorizeInfo> best;
-      for (auto dim : llvm::seq(0u, loop.getNumLoops())) {
-        auto info = mlir::getLoopVectorizeInfo(loop, dim, *len);
-        if (!info)
-          continue;
-
-        if (!best) {
-          best = *info;
-          continue;
-        }
-
-        if (getBenefit(*info) > getBenefit(*best))
-          best = *info;
-      }
-
-      if (!best)
-        return;
-
-      toVectorize.emplace_back(
-          loop,
-          mlir::SCFVectorizeParams{best->dim, best->factor, best->masked});
-    });
-
-    if (toVectorize.empty())
-      return markAllAnalysesPreserved();
-
-    mlir::OpBuilder builder(&getContext());
-    for (auto &&[loop, params] : toVectorize) {
-      builder.setInsertionPoint(loop);
-      if (mlir::failed(mlir::vectorizeLoop(builder, loop, params)))
-        return signalPassFailure();
-    }
-  }
-};
-} // namespace
-
-std::unique_ptr<mlir::Pass> mlir::createSCFVectorizePass() {
-  return std::make_unique<SCFVectorizePass>();
-}
\ No newline at end of file
diff --git a/mlir/test/Transforms/test-scf-vectorize.mlir b/mlir/test/Transforms/test-scf-vectorize.mlir
new file mode 100644
index 0000000000000..f4a817b44aa39
--- /dev/null
+++ b/mlir/test/Transforms/test-scf-vectorize.mlir
@@ -0,0 +1,272 @@
+// RUN: mlir-opt %s --test-scf-vectorize=vector-bitwidth=128 -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @test
+//  CHECK-SAME:  (%[[A:.*]]: memref<?xi32>, %[[B:.*]]: memref<?xi32>, %[[C:.*]]: memref<?xi32>) {
+//       CHECK:  %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:  %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?xi32>
+//       CHECK:  %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:  %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index
+//       CHECK:  scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) {
+//       CHECK:  %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+//       CHECK:  %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index
+//       CHECK:  %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1>
+//       CHECK:  %[[P:.*]] = ub.poison : vector<4xi32>
+//       CHECK:  %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xi32>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+//       CHECK:  %[[P:.*]] = ub.poison : vector<4xi32>
+//       CHECK:  %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xi32>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xi32>
+//       CHECK:  vector.maskedstore %[[C]][%1], %[[MASK]], %[[RES]] : memref<?xi32>, vector<4xi1>, vector<4xi32>
+//       CHECK:  scf.reduce
+func.func @test(%A: memref<?xi32>, %B: memref<?xi32>, %C: memref<?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %count = memref.dim %A, %c0 : memref<?xi32>
+  scf.parallel (%i) = (%c0) to (%count) step (%c1) {
+    %1 = memref.load %A[%i] : memref<?xi32>
+    %2 = memref.load %B[%i] : memref<?xi32>
+    %3 =  arith.addi %1, %2 : i32
+    memref.store %3, %C[%i] : memref<?xi32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test
+//  CHECK-SAME:  (%[[A:.*]]: memref<?xindex>, %[[B:.*]]: memref<?xindex>, %[[C:.*]]: memref<?xindex>) {
+//       CHECK:  %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:  %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?xindex>
+//       CHECK:  %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:  %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index
+//       CHECK:  scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) {
+//       CHECK:  %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+//       CHECK:  %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index
+//       CHECK:  %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1>
+//       CHECK:  %[[P:.*]] = ub.poison : vector<4xindex>
+//       CHECK:  %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex>
+//       CHECK:  %[[P:.*]] = ub.poison : vector<4xindex>
+//       CHECK:  %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex>
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xindex>
+//       CHECK:  vector.maskedstore %[[C]][%1], %[[MASK]], %[[RES]] : memref<?xindex>, vector<4xi1>, vector<4xindex>
+//       CHECK:  scf.reduce
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
+func.func @test(%A: memref<?xindex>, %B: memref<?xindex>, %C: memref<?xindex>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %count = memref.dim %A, %c0 : memref<?xindex>
+  scf.parallel (%i) = (%c0) to (%count) step (%c1) {
+    %1 = memref.load %A[%i] : memref<?xindex>
+    %2 = memref.load %B[%i] : memref<?xindex>
+    %3 =  arith.addi %1, %2 : index
+    memref.store %3, %C[%i] : memref<?xindex>
+  }
+  return
+}
+}
+
+// -----
+
+func.func private @non_vectorizable(i32) -> (i32)
+
+// CHECK-LABEL: @test
+//  CHECK-SAME:  (%[[A:.*]]: memref<?xi32>, %[[B:.*]]: memref<?xi32>, %[[C:.*]]: memref<?xi32>) {
+//       CHECK:  %[[C00:.*]] = arith.constant 0 : index
+//       CHECK:  %[[DIM:.*]] = memref.dim %[[A]], %[[C00]] : memref<?xi32>
+//       CHECK:  %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:  %[[COUNT:.*]] = arith.divsi %[[DIM]], %[[C4]] : index
+//       CHECK:  scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) {
+//       CHECK:  %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+//       CHECK:    %[[A_VAL:.*]] = vector.load %[[A]][%[[MULT]]] : memref<?xi32>, vector<4xi32>
+//       CHECK:    %[[B_VAL:.*]] = vector.load %[[B]][%[[MULT]]] : memref<?xi32>, vector<4xi32>
+//       CHECK:    %[[R1:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xi32>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[E0:.*]] = vector.extractelement %[[R1]][%[[C0]] : index] : vector<4xi32>
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[E1:.*]] = vector.extractelement %[[R1]][%[[C1]] : index] : vector<4xi32>
+//       CHECK:    %[[C2:.*]] = arith.constant 2 : index
+//       CHECK:    %[[E2:.*]] = vector.extractelement %[[R1]][%[[C2]] : index] : vector<4xi32>
+//       CHECK:    %[[C3:.*]] = arith.constant 3 : index
+//       CHECK:    %[[E3:.*]] = vector.extractelement %[[R1]][%[[C3]] : index] : vector<4xi32>
+//       CHECK:    %[[R2:.*]] = func.call @non_vectorizable(%[[E0]]) : (i32) -> i32
+//       CHECK:    %[[R3:.*]] = func.call @non_vectorizable(%[[E1]]) : (i32) -> i32
+//       CHECK:    %[[R4:.*]] = func.call @non_vectorizable(%[[E2]]) : (i32) -> i32
+//       CHECK:    %[[R5:.*]] = func.call @non_vectorizable(%[[E3]]) : (i32) -> i32
+//       CHECK:    %[[RES1:.*]] = ub.poison : vector<4xi32>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[RES2:.*]] = vector.insertelement %[[R2]], %[[RES1]][%[[C0]] : index] : vector<4xi32>
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[RES3:.*]] = vector.insertelement %[[R3]], %[[RES2]][%[[C1]] : index] : vector<4xi32>
+//       CHECK:    %[[C2:.*]] = arith.constant 2 : index
+//       CHECK:    %[[RES4:.*]] = vector.insertelement %[[R4]], %[[RES3]][%[[C2]] : index] : vector<4xi32>
+//       CHECK:    %[[C3:.*]] = arith.constant 3 : index
+//       CHECK:    %[[RES5:.*]] = vector.insertelement %[[R5]], %[[RES4]][%[[C3]] : index] : vector<4xi32>
+//       CHECK:    vector.store %[[RES5]], %[[C]][%[[MULT]]] : memref<?xi32>, vector<4xi32>
+//       CHECK:    scf.reduce
+//       CHECK:  }
+//       CHECK:  %[[UB1:.*]] = arith.muli %[[COUNT]], %[[C4]] : index
+//       CHECK:  %[[UB2:.*]] = arith.addi %[[UB1]], %[[C00]] : index
+//       CHECK:  scf.parallel (%[[I:.*]]) = (%[[UB2]]) to (%[[DIM]]) step (%{{.*}}) {
+//       CHECK:    %[[A_VAL:.*]] = memref.load %[[A]][%[[I]]] : memref<?xi32>
+//       CHECK:    %[[B_VAL:.*]] = memref.load %[[B]][%[[I]]] : memref<?xi32>
+//       CHECK:    %[[R1:.*]] = arith.addi %[[A_VAL:.*]], %[[B_VAL:.*]] : i32
+//       CHECK:    %[[R2:.*]] = func.call @non_vectorizable(%[[R1]]) : (i32) -> i32
+//       CHECK:    memref.store %[[R2]], %[[C]][%[[I]]] : memref<?xi32>
+//       CHECK:    scf.reduce
+//       CHECK:  }
+func.func @test(%A: memref<?xi32>, %B: memref<?xi32>, %C: memref<?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %count = memref.dim %A, %c0 : memref<?xi32>
+  scf.parallel (%i) = (%c0) to (%count) step (%c1) {
+    %1 = memref.load %A[%i] : memref<?xi32>
+    %2 = memref.load %B[%i] : memref<?xi32>
+    %3 =  arith.addi %1, %2 : i32
+    %4 = func.call @non_vectorizable(%3) : (i32) -> (i32)
+    memref.store %4, %C[%i] : memref<?xi32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test
+//  CHECK-SAME:  (%[[A:.*]]: memref<?xindex>, %[[B:.*]]: memref<?xindex>, %[[C:.*]]: memref<?xindex>) {
+//       CHECK:  %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:  %[[C2:.*]] = arith.constant 2 : index
+//       CHECK:  %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?xindex>
+//       CHECK:  %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:  %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index
+//       CHECK:  scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) {
+//       CHECK:  %[[OFFSETS:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+//       CHECK:  %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+//       CHECK:  %[[O1:.*]] = vector.splat %[[MULT]] : vector<4xindex>
+//       CHECK:  %[[O2:.*]] = arith.addi %[[O1]], %[[OFFSETS]] : vector<4xindex>
+//       CHECK:  %[[O3:.*]] = vector.splat %[[C2]] : vector<4xindex>
+//       CHECK:  %[[O4:.*]] = arith.muli %[[O2]], %[[O3]] : vector<4xindex>
+//       CHECK:  %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index
+//       CHECK:  %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1>
+//       CHECK:  %[[P:.*]] = ub.poison : vector<4xindex>
+//       CHECK:  %[[A_VAL:.*]] = vector.gather %arg0[%{{.*}}] [%[[O4]]], %[[MASK]], %[[P]] : memref<?xindex>, vector<4xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex>
+//       CHECK:  %[[P:.*]] = ub.poison : vector<4xindex>
+//       CHECK:  %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex>
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xindex>
+//       CHECK:  vector.scatter %[[C]][%{{.*}}] [%[[O4]]], %[[MASK]], %[[RES]] : memref<?xindex>, vector<4xindex>, vector<4xi1>, vector<4xindex>
+//       CHECK:  scf.reduce
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
+func.func @test(%A: memref<?xindex>, %B: memref<?xindex>, %C: memref<?xindex>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %count = memref.dim %A, %c0 : memref<?xindex>
+  scf.parallel (%i) = (%c0) to (%count) step (%c1) {
+    %0 = arith.muli %i, %c2 : index
+    %1 = memref.load %A[%0] : memref<?xindex>
+    %2 = memref.load %B[%i] : memref<?xindex>
+    %3 =  arith.addi %1, %2 : index
+    memref.store %3, %C[%0] : memref<?xindex>
+  }
+  return
+}
+}
+
+// -----
+
+// CHECK-LABEL: @test
+//  CHECK-SAME:  (%[[A:.*]]: memref<?xf32>) -> f32 {
+//       CHECK:  %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:  %[[INIT:.*]] = arith.constant 0.0{{.*}} : f32
+//       CHECK:  %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?xf32>
+//       CHECK:  %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:  %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index
+//       CHECK:  %[[RES:.*]] = scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) init (%[[INIT]]) -> f32 {
+//       CHECK:  %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+//       CHECK:  %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index
+//       CHECK:  %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1>
+//       CHECK:  %[[P:.*]] = ub.poison : vector<4xf32>
+//       CHECK:  %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+//       CHECK:  %[[N:.*]] = arith.constant 0.0{{.*}} : f32
+//       CHECK:  %[[N_SPLAT:.*]] = vector.splat %[[N]] : vector<4xf32>
+//       CHECK:  %[[RED1:.*]] = arith.select %[[MASK]], %[[A_VAL]], %[[N_SPLAT]] : vector<4xi1>, vector<4xf32>
+//       CHECK:  %[[RED2:.*]] = vector.reduction <add>, %[[RED1]] : vector<4xf32> into f32
+//       CHECK:  scf.reduce(%[[RED2]] : f32) {
+//       CHECK:  ^bb0(%[[R_ARG1:.*]]: f32, %[[R_ARG2:.*]]: f32):
+//       CHECK:    %[[R:.*]] = arith.addf %[[R_ARG1]], %[[R_ARG2]] : f32
+//       CHECK:    scf.reduce.return %[[R]] : f32
+//       CHECK:  }
+//       CHECK:  return %[[RES]] : f32
+func.func @test(%A: memref<?xf32>) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init = arith.constant 0.0 : f32
+  %count = memref.dim %A, %c0 : memref<?xf32>
+  %res = scf.parallel (%i) = (%c0) to (%count) step (%c1) init (%init) -> f32 {
+    %1 = memref.load %A[%i] : memref<?xf32>
+    scf.reduce(%1 : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %2 = arith.addf %lhs, %rhs : f32
+      scf.reduce.return %2 : f32
+    }
+  }
+  return %res : f32
+}
+
+
+// -----
+
+func.func private @combine(f32, f32) -> (f32)
+
+// CHECK-LABEL: @test
+//  CHECK-SAME:  (%[[A:.*]]: memref<?xf32>) -> f32 {
+//       CHECK:  %[[C00:.*]] = arith.constant 0 : index
+//       CHECK:  %[[INIT:.*]] = arith.constant 0.0{{.*}} : f32
+//       CHECK:  %[[DIM:.*]] = memref.dim %[[A]], %[[C00]] : memref<?xf32>
+//       CHECK:  %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:  %[[COUNT:.*]] = arith.divsi %[[DIM]], %[[C4]] : index
+//       CHECK:  %[[RES:.*]] = scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) init (%[[INIT]]) -> f32 {
+//       CHECK:    %[[MULT:.*]] = arith.muli %arg1, %c4 : index
+//       CHECK:    %[[A_VAL:.*]] = vector.load %[[A]][%[[MULT]]] : memref<?xf32>, vector<4xf32>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[E0:.*]] = vector.extractelement %[[A_VAL]][%[[C0]] : index] : vector<4xf32>
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[E1:.*]] = vector.extractelement %[[A_VAL]][%[[C1]] : index] : vector<4xf32>
+//       CHECK:    %[[C2:.*]] = arith.constant 2 : index
+//       CHECK:    %[[E2:.*]] = vector.extractelement %[[A_VAL]][%[[C2]] : index] : vector<4xf32>
+//       CHECK:    %[[C3:.*]] = arith.constant 3 : index
+//       CHECK:    %[[E3:.*]] = vector.extractelement %[[A_VAL]][%[[C3]] : index] : vector<4xf32>
+//       CHECK:    %[[R0:.*]] = func.call @combine(%[[E0]], %[[E1]]) : (f32, f32) -> f32
+//       CHECK:    %[[R1:.*]] = func.call @combine(%[[R0]], %[[E2]]) : (f32, f32) -> f32
+//       CHECK:    %[[R2:.*]] = func.call @combine(%[[R1]], %[[E3]]) : (f32, f32) -> f32
+//       CHECK:    scf.reduce(%[[R2]] : f32) {
+//       CHECK:    ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:      %[[R:.*]] = func.call @combine(%[[LHS]], %[[RHS]]) : (f32, f32) -> f32
+//       CHECK:      scf.reduce.return %[[R]] : f32
+//       CHECK:    }
+//       CHECK:  }
+//       CHECK:  %[[UB1:.*]] = arith.muli %[[COUNT]], %[[C4]] : index
+//       CHECK:  %[[UB2:.*]] = arith.addi %[[UB1]], %[[C00]] : index
+//       CHECK:  %[[RES1:.*]] = scf.parallel (%[[I:.*]]) = (%[[UB2]]) to (%[[DIM]]) step (%{{.*}}) init (%[[RES]]) -> f32 {
+//       CHECK:    %[[A_VAL:.*]] = memref.load %[[A]][%[[I]]] : memref<?xf32>
+//       CHECK:    scf.reduce(%[[A_VAL]] : f32) {
+//       CHECK:    ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+//       CHECK:      %[[R:.*]] = func.call @combine(%[[LHS]], %[[RHS]]) : (f32, f32) -> f32
+//       CHECK:      scf.reduce.return %[[R]] : f32
+//       CHECK:    }
+//       CHECK:  }
+//       CHECK:  return %[[RES1]] : f32
+func.func @test(%A: memref<?xf32>) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %init = arith.constant 0.0 : f32
+  %count = memref.dim %A, %c0 : memref<?xf32>
+  %res = scf.parallel (%i) = (%c0) to (%count) step (%c1) init (%init) -> f32 {
+    %1 = memref.load %A[%i] : memref<?xf32>
+    scf.reduce(%1 : f32) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %2 = func.call @combine(%lhs, %rhs) : (f32, f32) -> (f32)
+      scf.reduce.return %2 : f32
+    }
+  }
+  return %res : f32
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 975a41ac3d5fe..01c92199b6f3a 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_library(MLIRTestTransforms
   TestInlining.cpp
   TestIntRangeInference.cpp
   TestMakeIsolatedFromAbove.cpp
+  TestSCFVectorize.cpp
   ${MLIRTestTransformsPDLSrc}
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Transforms/TestSCFVectorize.cpp b/mlir/test/lib/Transforms/TestSCFVectorize.cpp
new file mode 100644
index 0000000000000..84ea190a33de2
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestSCFVectorize.cpp
@@ -0,0 +1,110 @@
+//===- TestSCFVectorize.cpp - SCF vectorization test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/SCFVectorize.h"
+
+#include "mlir/Analysis/DataLayoutAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+namespace {
+struct TestSCFVectorizePass
+    : public PassWrapper<TestSCFVectorizePass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFVectorizePass)
+
+  TestSCFVectorizePass() = default;
+  TestSCFVectorizePass(const TestSCFVectorizePass &pass) : PassWrapper(pass) {}
+
+  Option<unsigned> vectorBitwidth{*this, "vector-bitwidth",
+                                  llvm::cl::desc("Target vector bitwidth "),
+                                  llvm::cl::init(128)};
+
+  StringRef getArgument() const final { return "test-scf-vectorize"; }
+  StringRef getDescription() const final { return "Test SCF vectorization"; }
+
+  virtual void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect>();
+    registry.insert<scf::SCFDialect>();
+    registry.insert<ub::UBDialect>();
+    registry.insert<vector::VectorDialect>();
+  }
+
+  LogicalResult initializeOptions(
+      StringRef options,
+      function_ref<LogicalResult(const Twine &)> errorHandler) override {
+    if (failed(PassWrapper::initializeOptions(options, errorHandler)))
+      return failure();
+
+    if (vectorBitwidth <= 0)
+      return errorHandler("Invalid vector bitwidth: " +
+                          llvm::Twine(vectorBitwidth));
+
+    return success();
+  }
+
+  void runOnOperation() override {
+    Operation *root = getOperation();
+    auto &DLAnalysis = getAnalysis<DataLayoutAnalysis>();
+
+    llvm::SmallVector<std::pair<scf::ParallelOp, SCFVectorizeParams>>
+        toVectorize;
+
+    // Simple heuristic: total number of elements processed by vector ops, but
+    // prefer masked mode over non-masked.
+    auto getBenefit = [](const SCFVectorizeInfo &info) {
+      return info.factor * info.count * (int(info.masked) + 1);
+    };
+
+    root->walk([&](scf::ParallelOp loop) {
+      const DataLayout &DL = DLAnalysis.getAbove(loop);
+      std::optional<SCFVectorizeInfo> best;
+      for (auto dim : llvm::seq(0u, loop.getNumLoops())) {
+        auto info = getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL);
+        if (!info)
+          continue;
+
+        if (!best) {
+          best = *info;
+          continue;
+        }
+
+        if (getBenefit(*info) > getBenefit(*best))
+          best = *info;
+      }
+
+      if (!best)
+        return;
+
+      toVectorize.emplace_back(
+          loop, SCFVectorizeParams{best->dim, best->factor, best->masked});
+    });
+
+    if (toVectorize.empty())
+      return markAllAnalysesPreserved();
+
+    for (auto &&[loop, params] : toVectorize) {
+      const DataLayout &DL = DLAnalysis.getAbove(loop);
+      if (failed(vectorizeLoop(loop, params, &DL)))
+        return signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTesSCFVectorize() { PassRegistration<TestSCFVectorizePass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 0e8b161d51345..e2b30ed0626a3 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,7 +68,6 @@ void registerTosaTestQuantUtilAPIPass();
 void registerVectorizerTestPass();
 
 namespace test {
-void registerTestCompositePass();
 void registerCommutativityUtils();
 void registerConvertCallOpPass();
 void registerConvertFuncOpPass();
@@ -76,13 +75,17 @@ void registerInliner();
 void registerMemRefBoundCheck();
 void registerPatternsTestPass();
 void registerSimpleParametricTilingPass();
+void registerTesSCFVectorize();
 void registerTestAffineLoopParametricTilingPass();
-void registerTestArithEmulateWideIntPass();
 void registerTestAliasAnalysisPass();
+void registerTestArithEmulateWideIntPass();
 void registerTestBuiltinAttributeInterfaces();
 void registerTestBuiltinDistinctAttributes();
 void registerTestCallGraphPass();
 void registerTestCfAssertPass();
+void registerTestCFGLoopInfoPass();
+void registerTestComposeSubView();
+void registerTestCompositePass();
 void registerTestConstantFold();
 void registerTestControlFlowSink();
 void registerTestDataLayoutPropagation();
@@ -95,12 +98,10 @@ void registerTestDynamicPipelinePass();
 void registerTestEmulateNarrowTypePass();
 void registerTestExpandMathPass();
 void registerTestFooAnalysisPass();
-void registerTestComposeSubView();
-void registerTestMultiBuffering();
-void registerTestIntRangeInference();
-void registerTestIRVisitorsPass();
 void registerTestGenericIRVisitorsPass();
 void registerTestInterfaces();
+void registerTestIntRangeInference();
+void registerTestIRVisitorsPass();
 void registerTestLastModifiedPass();
 void registerTestLinalgDecomposeOps();
 void registerTestLinalgDropUnitDims();
@@ -110,7 +111,6 @@ void registerTestLinalgTransforms();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
 void registerTestLoopFusion();
-void registerTestCFGLoopInfoPass();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
 void registerTestLowerToArmNeon();
@@ -123,12 +123,14 @@ void registerTestMathPolynomialApproximationPass();
 void registerTestMathToVCIXPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
-void registerTestMeshSimplificationsPass();
 void registerTestMeshReshardingSpmdizationPass();
-void registerTestOpLoweringPasses();
+void registerTestMeshSimplificationsPass();
+void registerTestMultiBuffering();
 void registerTestNextAccessPass();
+void registerTestNVGPULowerings();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
+void registerTestOpLoweringPasses();
 void registerTestPadFusion();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUpliftWhileToFor();
@@ -141,10 +143,9 @@ void registerTestTensorCopyInsertionPass();
 void registerTestTensorTransforms();
 void registerTestTopologicalSortAnalysisPass();
 void registerTestTransformDialectEraseSchedulePass();
-void registerTestWrittenToPass();
 void registerTestVectorLowerings();
 void registerTestVectorReductionToSPIRVDotProd();
-void registerTestNVGPULowerings();
+void registerTestWrittenToPass();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 void registerTestDialectConversionPasses();
 void registerTestPDLByteCodePass();
@@ -197,7 +198,6 @@ void registerTestPasses() {
   registerVectorizerTestPass();
   registerTosaTestQuantUtilAPIPass();
 
-  mlir::test::registerTestCompositePass();
   mlir::test::registerCommutativityUtils();
   mlir::test::registerConvertCallOpPass();
   mlir::test::registerConvertFuncOpPass();
@@ -205,6 +205,7 @@ void registerTestPasses() {
   mlir::test::registerMemRefBoundCheck();
   mlir::test::registerPatternsTestPass();
   mlir::test::registerSimpleParametricTilingPass();
+  mlir::test::registerTesSCFVectorize();
   mlir::test::registerTestAffineLoopParametricTilingPass();
   mlir::test::registerTestAliasAnalysisPass();
   mlir::test::registerTestArithEmulateWideIntPass();
@@ -212,24 +213,25 @@ void registerTestPasses() {
   mlir::test::registerTestBuiltinDistinctAttributes();
   mlir::test::registerTestCallGraphPass();
   mlir::test::registerTestCfAssertPass();
+  mlir::test::registerTestCFGLoopInfoPass();
+  mlir::test::registerTestComposeSubView();
+  mlir::test::registerTestCompositePass();
   mlir::test::registerTestConstantFold();
   mlir::test::registerTestControlFlowSink();
-  mlir::test::registerTestDiagnosticsPass();
-  mlir::test::registerTestDecomposeCallGraphTypes();
   mlir::test::registerTestDataLayoutPropagation();
   mlir::test::registerTestDataLayoutQuery();
   mlir::test::registerTestDeadCodeAnalysisPass();
+  mlir::test::registerTestDecomposeCallGraphTypes();
+  mlir::test::registerTestDiagnosticsPass();
   mlir::test::registerTestDominancePass();
   mlir::test::registerTestDynamicPipelinePass();
   mlir::test::registerTestEmulateNarrowTypePass();
   mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestFooAnalysisPass();
-  mlir::test::registerTestComposeSubView();
-  mlir::test::registerTestMultiBuffering();
-  mlir::test::registerTestIntRangeInference();
-  mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();
+  mlir::test::registerTestIntRangeInference();
+  mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestLastModifiedPass();
   mlir::test::registerTestLinalgDecomposeOps();
   mlir::test::registerTestLinalgDropUnitDims();
@@ -239,7 +241,6 @@ void registerTestPasses() {
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();
   mlir::test::registerTestLoopFusion();
-  mlir::test::registerTestCFGLoopInfoPass();
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();
   mlir::test::registerTestLowerToArmNeon();
@@ -252,12 +253,14 @@ void registerTestPasses() {
   mlir::test::registerTestMathToVCIXPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
-  mlir::test::registerTestOpLoweringPasses();
-  mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestMeshReshardingSpmdizationPass();
+  mlir::test::registerTestMeshSimplificationsPass();
+  mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestNextAccessPass();
+  mlir::test::registerTestNVGPULowerings();
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();
+  mlir::test::registerTestOpLoweringPasses();
   mlir::test::registerTestPadFusion();
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUpliftWhileToFor();
@@ -272,7 +275,6 @@ void registerTestPasses() {
   mlir::test::registerTestTransformDialectEraseSchedulePass();
   mlir::test::registerTestVectorLowerings();
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
-  mlir::test::registerTestNVGPULowerings();
   mlir::test::registerTestWrittenToPass();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
   mlir::test::registerTestDialectConversionPasses();

>From cd50c6ab6621ea4d505d7a92e4bd275734557e10 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 12:19:00 +0200
Subject: [PATCH 04/10] fix typo

---
 mlir/lib/Transforms/SCFVectorize.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
index 29e184e584a56..c74cfa4abf80d 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -66,7 +66,7 @@ static bool isRangePermutation(ValueRange val1, ValueRange val2) {
 
 template <typename Op>
 static std::optional<unsigned>
-cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
                                const DataLayout *DL) {
   auto loopIndexVars = loop.getInductionVars();
   assert(dim < loopIndexVars.size());
@@ -97,14 +97,14 @@ cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
 /// Returns memref element bitwidth or `std::nullopt` if access cannot be
 /// vectorized.
 static std::optional<unsigned>
-cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
+canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
                            const DataLayout *DL) {
   assert(dim < loop.getInductionVars().size());
   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
-    return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
+    return canTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
 
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
-    return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
+    return canTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
 
   return std::nullopt;
 }
@@ -141,7 +141,7 @@ canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
 static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
                                                     unsigned dim, Operation &op,
                                                     const DataLayout *DL) {
-  if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL))
+  if (auto w = canTriviallyVectorizeMemOp(loop, dim, op, DL))
     return w;
 
   return canGatherScatter(loop, op, DL);
@@ -419,7 +419,7 @@ LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
   };
 
   auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
-    return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+    return !!::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
   };
 
   auto canGatherScatter = [&](auto op) {

>From 27a11b3ce9a53db82e9f327f9db15c821e47610b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 12:20:14 +0200
Subject: [PATCH 05/10] use has_value()

---
 mlir/lib/Transforms/SCFVectorize.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
index c74cfa4abf80d..7b907b655976a 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -419,11 +419,11 @@ LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
   };
 
   auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
-    return !!::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+    return ::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL).has_value();
   };
 
   auto canGatherScatter = [&](auto op) {
-    return !!::canGatherScatterImpl(loop, op, DL);
+    return ::canGatherScatterImpl(loop, op, DL).has_value();
   };
 
   // Get idices for vectorized memref load/store.

>From 4a8bc62e0e859a5231c90d93f0a1494b1454d4d0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 12:44:54 +0200
Subject: [PATCH 06/10] move files to scf dialect

---
 .../{ => Dialect/SCF}/Transforms/SCFVectorize.h    |  5 ++---
 mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt     |  3 ++-
 .../{ => Dialect/SCF}/Transforms/SCFVectorize.cpp  | 14 +++++++-------
 mlir/lib/Transforms/CMakeLists.txt                 |  1 -
 .../SCF}/test-scf-vectorize.mlir                   |  0
 mlir/test/lib/Dialect/SCF/CMakeLists.txt           |  1 +
 .../SCF}/TestSCFVectorize.cpp                      | 14 +++++++-------
 mlir/test/lib/Transforms/CMakeLists.txt            |  1 -
 8 files changed, 19 insertions(+), 20 deletions(-)
 rename mlir/include/mlir/{ => Dialect/SCF}/Transforms/SCFVectorize.h (98%)
 rename mlir/lib/{ => Dialect/SCF}/Transforms/SCFVectorize.cpp (97%)
 rename mlir/test/{Transforms => Dialect/SCF}/test-scf-vectorize.mlir (100%)
 rename mlir/test/lib/{Transforms => Dialect/SCF}/TestSCFVectorize.cpp (87%)

diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h
similarity index 98%
rename from mlir/include/mlir/Transforms/SCFVectorize.h
rename to mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h
index d2a5e3085ae37..ebaa1edac531c 100644
--- a/mlir/include/mlir/Transforms/SCFVectorize.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h
@@ -17,9 +17,7 @@ struct LogicalResult;
 namespace scf {
 class ParallelOp;
 }
-} // namespace mlir
-
-namespace mlir {
+namespace scf {
 
 /// Loop vectorization info
 struct SCFVectorizeInfo {
@@ -65,6 +63,7 @@ struct SCFVectorizeParams {
 mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop,
                                   const SCFVectorizeParams &params,
                                   const DataLayout *DL = nullptr);
+} // namespace scf
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e7671c9cc28f8..8b4c29e0493c1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -12,10 +12,11 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
+  SCFVectorize.cpp
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
-  WrapInZeroTripCheck.cpp
   UpliftWhileToFor.cpp
+  WrapInZeroTripCheck.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
similarity index 97%
rename from mlir/lib/Transforms/SCFVectorize.cpp
rename to mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index 7b907b655976a..7bc7fa544f286 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Transforms/SCFVectorize.h"
+#include "mlir/Dialect/SCF/Transforms/SCFVectorize.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
@@ -157,9 +157,9 @@ static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
   return linalg::getCombinerOpKind(&body.front());
 }
 
-std::optional<SCFVectorizeInfo>
-mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
-                           unsigned vectorBitwidth, const DataLayout *DL) {
+std::optional<scf::SCFVectorizeInfo>
+mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
+                                unsigned vectorBitwidth, const DataLayout *DL) {
   assert(dim < loop.getStep().size());
   assert(vectorBitwidth > 0);
   unsigned factor = vectorBitwidth / 8;
@@ -234,9 +234,9 @@ static arith::FastMathFlags getFMF(Operation &op) {
   return arith::FastMathFlags::none;
 }
 
-LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
-                                  const SCFVectorizeParams &params,
-                                  const DataLayout *DL) {
+LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
+                                       const scf::SCFVectorizeParams &params,
+                                       const DataLayout *DL) {
   auto dim = params.dim;
   auto factor = params.factor;
   auto masked = params.masked;
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index ed71c73c938ed..90c0298fb5e46 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,7 +14,6 @@ add_mlir_library(MLIRTransforms
   PrintIR.cpp
   RemoveDeadValues.cpp
   SCCP.cpp
-  SCFVectorize.cpp
   SROA.cpp
   StripDebugInfo.cpp
   SymbolDCE.cpp
diff --git a/mlir/test/Transforms/test-scf-vectorize.mlir b/mlir/test/Dialect/SCF/test-scf-vectorize.mlir
similarity index 100%
rename from mlir/test/Transforms/test-scf-vectorize.mlir
rename to mlir/test/Dialect/SCF/test-scf-vectorize.mlir
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 792430cc84b65..9af1459d17df9 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
   TestSCFUtils.cpp
+  TestSCFVectorize.cpp
   TestSCFWrapInZeroTripCheck.cpp
   TestUpliftWhileToFor.cpp
   TestWhileOpBuilder.cpp
diff --git a/mlir/test/lib/Transforms/TestSCFVectorize.cpp b/mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp
similarity index 87%
rename from mlir/test/lib/Transforms/TestSCFVectorize.cpp
rename to mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp
index 84ea190a33de2..3f92dd03438bd 100644
--- a/mlir/test/lib/Transforms/TestSCFVectorize.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Transforms/SCFVectorize.h"
+#include "mlir/Dialect/SCF/Transforms/SCFVectorize.h"
 
 #include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -58,20 +58,20 @@ struct TestSCFVectorizePass
     Operation *root = getOperation();
     auto &DLAnalysis = getAnalysis<DataLayoutAnalysis>();
 
-    llvm::SmallVector<std::pair<scf::ParallelOp, SCFVectorizeParams>>
+    llvm::SmallVector<std::pair<scf::ParallelOp, scf::SCFVectorizeParams>>
         toVectorize;
 
     // Simple heuristic: total number of elements processed by vector ops, but
     // prefer masked mode over non-masked.
-    auto getBenefit = [](const SCFVectorizeInfo &info) {
+    auto getBenefit = [](const scf::SCFVectorizeInfo &info) {
       return info.factor * info.count * (int(info.masked) + 1);
     };
 
     root->walk([&](scf::ParallelOp loop) {
       const DataLayout &DL = DLAnalysis.getAbove(loop);
-      std::optional<SCFVectorizeInfo> best;
+      std::optional<scf::SCFVectorizeInfo> best;
       for (auto dim : llvm::seq(0u, loop.getNumLoops())) {
-        auto info = getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL);
+        auto info = scf::getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL);
         if (!info)
           continue;
 
@@ -88,7 +88,7 @@ struct TestSCFVectorizePass
         return;
 
       toVectorize.emplace_back(
-          loop, SCFVectorizeParams{best->dim, best->factor, best->masked});
+          loop, scf::SCFVectorizeParams{best->dim, best->factor, best->masked});
     });
 
     if (toVectorize.empty())
@@ -96,7 +96,7 @@ struct TestSCFVectorizePass
 
     for (auto &&[loop, params] : toVectorize) {
       const DataLayout &DL = DLAnalysis.getAbove(loop);
-      if (failed(vectorizeLoop(loop, params, &DL)))
+      if (failed(scf::vectorizeLoop(loop, params, &DL)))
         return signalPassFailure();
     }
   }
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 01c92199b6f3a..975a41ac3d5fe 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,7 +26,6 @@ add_mlir_library(MLIRTestTransforms
   TestInlining.cpp
   TestIntRangeInference.cpp
   TestMakeIsolatedFromAbove.cpp
-  TestSCFVectorize.cpp
   ${MLIRTestTransformsPDLSrc}
 
   EXCLUDE_FROM_LIBMLIR

>From ee67f7e26a836dc3f32b7a10842140f185c54138 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 19:10:39 +0200
Subject: [PATCH 07/10] getTypeBitWidth std::optional

---
 .../Dialect/SCF/Transforms/SCFVectorize.cpp   | 46 +++++++++++--------
 1 file changed, 28 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index 7bc7fa544f286..7e7add925dbb9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -20,11 +20,12 @@ using namespace mlir;
 
 static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
 
-/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// Return type bitwidth for vectorization purposes or empty if type cannot be
 /// vectorized.
-static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+static std::optional<unsigned> getTypeBitWidth(Type type,
+                                               const DataLayout *DL) {
   if (!isSupportedVecElem(type))
-    return 0;
+    return std::nullopt;
 
   if (DL)
     return DL->getTypeSizeInBits(type);
@@ -32,16 +33,21 @@ static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
   if (type.isIntOrFloat())
     return type.getIntOrFloatBitWidth();
 
-  return 0;
+  return std::nullopt;
 }
 
-static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
+static std::optional<unsigned> getArgsTypeWidth(Operation &op,
+                                                const DataLayout *DL) {
   unsigned ret = 0;
-  for (auto arg : op.getOperands())
-    ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
+  for (auto r : {ValueRange(op.getOperands()), ValueRange(op.getResults())}) {
+    for (auto arg : op.getOperands()) {
+      std::optional<unsigned> w = getTypeBitWidth(arg.getType(), DL);
+      if (!w)
+        return std::nullopt;
 
-  for (auto res : op.getResults())
-    ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
+      ret = std::max(ret, *w);
+    }
+  }
 
   return ret;
 }
@@ -72,8 +78,8 @@ canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
   assert(dim < loopIndexVars.size());
   auto memref = memOp.getMemRef();
   auto type = cast<MemRefType>(memref.getType());
-  auto width = getTypeBitWidth(type.getElementType(), DL);
-  if (width == 0)
+  std::optional<unsigned> width = getTypeBitWidth(type.getElementType(), DL);
+  if (!width)
     return std::nullopt;
 
   if (!type.getLayout().isIdentity())
@@ -114,13 +120,17 @@ static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
                                                     const DataLayout *DL) {
   auto memref = op.getMemRef();
   auto memrefType = cast<MemRefType>(memref.getType());
-  auto width = getTypeBitWidth(memrefType.getElementType(), DL);
-  if (width == 0)
+  std::optional<unsigned> width =
+      getTypeBitWidth(memrefType.getElementType(), DL);
+  if (!width)
     return std::nullopt;
 
   DominanceInfo dom;
-  return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
-         memrefType.getLayout().isIdentity();
+  if (!dom.properlyDominates(memref, loop) || op.getIndices().size() != 1 ||
+      !memrefType.getLayout().isIdentity())
+    return std::nullopt;
+
+  return width;
 }
 
 // Check if memref access can be converted into gather/scatter.
@@ -206,11 +216,11 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
       continue;
     }
 
-    auto width = getArgsTypeWidth(op, DL);
-    if (width == 0)
+    std::optional<unsigned> width = getArgsTypeWidth(op, DL);
+    if (!width)
       return std::nullopt;
 
-    auto newFactor = vectorBitwidth / width;
+    auto newFactor = vectorBitwidth / *width;
     if (newFactor <= 1)
       continue;
 

>From 7c8e1641c6ac04116d076946cbfba1ed5e5264af Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 19:17:05 +0200
Subject: [PATCH 08/10] update assert messages

---
 .../Dialect/SCF/Transforms/SCFVectorize.cpp   | 30 +++++++++----------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index 7e7add925dbb9..536efc72a0305 100644
--- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -75,7 +75,7 @@ static std::optional<unsigned>
 canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
                                const DataLayout *DL) {
   auto loopIndexVars = loop.getInductionVars();
-  assert(dim < loopIndexVars.size());
+  assert(dim < loopIndexVars.size() && "Invalid loop dimension");
   auto memref = memOp.getMemRef();
   auto type = cast<MemRefType>(memref.getType());
   std::optional<unsigned> width = getTypeBitWidth(type.getElementType(), DL);
@@ -105,7 +105,7 @@ canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
 static std::optional<unsigned>
 canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
                            const DataLayout *DL) {
-  assert(dim < loop.getInductionVars().size());
+  assert(dim < loop.getInductionVars().size() && "Invalid loop dimension");
   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
     return canTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
 
@@ -170,8 +170,8 @@ static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
 std::optional<scf::SCFVectorizeInfo>
 mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
                                 unsigned vectorBitwidth, const DataLayout *DL) {
-  assert(dim < loop.getStep().size());
-  assert(vectorBitwidth > 0);
+  assert(dim < loop.getStep().size() && "Invalid loop dimension");
+  assert(vectorBitwidth > 0 && "Invalid vector bitwidth");
   unsigned factor = vectorBitwidth / 8;
   if (factor <= 1)
     return std::nullopt;
@@ -250,9 +250,9 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
   auto dim = params.dim;
   auto factor = params.factor;
   auto masked = params.masked;
-  assert(dim < loop.getStep().size());
-  assert(factor > 1);
-  assert(isConstantIntValue(loop.getStep()[dim], 1));
+  assert(dim < loop.getStep().size() && "Invalid loop dimension");
+  assert(factor > 1 && "Invalid vectorize factor");
+  assert(isConstantIntValue(loop.getStep()[dim], 1) && "Loop stepust be 1");
 
   OpBuilder builder(loop);
   auto lower = llvm::to_vector(loop.getLowerBound());
@@ -332,7 +332,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
       return vec;
     }
     auto type = orig.getType();
-    assert(isSupportedVecElem(type));
+    assert(isSupportedVecElem(type) && "Unsupported vector element type");
 
     Value val = orig;
     auto origIndexVars = loop.getInductionVars();
@@ -367,7 +367,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
     // cache and not handled here.
 
     auto &ret = unpackedVals[val];
-    assert(ret.empty());
+    assert(ret.empty() && "Invalid unpackedVals state");
     if (!isSupportedVecElem(val.getType())) {
       // Non vectorizable value, it must be a value defined outside the loop,
       // just replicate it.
@@ -387,8 +387,8 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
 
   // Add unpacked values to the cache.
   auto setUnpackedVals = [&](Value origVal, ValueRange newVals) {
-    assert(newVals.size() == factor);
-    assert(unpackedVals.count(origVal) == 0);
+    assert(newVals.size() == factor && "Invalid values count");
+    assert(unpackedVals.count(origVal) == 0 && "Invalid unpackedVals state");
     unpackedVals[origVal].append(newVals.begin(), newVals.end());
 
     auto type = origVal.getType();
@@ -555,7 +555,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
 
     for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) {
       auto unpacked = getUnpackedVals(arg);
-      assert(unpacked.size() == factor);
+      assert(unpacked.size() == factor && "Invalid unpacked size");
       for (auto j : llvm::seq(0u, factor))
         duplicatedArgs[j * numArgs + i] = unpacked[j];
     }
@@ -588,7 +588,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
        llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) {
     scalarMapping.clear();
     Block &reduceBody = body.front();
-    assert(reduceBody.getNumArguments() == 2);
+    assert(reduceBody.getNumArguments() == 2 && "Malformed scf.reduce");
 
     Value reduceVal;
     if (auto redKind = getReductionKind(reduceBody)) {
@@ -596,7 +596,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
       Value redArg = getVecVal(arg);
       if (redArg) {
         auto neutral = arith::getNeutralElement(&reduceBody.front());
-        assert(neutral);
+        assert(neutral && "getNeutralElement has unepectedly failed");
         Value neutralVal = builder.create<arith::ConstantOp>(loc, *neutral);
         Value neutralVec =
             builder.create<vector::SplatOp>(loc, neutralVal, redArg.getType());
@@ -618,7 +618,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
       auto lhs = reduceBody.getArgument(0);
       auto rhs = reduceBody.getArgument(1);
       auto unpacked = getUnpackedVals(arg);
-      assert(unpacked.size() == factor);
+      assert(unpacked.size() == factor && "Invalid unpacked size");
       reduceVal = unpacked.front();
       for (auto i : llvm::seq(1u, factor)) {
         Value val = unpacked[i];

>From ef70123ed930eec093be43be74036598ac4c8dff Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 19:42:01 +0200
Subject: [PATCH 09/10] remove auto

---
 .../Dialect/SCF/Transforms/SCFVectorize.cpp   | 139 +++++++++---------
 1 file changed, 69 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index 536efc72a0305..b7d1281fb20ca 100644
--- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -40,7 +40,7 @@ static std::optional<unsigned> getArgsTypeWidth(Operation &op,
                                                 const DataLayout *DL) {
   unsigned ret = 0;
   for (auto r : {ValueRange(op.getOperands()), ValueRange(op.getResults())}) {
-    for (auto arg : op.getOperands()) {
+    for (Value arg : op.getOperands()) {
       std::optional<unsigned> w = getTypeBitWidth(arg.getType(), DL);
       if (!w)
         return std::nullopt;
@@ -62,7 +62,7 @@ static bool isRangePermutation(ValueRange val1, ValueRange val2) {
   if (val1.size() != val2.size())
     return false;
 
-  for (auto v1 : val1) {
+  for (Value v1 : val1) {
     auto it = llvm::find(val2, v1);
     if (it == val2.end())
       return false;
@@ -74,9 +74,9 @@ template <typename Op>
 static std::optional<unsigned>
 canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
                                const DataLayout *DL) {
-  auto loopIndexVars = loop.getInductionVars();
+  ValueRange loopIndexVars = loop.getInductionVars();
   assert(dim < loopIndexVars.size() && "Invalid loop dimension");
-  auto memref = memOp.getMemRef();
+  Value memref = memOp.getMemRef();
   auto type = cast<MemRefType>(memref.getType());
   std::optional<unsigned> width = getTypeBitWidth(type.getElementType(), DL);
   if (!width)
@@ -118,7 +118,7 @@ canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
 template <typename Op>
 static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
                                                     const DataLayout *DL) {
-  auto memref = op.getMemRef();
+  Value memref = op.getMemRef();
   auto memrefType = cast<MemRefType>(memref.getType());
   std::optional<unsigned> width =
       getTypeBitWidth(memrefType.getElementType(), DL);
@@ -151,7 +151,7 @@ canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
 static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
                                                     unsigned dim, Operation &op,
                                                     const DataLayout *DL) {
-  if (auto w = canTriviallyVectorizeMemOp(loop, dim, op, DL))
+  if (std::optional<unsigned> w = canTriviallyVectorizeMemOp(loop, dim, op, DL))
     return w;
 
   return canGatherScatter(loop, op, DL);
@@ -200,8 +200,8 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
       return std::nullopt;
 
     /// Check mem ops.
-    if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
-      auto newFactor = vectorBitwidth / *w;
+    if (std::optional<unsigned> w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
+      unsigned newFactor = vectorBitwidth / *w;
       if (newFactor > 1) {
         factor = std::min(factor, newFactor);
         ++count;
@@ -220,7 +220,7 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
     if (!width)
       return std::nullopt;
 
-    auto newFactor = vectorBitwidth / *width;
+    unsigned newFactor = vectorBitwidth / *width;
     if (newFactor <= 1)
       continue;
 
@@ -247,26 +247,26 @@ static arith::FastMathFlags getFMF(Operation &op) {
 LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
                                        const scf::SCFVectorizeParams &params,
                                        const DataLayout *DL) {
-  auto dim = params.dim;
-  auto factor = params.factor;
-  auto masked = params.masked;
+  unsigned dim = params.dim;
+  unsigned factor = params.factor;
+  bool masked = params.masked;
   assert(dim < loop.getStep().size() && "Invalid loop dimension");
   assert(factor > 1 && "Invalid vectorize factor");
   assert(isConstantIntValue(loop.getStep()[dim], 1) && "Loop stepust be 1");
 
   OpBuilder builder(loop);
-  auto lower = llvm::to_vector(loop.getLowerBound());
-  auto upper = llvm::to_vector(loop.getUpperBound());
-  auto step = llvm::to_vector(loop.getStep());
+  SmallVector<Value> lower = llvm::to_vector(loop.getLowerBound());
+  SmallVector<Value> upper = llvm::to_vector(loop.getUpperBound());
+  SmallVector<Value> step = llvm::to_vector(loop.getStep());
 
-  auto loc = loop.getLoc();
+  Location loc = loop.getLoc();
 
-  auto origIndexVar = loop.getInductionVars()[dim];
+  Value origIndexVar = loop.getInductionVars()[dim];
 
   Value factorVal = builder.create<arith::ConstantIndexOp>(loc, factor);
 
-  auto origLower = lower[dim];
-  auto origUpper = upper[dim];
+  Value origLower = lower[dim];
+  Value origUpper = upper[dim];
   Value count = builder.createOrFold<arith::SubIOp>(loc, origUpper, origLower);
   Value newCount;
 
@@ -284,10 +284,10 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
   // Vectorized loop.
   auto newLoop = builder.create<scf::ParallelOp>(loc, lower, upper, step,
                                                  loop.getInitVals());
-  auto newIndexVar = newLoop.getInductionVars()[dim];
+  Value newIndexVar = newLoop.getInductionVars()[dim];
 
   auto toVectorType = [&](Type elemType) -> VectorType {
-    int64_t f = factor;
+    auto f = static_cast<int64_t>(factor);
     return VectorType::get(f, elemType);
   };
 
@@ -311,14 +311,14 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
   // Get vector value in new loop for provided `orig` value in source loop.
   auto getVecVal = [&](Value orig) -> Value {
     // Use cached value if present.
-    if (auto mapped = mapping.lookupOrNull(orig))
+    if (Value mapped = mapping.lookupOrNull(orig))
       return mapped;
 
     // Vectorized loop index, loop index is divided by factor, so for factorN
     // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
     if (orig == origIndexVar) {
-      auto vecType = toVectorType(builder.getIndexType());
-      llvm::SmallVector<Attribute> elems(factor);
+      VectorType vecType = toVectorType(builder.getIndexType());
+      SmallVector<Attribute> elems(factor);
       for (auto i : llvm::seq(0u, factor))
         elems[i] = builder.getIndexAttr(i);
       auto attr = DenseElementsAttr::get(vecType, elems);
@@ -331,11 +331,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
       mapping.map(orig, vec);
       return vec;
     }
-    auto type = orig.getType();
+    Type type = orig.getType();
     assert(isSupportedVecElem(type) && "Unsupported vector element type");
 
     Value val = orig;
-    auto origIndexVars = loop.getInductionVars();
+    ValueRange origIndexVars = loop.getInductionVars();
     auto it = llvm::find(origIndexVars, orig);
 
     // If loop index, but not on vectorized dimension, just take new loop index
@@ -346,14 +346,12 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
     // Values which are defined inside loop body are preemptively added to the
     // mapper and not handled here. Values defined outside body are just
     // splatted.
-
-    auto vecType = toVectorType(type);
-    Value vec = builder.create<vector::SplatOp>(loc, val, vecType);
+    Value vec = builder.create<vector::SplatOp>(loc, val, toVectorType(type));
     mapping.map(orig, vec);
     return vec;
   };
 
-  llvm::DenseMap<Value, llvm::SmallVector<Value>> unpackedVals;
+  llvm::DenseMap<Value, SmallVector<Value>> unpackedVals;
 
   // Get unpacked values for provided `orig` value in source loop.
   // Values are returned as `ValueRange` and not as vector value.
@@ -376,7 +374,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
     }
 
     // Get vector value and extract elements from it.
-    auto vecVal = getVecVal(val);
+    Value vecVal = getVecVal(val);
     ret.resize(factor);
     for (auto i : llvm::seq(0u, factor)) {
       Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
@@ -391,13 +389,13 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
     assert(unpackedVals.count(origVal) == 0 && "Invalid unpackedVals state");
     unpackedVals[origVal].append(newVals.begin(), newVals.end());
 
-    auto type = origVal.getType();
+    Type type = origVal.getType();
     if (!isSupportedVecElem(type))
       return;
 
     // If type is vectorizabale construct a vector add it to vector cache as
     // well.
-    auto vecType = toVectorType(type);
+    VectorType vecType = toVectorType(type);
 
     Value vec = createPosionVec(vecType);
     for (auto i : llvm::seq(0u, factor)) {
@@ -422,7 +420,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
     } else {
       maskSize = builder.getIndexAttr(factor);
     }
-    auto vecType = toVectorType(builder.getI1Type());
+    VectorType vecType = toVectorType(builder.getI1Type());
     mask = builder.create<vector::CreateMaskOp>(loc, vecType, maskSize);
 
     return mask;
@@ -437,11 +435,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
   };
 
   // Get idices for vectorized memref load/store.
-  auto getMemrefVecIndices = [&](ValueRange indices) {
+  auto getMemrefVecIndices = [&](ValueRange indices) -> SmallVector<Value> {
     scalarMapping.clear();
     scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
 
-    llvm::SmallVector<Value> ret(indices.size());
+    SmallVector<Value> ret(indices.size());
     for (auto &&[i, val] : llvm::enumerate(indices)) {
       if (val == origIndexVar) {
         Value idx = getrIndexVarMult();
@@ -457,13 +455,13 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
 
   // Create vectorized memref load for specified non-vectorized load.
   auto genLoad = [&](auto loadOp) {
-    auto indices = getMemrefVecIndices(loadOp.getIndices());
-    auto resType = toVectorType(loadOp.getResult().getType());
-    auto memref = loadOp.getMemRef();
+    SmallVector<Value> indices = getMemrefVecIndices(loadOp.getIndices());
+    VectorType resType = toVectorType(loadOp.getResult().getType());
+    Value memref = loadOp.getMemRef();
     Value vecLoad;
     if (masked) {
-      auto mask = getMask();
-      auto init = createPosionVec(resType);
+      Value mask = getMask();
+      Value init = createPosionVec(resType);
       vecLoad = builder.create<vector::MaskedLoadOp>(loc, resType, memref,
                                                      indices, mask, init);
     } else {
@@ -474,19 +472,19 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
 
   // Create vectorized memref store for specified non-vectorized store.
   auto genStore = [&](auto storeOp) {
-    auto indices = getMemrefVecIndices(storeOp.getIndices());
-    auto value = getVecVal(storeOp.getValueToStore());
-    auto memref = storeOp.getMemRef();
+    SmallVector<Value> indices = getMemrefVecIndices(storeOp.getIndices());
+    Value value = getVecVal(storeOp.getValueToStore());
+    Value memref = storeOp.getMemRef();
     if (masked) {
-      auto mask = getMask();
+      Value mask = getMask();
       builder.create<vector::MaskedStoreOp>(loc, memref, indices, mask, value);
     } else {
       builder.create<vector::StoreOp>(loc, value, memref, indices);
     }
   };
 
-  llvm::SmallVector<Value> duplicatedArgs;
-  llvm::SmallVector<Value> duplicatedResults;
+  SmallVector<Value> duplicatedArgs;
+  SmallVector<Value> duplicatedResults;
 
   builder.setInsertionPointToStart(newLoop.getBody());
   for (Operation &op : loop.getBody()->without_terminator()) {
@@ -494,11 +492,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
     if (isSupportedVectorOp(op)) {
       // If op can be vectorized, clone it with vectorized inputs and  update
       // resuls to vectorized types.
-      for (auto arg : op.getOperands())
+      for (Value arg : op.getOperands())
         getVecVal(arg); // init mapper for op args
 
-      auto newOp = builder.clone(op, mapping);
-      for (auto res : newOp->getResults())
+      Operation *newOp = builder.clone(op, mapping);
+      for (Value res : newOp->getResults())
         res.setType(toVectorType(res.getType()));
 
       continue;
@@ -512,11 +510,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
         continue;
       }
       if (canGatherScatter(loadOp)) {
-        auto resType = toVectorType(loadOp.getResult().getType());
-        auto memref = loadOp.getMemRef();
-        auto mask = getMask();
-        auto indexVec = getVecVal(loadOp.getIndices()[0]);
-        auto init = createPosionVec(resType);
+        VectorType resType = toVectorType(loadOp.getResult().getType());
+        Value memref = loadOp.getMemRef();
+        Value mask = getMask();
+        Value indexVec = getVecVal(loadOp.getIndices()[0]);
+        Value init = createPosionVec(resType);
 
         auto gather = builder.create<vector::GatherOp>(
             loc, resType, memref, zero, indexVec, mask, init);
@@ -531,10 +529,10 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
         continue;
       }
       if (canGatherScatter(storeOp)) {
-        auto memref = storeOp.getMemRef();
-        auto value = getVecVal(storeOp.getValueToStore());
-        auto mask = getMask();
-        auto indexVec = getVecVal(storeOp.getIndices()[0]);
+        Value memref = storeOp.getMemRef();
+        Value value = getVecVal(storeOp.getValueToStore());
+        Value mask = getMask();
+        Value indexVec = getVecVal(storeOp.getIndices()[0]);
 
         builder.create<vector::ScatterOp>(loc, memref, zero, indexVec, mask,
                                           value);
@@ -554,7 +552,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
     duplicatedResults.resize(numResults * factor);
 
     for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) {
-      auto unpacked = getUnpackedVals(arg);
+      ValueRange unpacked = getUnpackedVals(arg);
       assert(unpacked.size() == factor && "Invalid unpacked size");
       for (auto j : llvm::seq(0u, factor))
         duplicatedArgs[j * numArgs + i] = unpacked[j];
@@ -565,7 +563,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
                       .drop_front(numArgs * i)
                       .take_front(numArgs);
       scalarMapping.map(op.getOperands(), args);
-      auto results = builder.clone(op, scalarMapping)->getResults();
+      ValueRange results = builder.clone(op, scalarMapping)->getResults();
 
       for (auto j : llvm::seq(0u, numResults))
         duplicatedResults[j * factor + i] = results[j];
@@ -581,7 +579,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
 
   // Vectorize `scf.reduce` op.
   auto reduceOp = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
-  llvm::SmallVector<Value> reduceVals;
+  SmallVector<Value> reduceVals;
   reduceVals.reserve(reduceOp.getNumOperands());
 
   for (auto &&[body, arg] :
@@ -595,16 +593,17 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
       // Generate `vector.reduce` if possible.
       Value redArg = getVecVal(arg);
       if (redArg) {
-        auto neutral = arith::getNeutralElement(&reduceBody.front());
+        std::optional<TypedAttr> neutral =
+            arith::getNeutralElement(&reduceBody.front());
         assert(neutral && "getNeutralElement has unepectedly failed");
         Value neutralVal = builder.create<arith::ConstantOp>(loc, *neutral);
         Value neutralVec =
             builder.create<vector::SplatOp>(loc, neutralVal, redArg.getType());
-        auto mask = getMask();
+        Value mask = getMask();
         redArg = builder.create<arith::SelectOp>(loc, mask, redArg, neutralVec);
       }
 
-      auto fmf = getFMF(reduceBody.front());
+      arith::FastMathFlags fmf = getFMF(reduceBody.front());
       reduceVal =
           builder.create<vector::ReductionOp>(loc, *redKind, redArg, fmf);
     } else {
@@ -615,16 +614,16 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
       // individually.
 
       auto reduceTerm = cast<scf::ReduceReturnOp>(reduceBody.getTerminator());
-      auto lhs = reduceBody.getArgument(0);
-      auto rhs = reduceBody.getArgument(1);
-      auto unpacked = getUnpackedVals(arg);
+      Value lhs = reduceBody.getArgument(0);
+      Value rhs = reduceBody.getArgument(1);
+      ValueRange unpacked = getUnpackedVals(arg);
       assert(unpacked.size() == factor && "Invalid unpacked size");
       reduceVal = unpacked.front();
       for (auto i : llvm::seq(1u, factor)) {
         Value val = unpacked[i];
         scalarMapping.map(lhs, reduceVal);
         scalarMapping.map(rhs, val);
-        for (auto &redOp : reduceBody.without_terminator())
+        for (Operation &redOp : reduceBody.without_terminator())
           builder.clone(redOp, scalarMapping);
 
         reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult());
@@ -648,7 +647,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
         builder.createOrFold<arith::MulIOp>(loc, newCount, factorVal);
     newLower = builder.createOrFold<arith::AddIOp>(loc, origLower, newLower);
 
-    auto lowerCopy = llvm::to_vector(loop.getLowerBound());
+    SmallVector<Value> lowerCopy = llvm::to_vector(loop.getLowerBound());
     lowerCopy[dim] = newLower;
     loop.getLowerBoundMutable().assign(lowerCopy);
     loop.getInitValsMutable().assign(newLoop.getResults());

>From ab4302fb38336c4125d7757b372f88a9395f543a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 19:48:27 +0200
Subject: [PATCH 10/10] remove tmp var

---
 mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index b7d1281fb20ca..d441dffc6c58a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -287,8 +287,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
   Value newIndexVar = newLoop.getInductionVars()[dim];
 
   auto toVectorType = [&](Type elemType) -> VectorType {
-    auto f = static_cast<int64_t>(factor);
-    return VectorType::get(f, elemType);
+    return VectorType::get(factor, elemType);
   };
 
   IRMapping mapping;



More information about the Mlir-commits mailing list