[Mlir-commits] [mlir] 99b3849 - [mlir][sparse] introduce vectorization pass for sparse loops

Aart Bik llvmlistbot at llvm.org
Mon Nov 21 16:12:27 PST 2022


Author: Aart Bik
Date: 2022-11-21T16:12:12-08:00
New Revision: 99b3849d89cfdbc60ce4e18fc9c70dfd377bd93b

URL: https://github.com/llvm/llvm-project/commit/99b3849d89cfdbc60ce4e18fc9c70dfd377bd93b
DIFF: https://github.com/llvm/llvm-project/commit/99b3849d89cfdbc60ce4e18fc9c70dfd377bd93b.diff

LOG: [mlir][sparse] introduce vectorization pass for sparse loops

This brings back previous SIMD functionality, but in a separate pass.
The idea is to improve this new pass incrementally, going beyond for-loops
to while-loops for co-iteration as welll (masking), while introducing new
abstractions to make the lowering more progressive. The separation of
sparsification and vectorization is a very good first step on this journey.

Also brings back ArmSVE support

Still to be fine-tuned:
  + use of "index" in SIMD loop (viz. a[i] = i)
  + check that all ops really have SIMD support
  + check all forms of reductions
  + chain reduction SIMD values

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D138236

Added: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/sparse_vector.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 0961b5e000868..9b04c376f5c2c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -172,6 +172,16 @@ std::unique_ptr<Pass> createSparseBufferRewritePass();
 std::unique_ptr<Pass>
 createSparseBufferRewritePass(bool enableBufferInitialization);
 
+void populateSparseVectorizationPatterns(RewritePatternSet &patterns,
+                                         unsigned vectorLength,
+                                         bool enableVLAVectorization,
+                                         bool enableSIMDIndex32);
+
+std::unique_ptr<Pass> createSparseVectorizationPass();
+std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
+                                                    bool enableVLAVectorization,
+                                                    bool enableSIMDIndex32);
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 32bba3a1552e4..3342d2c072caf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -225,4 +225,64 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
   ];
 }
 
+def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
+  let summary = "Vectorizes loops after sparsification";
+  let description = [{
+    A pass that converts loops after sparsification into vector loops.
+    The vector dialect is used as target to provide an architectural
+    neutral way of exploiting any platform that supports SIMD instructions.
+
+    The vector length (viz. `vl`) describes the number of packed data elements
+    (e.g. both vector<16xf32> and vector<16xf64> have a vector length of 16 even
+    though the actual bitwidths 
diff er). A small multiple of the actual lengths
+    supported in hardware typically results in efficient SIMD code, since the
+    backend will map longer vectors to multiple vector registers, thereby
+    effectively unrolling an addition level within the generated for-loop.
+
+    Example of the conversion:
+
+    ```mlir
+      Before:
+        %3 = memref.load %2[] : memref<f32>
+        %4 = scf.for %arg3 = %c0 to %c1024 step %c1 iter_args(%arg4 = %3) -> (f32) {
+          %6 = memref.load %0[%arg3] : memref<?xf32>
+          %7 = memref.load %1[%arg3] : memref<1024xf32>
+          %8 = arith.mulf %6, %7 : f32
+          %9 = arith.addf %arg4, %8 : f32
+          scf.yield %9 : f32
+        }
+        memref.store %4, %2[] : memref<f32>
+
+      After:
+        %3 = memref.load %2[] : memref<f32>
+        %4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32>
+        %5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) {
+          %8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32>
+          %9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32>
+          %10 = arith.mulf %8, %9 : vector<32xf32>
+          %11 = arith.addf %arg4, %10 : vector<32xf32>
+          scf.yield %11 : vector<32xf32>
+        }
+        %6 = vector.reduction <add>, %5 : vector<32xf32> into f32
+        memref.store %6, %2[] : memref<f32>
+    ```
+  }];
+  let constructor = "mlir::createSparseVectorizationPass()";
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+    "vector::VectorDialect",
+  ];
+  let options = [
+    Option<"vectorLength", "vl", "int32_t", "0",
+           "Set the vector length (use 0 to disable vectorization)">,
+    Option<"enableVLAVectorization", "enable-vla-vectorization", "bool",
+           "false", "Enable vector length agnostic vectorization">,
+    Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false",
+           "Enable i32 indexing into vectors (for efficient gather/scatter)">,
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index eafee08ae05ea..00c624fd6e085 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
   SparseTensorRewriting.cpp
+  SparseVectorization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d1491dfffcb7d..06f57eb6edc50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -27,6 +27,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
+#define GEN_PASS_DEF_SPARSEVECTORIZATION
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -67,10 +68,9 @@ struct SparsificationPass
     auto *ctx = &getContext();
     // Translate strategy flags to strategy options.
     SparsificationOptions options(parallelization);
-    // Apply sparsification and vector cleanup rewriting.
+    // Apply sparsification and cleanup rewriting.
     RewritePatternSet patterns(ctx);
     populateSparsificationPatterns(patterns, options);
-    vector::populateVectorToVectorCanonicalizationPatterns(patterns);
     scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
@@ -250,6 +250,27 @@ struct SparseBufferRewritePass
   }
 };
 
