[Mlir-commits] [mlir] 7783a17 - [mlir][Sparse] Add option for VLA sparsification

Javier Setoain llvmlistbot at llvm.org
Fri Mar 25 03:55:53 PDT 2022


Author: Javier Setoain
Date: 2022-03-25T10:54:49Z
New Revision: 7783a178f5752b24267167c8abc5db38d839b839

URL: https://github.com/llvm/llvm-project/commit/7783a178f5752b24267167c8abc5db38d839b839
DIFF: https://github.com/llvm/llvm-project/commit/7783a178f5752b24267167c8abc5db38d839b839.diff

LOG: [mlir][Sparse] Add option for VLA sparsification

Use "enable-vla-vectorization=vla" to generate a vector length agnostic
loops during vectorization. This option works for vectorization strategy 2.

Differential Revision: https://reviews.llvm.org/D118379

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/test/Dialect/SparseTensor/sparse_vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 782f3b443d428..a7064a2508312 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -41,12 +41,16 @@ struct SparseCompilerOptions
   PassOptions::Option<bool> enableSIMDIndex32{
       *this, "enable-simd-index32",
       desc("Enable i32 indexing into vectors (for efficiency)"), init(false)};
+  PassOptions::Option<bool> enableVLAVectorization{
+      *this, "enable-vla-vectorization",
+      desc("Enable vector length agnostic vectorization"), init(false)};
 
   /// Projects out the options for `createSparsificationPass`.
   SparsificationOptions sparsificationOptions() const {
     return SparsificationOptions(sparseParallelizationStrategy(parallelization),
                                  sparseVectorizationStrategy(vectorization),
-                                 vectorLength, enableSIMDIndex32);
+                                 vectorLength, enableSIMDIndex32,
+                                 enableVLAVectorization);
   }
 
   // These options must be kept in sync with `SparseTensorConversionBase`.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 1888b45fc8442..20a322e97dfc0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -64,16 +64,19 @@ SparseVectorizationStrategy sparseVectorizationStrategy(int32_t flag);
 /// Options for the Sparsification pass.
 struct SparsificationOptions {
   SparsificationOptions(SparseParallelizationStrategy p,
-                        SparseVectorizationStrategy v, unsigned vl, bool e)
+                        SparseVectorizationStrategy v, unsigned vl, bool e,
+                        bool vla)
       : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
-        enableSIMDIndex32(e) {}
+        enableSIMDIndex32(e), enableVLAVectorization(vla) {}
   SparsificationOptions()
       : SparsificationOptions(SparseParallelizationStrategy::kNone,
-                              SparseVectorizationStrategy::kNone, 1u, false) {}
+                              SparseVectorizationStrategy::kNone, 1u, false,
+                              false) {}
   SparseParallelizationStrategy parallelizationStrategy;
   SparseVectorizationStrategy vectorizationStrategy;
   unsigned vectorLength;
   bool enableSIMDIndex32;
+  bool enableVLAVectorization;
 };
 
 /// Sets up sparsification rewriting rules with the given options.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 89aacd69b67a0..db5f881e684d7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -70,7 +70,9 @@ def Sparsification : Pass<"sparsification", "ModuleOp"> {
     Option<"vectorLength", "vl", "int32_t", "1",
            "Set the vector length">,
     Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false",
-           "Enable i32 indexing into vectors (for efficiency)">
+           "Enable i32 indexing into vectors (for efficiency)">,
+    Option<"enableVLAVectorization", "enable-vla-vectorization", "bool",
+           "false", "Enable vector length agnostic vectorization">
   ];
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 2124aecc128da..a43932b1cecf8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -40,6 +40,7 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
     vectorization = static_cast<int32_t>(options.vectorizationStrategy);
     vectorLength = options.vectorLength;
     enableSIMDIndex32 = options.enableSIMDIndex32;
+    enableVLAVectorization = options.enableVLAVectorization;
   }
 
   void runOnOperation() override {
@@ -49,7 +50,7 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
     SparsificationOptions options(
         sparseParallelizationStrategy(parallelization),
         sparseVectorizationStrategy(vectorization), vectorLength,
-        enableSIMDIndex32);
+        enableSIMDIndex32, enableVLAVectorization);
     // Apply rewriting.
     populateSparsificationPatterns(patterns, options);
     vector::populateVectorToVectorCanonicalizationPatterns(patterns);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0ebd4e4b5cd2e..398c4faa07d90 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -552,7 +552,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
 
 /// Constructs vector type.
 static VectorType vectorType(CodeGen &codegen, Type etp) {
-  return VectorType::get(codegen.curVecLength, etp);
+  unsigned numScalableDims = codegen.options.enableVLAVectorization;
+  return VectorType::get(codegen.curVecLength, etp, numScalableDims);
 }
 
 /// Constructs vector type from pointer.
@@ -1164,6 +1165,11 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
   Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx];
   Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx];
   Value step = constantIndex(rewriter, loc, codegen.curVecLength);