+struct SparseVectorizationPass
+    : public impl::SparseVectorizationBase<SparseVectorizationPass> {
+
+  SparseVectorizationPass() = default;
+  SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
+  SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
+    vectorLength = vl;
+    enableVLAVectorization = vla;
+    enableSIMDIndex32 = sidx32;
+  }
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateSparseVectorizationPatterns(
+        patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
+    vector::populateVectorToVectorCanonicalizationPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -322,3 +343,15 @@ std::unique_ptr<Pass>
 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
   return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
 }
+
+std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
+  return std::make_unique<SparseVectorizationPass>();
+}
+
+std::unique_ptr<Pass>
+mlir::createSparseVectorizationPass(unsigned vectorLength,
+                                    bool enableVLAVectorization,
+                                    bool enableSIMDIndex32) {
+  return std::make_unique<SparseVectorizationPass>(
+      vectorLength, enableVLAVectorization, enableSIMDIndex32);
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
new file mode 100644
index 0000000000000..aed394990428d
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -0,0 +1,485 @@
+//===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A pass that converts loops generated by the sparse compiler into a form that
+// can exploit SIMD instructions of the target architecture. Note that this pass
+// ensures the sparse compiler can generate efficient SIMD (including ArmSVE
+// support) with proper separation of concerns as far as sparsification and
+// vectorization is concerned. However, this pass is not the final abstraction
+// level we want, and not the general vectorizer we want either. It forms a good
+// stepping stone for incremental future improvements though.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Matchers.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+/// Target SIMD properties:
+///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
+///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
+///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
+struct VL {
+  unsigned vectorLength;
+  bool enableVLAVectorization;
+  bool enableSIMDIndex32;
+};
+
+/// Helper to test for given index value.
+static bool isIntValue(Value val, int64_t idx) {
+  if (auto ival = getConstantIntValue(val))
+    return *ival == idx;
+  return false;
+}
+
+/// Constructs vector type for element type.
+static VectorType vectorType(VL vl, Type etp) {
+  unsigned numScalableDims = vl.enableVLAVectorization;
+  return VectorType::get(vl.vectorLength, etp, numScalableDims);
+}
+
+/// Constructs vector type from pointer.
+static VectorType vectorType(VL vl, Value ptr) {
+  return vectorType(vl, ptr.getType().cast<MemRefType>().getElementType());
+}
+
+/// Constructs vector iteration mask.
+static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
+                           Value iv, Value lo, Value hi, Value step) {
+  VectorType mtp = vectorType(vl, rewriter.getI1Type());
+  // Special case if the vector length evenly divides the trip count (for
+  // example, "for i = 0, 128, 16"). A constant all-true mask is generated
+  // so that all subsequent masked memory operations are immediately folded
+  // into unconditional memory operations.
+  IntegerAttr loInt, hiInt, stepInt;
+  if (matchPattern(lo, m_Constant(&loInt)) &&
+      matchPattern(hi, m_Constant(&hiInt)) &&
+      matchPattern(step, m_Constant(&stepInt))) {
+    if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
+      Value trueVal = constantI1(rewriter, loc, true);
+      return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
+    }
+  }
+  // Otherwise, generate a vector mask that avoids overrunning the upperbound
+  // during vector execution. Here we rely on subsequent loop optimizations to
+  // avoid executing the mask in all iterations, for example, by splitting the
+  // loop into an unconditional vector loop and a scalar cleanup loop.
+  auto min = AffineMap::get(
+      /*dimCount=*/2, /*symbolCount=*/1,
+      {rewriter.getAffineSymbolExpr(0),
+       rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
+      rewriter.getContext());
+  Value end =
+      rewriter.createOrFold<AffineMinOp>(loc, min, ValueRange{hi, iv, step});
+  return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
+}
+
+/// Generates a vectorized invariant. Here we rely on subsequent loop
+/// optimizations to hoist the invariant broadcast out of the vector loop.
+static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
+                                     Value val) {
+  VectorType vtp = vectorType(vl, val.getType());
+  return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
+}
+
+/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
+/// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
+static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
+                           Value ptr, ArrayRef<Value> idxs, Value vmask) {
+  VectorType vtp = vectorType(vl, ptr);
+  Value pass = constantZero(rewriter, loc, vtp);
+  if (idxs.back().getType().isa<VectorType>()) {
+    SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
+    Value indexVec = idxs.back();
+    scalarArgs.back() = constantIndex(rewriter, loc, 0);
+    return rewriter.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs,
+                                             indexVec, vmask, pass);
+  }
+  return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, idxs, vmask,
+                                               pass);
+}
+
+/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
+/// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
+static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
+                           ArrayRef<Value> idxs, Value vmask, Value rhs) {
+  if (idxs.back().getType().isa<VectorType>()) {
+    SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
+    Value indexVec = idxs.back();
+    scalarArgs.back() = constantIndex(rewriter, loc, 0);
+    rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, vmask,
+                                       rhs);
+    return;
+  }
+  rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
+}
+
+/// Maps operation to combining kind for reduction.
+static vector::CombiningKind getCombiningKind(Operation *def) {
+  if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def) ||
+      isa<arith::SubFOp>(def) || isa<arith::SubIOp>(def))
+    return vector::CombiningKind::ADD;
+  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
+    return vector::CombiningKind::MUL;
+  if (isa<arith::AndIOp>(def))
+    return vector::CombiningKind::AND;
+  if (isa<arith::OrIOp>(def))
+    return vector::CombiningKind::OR;
+  if (isa<arith::XOrIOp>(def))
+    return vector::CombiningKind::XOR;
+  llvm_unreachable("unknown reduction kind");
+}
+
+/// Generates an initial value for a vector reduction, following the scheme
+/// given in Chapter 5 of "The Software Vectorization Handbook", where the
+/// initial scalar value is correctly embedded in the vector reduction value,
+/// and a straightforward horizontal reduction will complete the operation.
+/// The value 'r' denotes the initial value of the accumulator. Value 'rd'
+/// denotes the accumulation operation, which is solely used here to determine
+/// the kind of combining reduction (viz. addf -> sum-accumulation).
+static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
+                                VectorType vtp, Value r, Value rd) {
+  vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
+  switch (kind) {
+  case vector::CombiningKind::ADD:
+  case vector::CombiningKind::XOR:
+    // Initialize reduction vector to: | 0 | .. | 0 | r |
+    return rewriter.create<vector::InsertElementOp>(
+        loc, r, constantZero(rewriter, loc, vtp),
+        constantIndex(rewriter, loc, 0));
+  case vector::CombiningKind::MUL:
+    // Initialize reduction vector to: | 1 | .. | 1 | r |
+    return rewriter.create<vector::InsertElementOp>(
+        loc, r, constantOne(rewriter, loc, vtp),
+        constantIndex(rewriter, loc, 0));
+  case vector::CombiningKind::AND:
+  case vector::CombiningKind::OR:
+    // Initialize reduction vector to: | r | .. | r | r |
+    return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
+  default:
+    break;
+  }
+  llvm_unreachable("unknown reduction kind");
+}
+
+/// Generates final value for a vector reduction.
+static Value genVectorReducEnd(PatternRewriter &rewriter, Location loc,
+                               Value vexp, Value rd) {
+  vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
+  return rewriter.create<vector::ReductionOp>(loc, kind, vexp);
+}
+
+/// This method is called twice to analyze and rewrite the given subscripts.
+/// The first call (!codegen) does the analysis. Then, on success, the second
+/// call (codegen) yields the proper vector form in the output parameter
+/// vector 'idxs'. This mechanism ensures that analysis and rewriting code
+/// stay in sync.
+///
+/// See https://llvm.org/docs/GetElementPtr.html for some background on
+/// the complications described below.
+///
+/// We need to generate a pointer/index load from the sparse storage scheme.
+/// Narrower data types need to be zero extended before casting the value
+/// into the index type used for looping and indexing.
+///
+/// For the scalar case, subscripts simply zero extend narrower indices
+/// into 64-bit values before casting to an index type without a performance
+/// penalty. Indices that already are 64-bit, in theory, cannot express the
+/// full range since the LLVM backend defines addressing in terms of an
+/// unsigned pointer/signed index pair.
+static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
+                                VL vl, ValueRange subs, bool codegen,
+                                Value vmask, SmallVectorImpl<Value> &idxs) {
+  for (auto sub : subs) {
+    // Invariant indices simply pass through.
+    if (sub.dyn_cast<BlockArgument>() ||
+        sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
+      if (codegen)
+        idxs.push_back(sub);
+      continue; // success so far
+    }
+    // Look under the hood of casting.
+    auto cast = sub;
+    while (1) {
+      if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
+        cast = icast->getOperand(0);
+      else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
+        cast = ecast->getOperand(0);
+      else
+        break;
+    }
+    // Since the index vector is used in a subsequent gather/scatter
+    // operations, which effectively defines an unsigned pointer + signed
+    // index, we must zero extend the vector to an index width. For 8-bit
+    // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
+    // zero extending the elements into 64-bit loses some performance since
+    // the 32-bit indexed gather/scatter is more efficient than the 64-bit
+    // index variant (if the negative 32-bit index space is unused, the
+    // enableSIMDIndex32 flag can preserve this performance). For 64-bit
+    // values, there is no good way to state that the indices are unsigned,
+    // which creates the potential of incorrect address calculations in the
+    // unlikely case we need such extremely large offsets.
+    if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
+      if (codegen) {
+        SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
+        Location loc = forOp.getLoc();
+        Value vload =
+            genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
+        Type etp = vload.getType().cast<VectorType>().getElementType();
+        if (!etp.isa<IndexType>()) {
+          if (etp.getIntOrFloatBitWidth() < 32)
+            vload = rewriter.create<arith::ExtUIOp>(
+                loc, vectorType(vl, rewriter.getI32Type()), vload);
+          else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
+            vload = rewriter.create<arith::ExtUIOp>(
+                loc, vectorType(vl, rewriter.getI64Type()), vload);
+        }
+        idxs.push_back(vload);
+      }
+      continue; // success so far
+    }
+    return false;
+  }
+  return true;
+}
+
+#define UNAOP(xxx)                                                             \
+  if (isa<xxx>(def)) {                                                         \
+    if (codegen)                                                               \
+      vexp = rewriter.create<xxx>(loc, vx);                                    \
+    return true;                                                               \
+  }
+
+#define BINOP(xxx)                                                             \
+  if (isa<xxx>(def)) {                                                         \
+    if (codegen)                                                               \
+      vexp = rewriter.create<xxx>(loc, vx, vy);                                \
+    return true;                                                               \
+  }
+
+/// This method is called twice to analyze and rewrite the given expression.
+/// The first call (!codegen) does the analysis. Then, on success, the second
+/// call (codegen) yields the proper vector form in the output parameter 'vexp'.
+/// This mechanism ensures that analysis and rewriting code stay in sync.
+static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
+                          Value exp, bool codegen, Value vmask, Value &vexp) {
+  // A block argument in invariant.
+  if (auto arg = exp.dyn_cast<BlockArgument>()) {
+    if (codegen)
+      vexp = genVectorInvariantValue(rewriter, vl, exp);
+    return true;
+  }
+  // Something defined outside the loop-body is invariant as well.
+  Operation *def = exp.getDefiningOp();
+  if (def->getBlock() != &forOp.getRegion().front()) {
+    if (codegen)
+      vexp = genVectorInvariantValue(rewriter, vl, exp);
+    return true;
+  }
+  // Inside loop-body unary and binary operations. Note that it would be
+  // nicer if we could somehow test and build the operations in a more
+  // concise manner than just listing them all (although this way we know
+  // for certain that they can vectorize).
+  Location loc = forOp.getLoc();
+  if (auto load = dyn_cast<memref::LoadOp>(def)) {
+    auto subs = load.getIndices();
+    SmallVector<Value> idxs;
+    if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
+      if (codegen)
+        vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
+      return true;
+    }
+  } else if (def->getNumOperands() == 1) {
+    Value vx;
+    if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
+                      vx)) {
+      UNAOP(math::AbsFOp)
+      UNAOP(math::AbsIOp)
+      UNAOP(math::CeilOp)
+      UNAOP(math::FloorOp)
+      UNAOP(math::SqrtOp)
+      UNAOP(math::ExpM1Op)
+      UNAOP(math::Log1pOp)
+      UNAOP(math::SinOp)
+      UNAOP(math::TanhOp)
+      UNAOP(arith::NegFOp)
+    }
+  } else if (def->getNumOperands() == 2) {
+    Value vx, vy;
+    if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
+                      vx) &&
+        vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
+                      vy)) {
+      BINOP(arith::MulFOp)
+      BINOP(arith::MulIOp)
+      BINOP(arith::DivFOp)
+      BINOP(arith::DivSIOp)
+      BINOP(arith::DivUIOp)
+      BINOP(arith::AddFOp)
+      BINOP(arith::AddIOp)
+      BINOP(arith::SubFOp)
+      BINOP(arith::SubIOp)
+      BINOP(arith::AndIOp)
+      BINOP(arith::OrIOp)
+      BINOP(arith::XOrIOp)
+    }
+  }
+  return false;
+}
+
+#undef UNAOP
+#undef BINOP
+
+/// This method is called twice to analyze and rewrite the given for-loop.
+/// The first call (!codegen) does the analysis. Then, on success, the second
+/// call (codegen) rewriters the IR into vector form. This mechanism ensures
+/// that analysis and rewriting code stay in sync.
+static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
+                          bool codegen) {
+  Location loc = forOp.getLoc();
+  Block &block = forOp.getRegion().front();
+  scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
+  auto &last = *++block.rbegin();
+  scf::ForOp forOpNew;
+
+  // Perform initial set up during codegen (we know that the first analysis
+  // pass was successful). For reductions, we need to construct a completely
+  // new for-loop, since the incoming and outgoing reduction type
+  // changes into SIMD form. For stores, we can simply adjust the stride
+  // and insert in the existing for-loop. In both cases, we set up a vector
+  // mask for all operations which takes care of confining vectors to
+  // the original iteration space (later cleanup loops or other
+  // optimizations can take care of those).
+  Value vmask;
+  if (codegen) {
+    Value step = constantIndex(rewriter, loc, vl.vectorLength);
+    if (vl.enableVLAVectorization) {
+      Value vscale =
+          rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+      step = rewriter.create<arith::MulIOp>(loc, vscale, step);
+    }
+    if (!yield.getResults().empty()) {
+      Value init = forOp.getInitArgs()[0];
+      VectorType vtp = vectorType(vl, init.getType());
+      Value vinit =
+          genVectorReducInit(rewriter, loc, vtp, init, yield->getOperand(0));
+      forOpNew = rewriter.create<scf::ForOp>(
+          loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
+      rewriter.setInsertionPointToStart(forOpNew.getBody());
+    } else {
+      forOp.setStep(step);
+      rewriter.setInsertionPoint(yield);
+    }
+    vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
+                          forOp.getLowerBound(), forOp.getUpperBound(), step);
+  }
+
+  // Sparse for-loops either are terminated by a non-empty yield operation
+  // (reduction loop) or otherwise by a store operation (pararallel loop).
+  if (!yield.getResults().empty()) {
+    if (yield->getNumOperands() != 1)
+      return false;
+    Value redOp = yield->getOperand(0);
+    // Analyze/vectorize reduction.
+    // TODO: use linalg utils to verify the actual reduction?
+    Value vrhs;
+    if (vectorizeExpr(rewriter, forOp, vl, redOp, codegen, vmask, vrhs)) {
+      if (codegen) {
+        Value vpass =
+            genVectorInvariantValue(rewriter, vl, forOp.getRegionIterArg(0));
+        Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
+        rewriter.create<scf::YieldOp>(loc, vred);
+        rewriter.setInsertionPointAfter(forOpNew);
+        Value vres = genVectorReducEnd(rewriter, loc, forOpNew.getResult(0), redOp);
+        // Now do some relinking (last one is not completely type safe
+        // but all bad ones are removed right away). This also folds away
+        // nop broadcast operations.
+        forOp.getResult(0).replaceAllUsesWith(vres);
+        forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
+        forOp.getRegionIterArg(0).replaceAllUsesWith(
+            forOpNew.getRegionIterArg(0));
+        rewriter.eraseOp(forOp);
+      }
+      return true;
+    }
+  } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
+    // Analyze/vectorize store operation.
+    auto subs = store.getIndices();
+    SmallVector<Value> idxs;
+    Value rhs = store.getValue();
+    Value vrhs;
+    if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
+        vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
+      if (codegen) {
+        genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
+        rewriter.eraseOp(store);
+      }
+      return true;
+    }
+  }
+
+  assert(!codegen && "cannot call codegen when analysis failed");
+  return false;
+}
+
+/// Basic for-loop vectorizer.
+struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
+public:
+  using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+
+  ForOpRewriter(MLIRContext *context, unsigned vectorLength,
+                bool enableVLAVectorization, bool enableSIMDIndex32)
+      : OpRewritePattern(context),
+        vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
+
+  LogicalResult matchAndRewrite(scf::ForOp op,
+                                PatternRewriter &rewriter) const override {
+    // Check for single block, unit-stride for-loop that is generated by
+    // sparse compiler, which means no data dependence analysis is required,
+    // and its loop-body is very restricted in form.
+    if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) ||
+        !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()))
+      return failure();
+    // Analyze (!codegen) and rewrite (codegen) loop-body.
+    if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
+        vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
+      return success();
+    return failure();
+  }
+
+private:
+  const VL vl;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Public method for populating vectorization rules.
+//===----------------------------------------------------------------------===//
+
+/// Populates the given patterns list with vectorization rules.
+void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
+                                               unsigned vectorLength,
+                                               bool enableVLAVectorization,
+                                               bool enableSIMDIndex32) {
+  patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
+                              enableVLAVectorization, enableSIMDIndex32);
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
old mode 100644
new mode 100755
index fca5a33332195..a5c0239dc7edd
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -1,5 +1,11 @@
 // RUN: mlir-opt %s -sparsification -cse -split-input-file | \
-// RUN:   FileCheck %s
+// RUN:   FileCheck %s --check-prefix=CHECK-SCALAR
+// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16" -cse -split-input-file | \
+// RUN:   FileCheck %s --check-prefix=CHECK-VEC16
+// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16 enable-simd-index32=true" -cse -split-input-file | \
+// RUN:   FileCheck %s --check-prefix=CHECK-VEC16-IDX32
+// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=4 enable-vla-vectorization=true" -cse -split-input-file | \
+// RUN:   FileCheck %s --check-prefix=CHECK-VEC4-SVE
 
 #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
 
@@ -13,18 +19,59 @@
 }
 
 //
-// CHECK-LABEL: func @scale_d
-// CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
-// CHECK:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
-// CHECK:         %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
-// CHECK:         %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32
-// CHECK:         store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
-// CHECK:       }
-// CHECK:       return
+// CHECK-SCALAR-LABEL: func @scale_d
+// CHECK-SCALAR-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-SCALAR-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-SCALAR-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-SCALAR:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
+// CHECK-SCALAR:         %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
+// CHECK-SCALAR:         %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32
+// CHECK-SCALAR:         store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
+// CHECK-SCALAR:       }
+// CHECK-SCALAR:       return
+//
+// CHECK-VEC16-LABEL: func @scale_d
+// CHECK-VEC16-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-VEC16:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
+// CHECK-VEC16:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
+// CHECK-VEC16:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
+// CHECK-VEC16:         %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32>
+// CHECK-VEC16:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC16:       }
+// CHECK-VEC16:       return
+//
+// CHECK-VEC16-IDX32-LABEL: func @scale_d
+// CHECK-VEC16-IDX32-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-VEC16-IDX32:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
+// CHECK-VEC16-IDX32:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32>
+// CHECK-VEC16-IDX32:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC16-IDX32:       }
+// CHECK-VEC16-IDX32:       return
+//
+// CHECK-VEC4-SVE:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-SVE-LABEL: func @scale_d
+// CHECK-VEC4-SVE-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-SVE-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-SVE-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-VEC4-SVE-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-VEC4-SVE-DAG:   %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4-SVE:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4-SVE:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] {
+// CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
+// CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4-SVE:         %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32>
+// CHECK-VEC4-SVE:         vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4-SVE:       }
+// CHECK-VEC4-SVE:       return
 //