+  if (isVector && codegen.options.enableVLAVectorization) {
+    Value vscale = rewriter.create<vector::VectorScaleOp>(
+        loc, IndexType::get(rewriter.getContext()));
+    step = rewriter.create<arith::MulIOp>(loc, vscale, step);
+  }
 
   // Emit a parallel loop.
   if (isParallel) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 005278ca70d22..4b05fdb8770cd 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -563,7 +563,7 @@ Type Merger::inferType(unsigned e, Value src) {
   // Inspect source type. For vector types, apply the same
   // vectorization to the destination type.
   if (auto vtp = src.getType().dyn_cast<VectorType>())
-    return VectorType::get(vtp.getNumElements(), dtp);
+    return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
   return dtp;
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 82a4bc4d5d91c..7b2d3a4213494 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -6,6 +6,8 @@
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC2
 // RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" -cse -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC3
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=4 enable-vla-vectorization=true" -cse -split-input-file | \
+// RUN:   FileCheck %s --check-prefix=CHECK-VEC4
 
 #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
 
@@ -54,6 +56,24 @@
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-LABEL: func @scale_d
+// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-VEC4-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-VEC4-DAG:   %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] {
+// CHECK-VEC4:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
+// CHECK-VEC4:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4:         %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4:         %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32>
+// CHECK-VEC4:         %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32>
+// CHECK-VEC4:         vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4:       }
+// CHECK-VEC4:       return
+//
 func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_scale_d
     ins(%arga: tensor<1024xf32, #DenseVector>)
@@ -169,6 +189,33 @@ func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024
 // CHECK-VEC3:       }
 // CHECK-VEC3:       return
 //
+// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-LABEL: func @mul_s
+// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-DAG:   %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32>
+// CHECK-VEC4-DAG:   %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-VEC4:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
+// CHECK-VEC4:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-VEC4:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-VEC4:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
+// CHECK-VEC4:       %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-VEC4:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-VEC4:       %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] {
+// CHECK-VEC4:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]]
+// CHECK-VEC4:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK-VEC4:         %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64>
+// CHECK-VEC4:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0f]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[v0f]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
+// CHECK-VEC4:         vector.scatter %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4:       }
+// CHECK-VEC4:       return
+//
 func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_mul_s
     ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
@@ -242,6 +289,29 @@ func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %ar
 // CHECK-VEC2:       %{{.*}} = vector.reduction <add>, %[[red]] : vector<16xf32> into f32
 // CHECK-VEC2:       return
 //
+// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-LABEL: func @reduction_d
+// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
+// CHECK-VEC4-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-VEC4:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
+// CHECK-VEC4:       %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
+// CHECK-VEC4:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
+// CHECK-VEC4:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4:         %[[lb:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
+// CHECK-VEC4:         %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<[4]xf32>
+// CHECK-VEC4:         %[[sa:.*]] = arith.select %[[mask]], %[[a]], %[[red_in]] : vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4:         scf.yield %[[sa]] : vector<[4]xf32>
+// CHECK-VEC4:       }
+// CHECK-VEC4:       %{{.*}} = vector.reduction <add>, %[[red]] : vector<[4]xf32> into f32
+// CHECK-VEC4:       return
+//
 func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_reduction_d
     ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>)
@@ -374,6 +444,37 @@ func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>
 // CHECK-VEC3:       }
 // CHECK-VEC3:       return
 //
+// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-LABEL: func @mul_ds
+// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-DAG:   %[[c512:.*]] = arith.constant 512 : index
+// CHECK-VEC4-DAG:   %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32>
+// CHECK-VEC4-DAG:   %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK-VEC4:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-VEC4:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
+// CHECK-VEC4:         %[[a:.*]] = arith.extui %[[p]] : i32 to i64
+// CHECK-VEC4:         %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
+// CHECK-VEC4:         %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-VEC4:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
+// CHECK-VEC4:         %[[b:.*]] = arith.extui %[[r]] : i32 to i64
+// CHECK-VEC4:         %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
+// CHECK-VEC4:         %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4:         %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[step]] {
+// CHECK-VEC4:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[step]]]
+// CHECK-VEC4:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4:           %[[lji32:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK-VEC4:           %[[lj:.*]] = arith.extui %[[lji32]] : vector<[4]xi32> to vector<[4]xi64>
+// CHECK-VEC4:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0f]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[v0f]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4:           %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32>
+// CHECK-VEC4:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4:         }
+// CHECK-VEC4:       }
+// CHECK-VEC4:       return
+//
 func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
   %0 = linalg.generic #trait_mul_ds
     ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
@@ -457,6 +558,32 @@ func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024x
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC4:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)
+// CHECK-VEC4-LABEL: func @add_dense
+// CHECK-VEC4-DAG:   %[[c0:.*]] = arith.constant 0 : index
+// CHECK-VEC4-DAG:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK-VEC4-DAG:   %[[c4:.*]] = arith.constant 4 : index
+// CHECK-VEC4-DAG:   %[[c32:.*]] = arith.constant 32 : index
+// CHECK-VEC4-DAG:   %[[v0idx:.*]] = arith.constant dense<0> : vector<[4]xindex>
+// CHECK-VEC4-DAG:   %[[v0f64:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf64>
+// CHECK-VEC4:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
+// CHECK-VEC4:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
+// CHECK-VEC4:         %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index
+// CHECK-VEC4:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
+// CHECK-VEC4:         %[[vscale:.*]] = vector.vscale
+// CHECK-VEC4:         %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
+// CHECK-VEC4:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[step]] {
+// CHECK-VEC4:           %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[step]]]
+// CHECK-VEC4:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
+// CHECK-VEC4:           %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0idx]] : memref<?xindex>
+// CHECK-VEC4:           %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[v0f64]] : memref<33x64xf64>
+// CHECK-VEC4:           %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0f64]] : memref<?xf64>
+// CHECK-VEC4:           %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<[4]xf64>
+// CHECK-VEC4:           vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
+// CHECK-VEC4:         }
+// CHECK-VEC4:       }
+// CHECK-VEC4:       return
+//
 func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>,
                 %argx: tensor<33x64xf64> {linalg.inplaceable = true}) -> tensor<33x64xf64> {
   %0 = linalg.generic #trait_affine


        


More information about the Mlir-commits mailing list