-
 func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_scale_d
     ins(%arga: tensor<1024xf32, #DenseVector>)
@@ -55,27 +102,101 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
 }
 
 //
-// CHECK-LABEL: func @mul_s
-// CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
-// CHECK:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
-// CHECK:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
-// CHECK:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
-// CHECK:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
-// CHECK:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
-// CHECK:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
-// CHECK:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
-// CHECK:         %[[zi:.*]] = arith.extui %[[li]] : i32 to i64
-// CHECK:         %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index
-// CHECK:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
-// CHECK:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
-// CHECK:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
-// CHECK:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
-// CHECK:       }
-// CHECK:       return
-//
-func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
+// CHECK-SCALAR-LABEL: func @mul_s
+// CHECK-SCALAR-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-SCALAR-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-SCALAR:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
+// CHECK-SCALAR:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-SCALAR:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-SCALAR:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
+// CHECK-SCALAR:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-SCALAR:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-SCALAR:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
+// CHECK-SCALAR:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
+// CHECK-SCALAR:         %[[zi:.*]] = arith.extui %[[li]] : i32 to i64
+// CHECK-SCALAR:         %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index
+// CHECK-SCALAR:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
+// CHECK-SCALAR:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
+// CHECK-SCALAR:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
+// CHECK-SCALAR:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
+// CHECK-SCALAR:       }
+// CHECK-SCALAR:       return
+//
+// CHECK-VEC16:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
+// CHECK-VEC16-LABEL: func @mul_s
+// CHECK-VEC16-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC16-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
+// CHECK-VEC16:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-VEC16:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-VEC16:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
+// CHECK-VEC16:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-VEC16:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-VEC16:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
+// CHECK-VEC16:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
+// CHECK-VEC16:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC16:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC16:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
+// CHECK-VEC16:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC16:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC16:       }
+// CHECK-VEC16:       return
+//
+// CHECK-VEC16-IDX32:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
+// CHECK-VEC16-IDX32-LABEL: func @mul_s
+// CHECK-VEC16-IDX32-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-IDX32:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
+// CHECK-VEC16-IDX32:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-VEC16-IDX32:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-VEC16-IDX32:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
+// CHECK-VEC16-IDX32:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-VEC16-IDX32:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-VEC16-IDX32:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
+// CHECK-VEC16-IDX32:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
+// CHECK-VEC16-IDX32:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC16-IDX32:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC16-IDX32:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC16-IDX32:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC16-IDX32:       }
+// CHECK-VEC16-IDX32:       return
+//
+// CHECK-VEC4-SVE:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-SVE-LABEL: func @mul_s
+// CHECK-VEC4-SVE-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-SVE-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC4-SVE-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-SVE-DAG:   %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32>
+// CHECK-VEC4-SVE-DAG:   %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-VEC4-SVE:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
+// CHECK-VEC4-SVE:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-VEC4-SVE:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-VEC4-SVE:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
+// CHECK-VEC4-SVE:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-VEC4-SVE:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-VEC4-SVE:       %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4-SVE:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4-SVE:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] {
+// CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]]
+// CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4-SVE:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK-VEC4-SVE:         %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64>
+// CHECK-VEC4-SVE:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0f]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[v0f]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
+// CHECK-VEC4-SVE:         vector.scatter %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4-SVE:       }
+// CHECK-VEC4-SVE:       return
+//
+func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
+                 %argb: tensor<1024xf32>,
+		 %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_mul_s
     ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
     outs(%argx: tensor<1024xf32>) {
@@ -101,20 +222,79 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>
 }
 
 //
-// CHECK-LABEL: func @reduction_d
-// CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
-// CHECK:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
-// CHECK:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
-// CHECK:         %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
-// CHECK:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
-// CHECK:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32
-// CHECK:         scf.yield %[[a]] : f32
-// CHECK:       }
-// CHECK:       return
-//
-func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
+// CHECK-SCALAR-LABEL: func @reduction_d
+// CHECK-SCALAR-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-SCALAR-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-SCALAR-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-SCALAR:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
+// CHECK-SCALAR:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
+// CHECK-SCALAR:         %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
+// CHECK-SCALAR:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
+// CHECK-SCALAR:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32
+// CHECK-SCALAR:         scf.yield %[[a]] : f32
+// CHECK-SCALAR:       }
+// CHECK-SCALAR:       return
+//
+// CHECK-VEC16-LABEL: func @reduction_d
+// CHECK-VEC16-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-VEC16-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-VEC16:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
+// CHECK-VEC16:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
+// CHECK-VEC16:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
+// CHECK-VEC16:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC16:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC16:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32>
+// CHECK-VEC16:         scf.yield %[[a]] : vector<16xf32>
+// CHECK-VEC16:       }
+// CHECK-VEC16:       %{{.*}} = vector.reduction <add>, %[[red]] : vector<16xf32> into f32
+// CHECK-VEC16:       return
+//
+// CHECK-VEC16-IDX32-LABEL: func @reduction_d
+// CHECK-VEC16-IDX32-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-VEC16-IDX32-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-VEC16-IDX32:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
+// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
+// CHECK-VEC16-IDX32:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32>
+// CHECK-VEC16-IDX32:         scf.yield %[[a]] : vector<16xf32>
+// CHECK-VEC16-IDX32:       }
+// CHECK-VEC16-IDX32:       %{{.*}} = vector.reduction <add>, %[[red]] : vector<16xf32> into f32
+// CHECK-VEC16-IDX32:       return
+//
+// CHECK-VEC4-SVE:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-SVE-LABEL: func @reduction_d
+// CHECK-VEC4-SVE-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-SVE-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-SVE-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-VEC4-SVE-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-VEC4-SVE:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
+// CHECK-VEC4-SVE:       %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4-SVE:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
+// CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
+// CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4-SVE:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[lb:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[sa:.*]] = arith.select %[[mask]], %[[a]], %[[red_in]] : vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4-SVE:         scf.yield %[[sa]] : vector<[4]xf32>
+// CHECK-VEC4-SVE:       }
+// CHECK-VEC4-SVE:       %{{.*}} = vector.reduction <add>, %[[red]] : vector<[4]xf32> into f32
+// CHECK-VEC4-SVE:       return
+//
+func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>,
+                       %argb: tensor<1024xf32>,
+		       %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_reduction_d
     ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>)
     outs(%argx: tensor<f32>) {
@@ -145,31 +325,117 @@ func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024
 }
 
 //
-// CHECK-LABEL: func @mul_ds
-// CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[c512:.*]] = arith.constant 512 : index
-// CHECK:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
-// CHECK:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
-// CHECK:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
-// CHECK:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
-// CHECK:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
-// CHECK:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
-// CHECK:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
-// CHECK:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
-// CHECK:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
-// CHECK:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
-// CHECK:           %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64
-// CHECK:           %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index
-// CHECK:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
-// CHECK:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
-// CHECK:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
-// CHECK:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
-// CHECK:         }
-// CHECK:       }
-// CHECK:       return
-//
-func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
+// CHECK-SCALAR-LABEL: func @mul_ds
+// CHECK-SCALAR-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-SCALAR-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-SCALAR-DAG:   %[[c512:.*]] = arith.constant 512 : index
+// CHECK-SCALAR:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-SCALAR:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
+// CHECK-SCALAR:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-SCALAR:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-SCALAR:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-SCALAR:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
+// CHECK-SCALAR:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-SCALAR:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-SCALAR:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
+// CHECK-SCALAR:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
+// CHECK-SCALAR:           %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64
+// CHECK-SCALAR:           %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index
+// CHECK-SCALAR:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
+// CHECK-SCALAR:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
+// CHECK-SCALAR:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32
+// CHECK-SCALAR:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
+// CHECK-SCALAR:         }
+// CHECK-SCALAR:       }
+// CHECK-SCALAR:       return
+//
+// CHECK-VEC16:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
+// CHECK-VEC16-LABEL: func @mul_ds
+// CHECK-VEC16-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC16-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-DAG:   %[[c512:.*]] = arith.constant 512 : index
+// CHECK-VEC16:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-VEC16:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
+// CHECK-VEC16:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-VEC16:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-VEC16:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-VEC16:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
+// CHECK-VEC16:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-VEC16:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-VEC16:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
+// CHECK-VEC16:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
+// CHECK-VEC16:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC16:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC16:           %[[zj:.*]] = arith.extui %[[lj]] : vector<16xi32> to vector<16xi64>
+// CHECK-VEC16:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC16:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC16:         }
+// CHECK-VEC16:       }
+// CHECK-VEC16:       return
+//
+// CHECK-VEC16-IDX32:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
+// CHECK-VEC16-IDX32-LABEL: func @mul_ds
+// CHECK-VEC16-IDX32-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c512:.*]] = arith.constant 512 : index
+// CHECK-VEC16-IDX32:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-VEC16-IDX32:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
+// CHECK-VEC16-IDX32:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-VEC16-IDX32:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-VEC16-IDX32:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-VEC16-IDX32:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
+// CHECK-VEC16-IDX32:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-VEC16-IDX32:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-VEC16-IDX32:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
+// CHECK-VEC16-IDX32:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
+// CHECK-VEC16-IDX32:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC16-IDX32:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC16-IDX32:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16-IDX32:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16-IDX32:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC16-IDX32:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC16-IDX32:         }
+// CHECK-VEC16-IDX32:       }
+// CHECK-VEC16-IDX32:       return
+//
+// CHECK-VEC4-SVE:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-SVE-LABEL: func @mul_ds
+// CHECK-VEC4-SVE-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-SVE-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC4-SVE-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-SVE-DAG:   %[[c512:.*]] = arith.constant 512 : index
+// CHECK-VEC4-SVE-DAG:   %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32>
+// CHECK-VEC4-SVE-DAG:   %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-VEC4-SVE:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-VEC4-SVE:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
+// CHECK-VEC4-SVE:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-VEC4-SVE:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-VEC4-SVE:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-VEC4-SVE:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
+// CHECK-VEC4-SVE:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-VEC4-SVE:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-VEC4-SVE:         %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4-SVE:         %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4-SVE:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[step]] {
+// CHECK-VEC4-SVE:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[step]]]
+// CHECK-VEC4-SVE:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4-SVE:           %[[lji32:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK-VEC4-SVE:           %[[lj:.*]] = arith.extui %[[lji32]] : vector<[4]xi32> to vector<[4]xi64>
+// CHECK-VEC4-SVE:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0f]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[v0f]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
+// CHECK-VEC4-SVE:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4-SVE:         }
+// CHECK-VEC4-SVE:       }
+// CHECK-VEC4-SVE:       return
+//
+func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>,
+                  %argb: tensor<512x1024xf32>,
+		  %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
   %0 = linalg.generic #trait_mul_ds
     ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
     outs(%argx: tensor<512x1024xf32>) {
@@ -194,26 +460,96 @@ func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x
 }
 
 //
-// CHECK-LABEL: func @add_dense
-// CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-DAG:   %[[c32:.*]] = arith.constant 32 : index
-// CHECK:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
-// CHECK:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
-// CHECK:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
-// CHECK:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
-// CHECK:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
-// CHECK:           %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
-// CHECK:           %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
-// CHECK:           %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
-// CHECK:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64
-// CHECK:           memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
-// CHECK:         }
-// CHECK:       }
-// CHECK:       return
+// CHECK-SCALAR-LABEL: func @add_dense
+// CHECK-SCALAR-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-SCALAR-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-SCALAR-DAG:   %[[c32:.*]] = arith.constant 32 : index
+// CHECK-SCALAR:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
+// CHECK-SCALAR:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
+// CHECK-SCALAR:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-SCALAR:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
+// CHECK-SCALAR:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
+// CHECK-SCALAR:           %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
+// CHECK-SCALAR:           %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
+// CHECK-SCALAR:           %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
+// CHECK-SCALAR:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64
+// CHECK-SCALAR:           memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
+// CHECK-SCALAR:         }
+// CHECK-SCALAR:       }
+// CHECK-SCALAR:       return
+//
+// CHECK-VEC16:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
+// CHECK-VEC16-LABEL: func @add_dense
+// CHECK-VEC16-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC16-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-DAG:   %[[c32:.*]] = arith.constant 32 : index
+// CHECK-VEC16:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
+// CHECK-VEC16:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
+// CHECK-VEC16:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-VEC16:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
+// CHECK-VEC16:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] {
+// CHECK-VEC16:           %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]]
+// CHECK-VEC16:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC16:           %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xindex>
+// CHECK-VEC16:           %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64>
+// CHECK-VEC16:           %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xf64>
+// CHECK-VEC16:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64>
+// CHECK-VEC16:           vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
+// CHECK-VEC16:         }
+// CHECK-VEC16:       }
+// CHECK-VEC16:       return
+//
+// CHECK-VEC16-IDX32:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
+// CHECK-VEC16-IDX32-LABEL: func @add_dense
+// CHECK-VEC16-IDX32-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-VEC16-IDX32-DAG:   %[[c32:.*]] = arith.constant 32 : index
+// CHECK-VEC16-IDX32:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
+// CHECK-VEC16-IDX32:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
+// CHECK-VEC16-IDX32:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-VEC16-IDX32:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
+// CHECK-VEC16-IDX32:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] {
+// CHECK-VEC16-IDX32:           %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]]
+// CHECK-VEC16-IDX32:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC16-IDX32:           %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xindex>
+// CHECK-VEC16-IDX32:           %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64>
+// CHECK-VEC16-IDX32:           %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xf64>
+// CHECK-VEC16-IDX32:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64>
+// CHECK-VEC16-IDX32:           vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
+// CHECK-VEC16-IDX32:         }
+// CHECK-VEC16-IDX32:       }
+// CHECK-VEC16-IDX32:       return
+//
+// CHECK-VEC4-SVE:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-SVE-LABEL: func @add_dense
+// CHECK-VEC4-SVE-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-SVE-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC4-SVE-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-SVE-DAG:   %[[c32:.*]] = arith.constant 32 : index
+// CHECK-VEC4-SVE-DAG:   %[[v0idx:.*]] = arith.constant dense<0> : vector<[4]xindex>
+// CHECK-VEC4-SVE-DAG:   %[[v0f64:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf64>
+// CHECK-VEC4-SVE:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
+// CHECK-VEC4-SVE:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
+// CHECK-VEC4-SVE:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-VEC4-SVE:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
+// CHECK-VEC4-SVE:         %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4-SVE:         %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4-SVE:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[step]] {
+// CHECK-VEC4-SVE:           %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[step]]]
+// CHECK-VEC4-SVE:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4-SVE:           %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0idx]] : memref<?xindex>
+// CHECK-VEC4-SVE:           %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[v0f64]] : memref<33x64xf64>
+// CHECK-VEC4-SVE:           %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0f64]] : memref<?xf64>
+// CHECK-VEC4-SVE:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<[4]xf64>
+// CHECK-VEC4-SVE:           vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
+// CHECK-VEC4-SVE:         }
+// CHECK-VEC4-SVE:       }
+// CHECK-VEC4-SVE:       return
 //
 func.func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>,
-                %argx: tensor<33x64xf64>) -> tensor<33x64xf64> {
+                     %argx: tensor<33x64xf64>) -> tensor<33x64xf64> {
   %0 = linalg.generic #trait_affine
      ins(%arga: tensor<32x64xf64, #SparseMatrix>)
     outs(%argx: tensor<33x64xf64>) {

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9b67e59eee5d4..03124a6e23377 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2224,6 +2224,7 @@ cc_library(
         ":LinalgDialect",
         ":LinalgTransforms",
         ":LinalgUtils",
+        ":MathDialect",
         ":MemRefDialect",
         ":Pass",
         ":SCFDialect",
@@ -2235,6 +2236,7 @@ cc_library(
         ":Support",
         ":TensorDialect",
         ":Transforms",
+        ":VectorDialect",
         "//llvm:Support",
     ],
 )


        


More information about the Mlir-commits mailing list