[Mlir-commits] [mlir] 76a1861 - [mlir][SparseTensor] Split scf.for loop into masked/unmasked parts

Matthias Springer llvmlistbot at llvm.org
Thu Aug 19 05:53:26 PDT 2021


Author: Matthias Springer
Date: 2021-08-19T21:53:11+09:00
New Revision: 76a186181634feaeaa2d0493aac2b796d2a3ef25

URL: https://github.com/llvm/llvm-project/commit/76a186181634feaeaa2d0493aac2b796d2a3ef25
DIFF: https://github.com/llvm/llvm-project/commit/76a186181634feaeaa2d0493aac2b796d2a3ef25.diff

LOG: [mlir][SparseTensor] Split scf.for loop into masked/unmasked parts

Apply the "for loop peeling" pattern from SCF dialect transforms. This pattern splits scf.for loops into full and partial iterations. In the full iteration, all masked loads/stores are canonicalized to unmasked loads/stores.

Differential Revision: https://reviews.llvm.org/D107733

Added: 
    mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sparse_vector.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3bbf6c82b5cd..01f8615de172 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -53,6 +53,7 @@ def Sparsification : Pass<"sparsification", "ModuleOp"> {
   }];
   let constructor = "mlir::createSparsificationPass()";
   let dependentDialects = [
+    "AffineDialect",
     "LLVM::LLVMDialect",
     "memref::MemRefDialect",
     "scf::SCFDialect",

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 24784555d143..f30318649dc0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -10,10 +10,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
@@ -348,7 +350,13 @@ static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter,
   // during vector execution. Here we rely on subsequent loop optimizations to
   // avoid executing the mask in all iterations, for example, by splitting the
   // loop into an unconditional vector loop and a scalar cleanup loop.
-  Value end = rewriter.create<SubIOp>(loc, hi, iv);
+  auto minMap = AffineMap::get(
+      /*dimCount=*/2, /*symbolCount=*/1,
+      {rewriter.getAffineSymbolExpr(0),
+       rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
+      rewriter.getContext());
+  Value end =
+      rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step});
   return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 9674123d13bd..ee8f948eee4d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -1,26 +1,14 @@
-// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" | \
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC0
-// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" | \
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC1
-// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" | \
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC2
-// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" | \
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" -split-input-file | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC3
 
 #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
 
-#SparseVector = #sparse_tensor.encoding<{
-  dimLevelType = [ "compressed" ],
-  pointerBitWidth = 32,
-  indexBitWidth = 32
-}>
-
-#SparseMatrix = #sparse_tensor.encoding<{
-  dimLevelType = [ "dense", "compressed" ],
-  pointerBitWidth = 32,
-  indexBitWidth = 32
-}>
-
 #trait_scale_d = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
@@ -77,6 +65,14 @@ func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024
   return %0 : tensor<1024xf32>
 }
 
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed" ],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
+
 #trait_mul_s = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
@@ -128,6 +124,7 @@ func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024
 // CHECK-VEC1:       }
 // CHECK-VEC1:       return
 //
+// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
 // CHECK-VEC2-LABEL: func @mul_s
 // CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
 // CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
@@ -139,7 +136,7 @@ func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024
 // CHECK-VEC2:       %[[b:.*]] = zexti %[[r]] : i32 to i64
 // CHECK-VEC2:       %[[s:.*]] = index_cast %[[b]] : i64 to index
 // CHECK-VEC2:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
-// CHECK-VEC2:         %[[sub:.*]] = subi %[[s]], %[[i]] : index
+// CHECK-VEC2:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
 // CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
 // CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
 // CHECK-VEC2:         %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64>
@@ -150,6 +147,7 @@ func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
 // CHECK-VEC3-LABEL: func @mul_s
 // CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
 // CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
@@ -161,7 +159,7 @@ func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024
 // CHECK-VEC3:       %[[b:.*]] = zexti %[[r]] : i32 to i64
 // CHECK-VEC3:       %[[s:.*]] = index_cast %[[b]] : i64 to index
 // CHECK-VEC3:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
-// CHECK-VEC3:         %[[sub:.*]] = subi %{{.*}}, %[[i]] : index
+// CHECK-VEC3:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
 // CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
 // CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
 // CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
@@ -182,6 +180,10 @@ func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %ar
   return %0 : tensor<1024xf32>
 }
 
+// -----
+
+#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
+
 #trait_reduction_d = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
@@ -248,6 +250,14 @@ func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>
   return %0 : tensor<f32>
 }
 
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
+
 #trait_mul_ds = {
   indexing_maps = [
     affine_map<(i,j) -> (i,j)>,  // A
@@ -307,6 +317,7 @@ func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>
 // CHECK-VEC1:       }
 // CHECK-VEC1:       return
 //
+// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
 // CHECK-VEC2-LABEL: func @mul_ds
 // CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
 // CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
@@ -321,7 +332,7 @@ func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>
 // CHECK-VEC2:         %[[b:.*]] = zexti %[[r]] : i32 to i64
 // CHECK-VEC2:         %[[s:.*]] = index_cast %[[b]] : i64 to index
 // CHECK-VEC2:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
-// CHECK-VEC2:           %[[sub:.*]] = subi %[[s]], %[[j]] : index
+// CHECK-VEC2:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
 // CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
 // CHECK-VEC2:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
 // CHECK-VEC2:           %[[zj:.*]] = zexti %[[lj]] : vector<16xi32> to vector<16xi64>
@@ -333,6 +344,7 @@ func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
 // CHECK-VEC3-LABEL: func @mul_ds
 // CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
 // CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
@@ -347,7 +359,7 @@ func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>
 // CHECK-VEC3:         %[[b:.*]] = zexti %[[r]] : i32 to i64
 // CHECK-VEC3:         %[[s:.*]] = index_cast %[[b]] : i64 to index
 // CHECK-VEC3:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
-// CHECK-VEC3:           %[[sub:.*]] = subi %[[s]], %[[j]] : index
+// CHECK-VEC3:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
 // CHECK-VEC3:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
 // CHECK-VEC3:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
 // CHECK-VEC3:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
new file mode 100644
index 000000000000..26c424f3f791
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" -for-loop-peeling -canonicalize -split-input-file | \
+// RUN:   FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed" ],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
+
+#trait_mul_s = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a
+    affine_map<(i) -> (i)>,  // b
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = a(i) * b(i)"
+}
+
+// CHECK-DAG:   #[[$map0:.*]] = affine_map<()[s0, s1] -> (s0 + ((-s0 + s1) floordiv 16) * 16)>
+// CHECK-DAG:   #[[$map1:.*]] = affine_map<()[s0, s1] -> ((s0 - s1) mod 16)>
+// CHECK-LABEL: func @mul_s
+// CHECK-DAG:   %[[c0:.*]] = constant 0 : index
+// CHECK-DAG:   %[[c1:.*]] = constant 1 : index
+// CHECK-DAG:   %[[c16:.*]] = constant 16 : index
+// CHECK:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
+// CHECK:       %[[a:.*]] = zexti %[[p]] : i32 to i64
+// CHECK:       %[[q:.*]] = index_cast %[[a]] : i64 to index
+// CHECK:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
+// CHECK:       %[[b:.*]] = zexti %[[r]] : i32 to i64
+// CHECK:       %[[s:.*]] = index_cast %[[b]] : i64 to index
+// CHECK:       %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
+// CHECK:       scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
+// CHECK:         %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK:         %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
+// CHECK:         %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64>
+// CHECK:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
+// CHECK:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
+// CHECK:       }
+// CHECK:       %[[has_more:.*]] = cmpi slt, %[[boundary]], %[[s]] : index
+// CHECK:       scf.if %[[has_more]] {
+// CHECK:         %[[sub:.*]] = affine.apply #[[$map1]]()[%[[s]], %[[q]]]
+// CHECK:         %[[mask2:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK:         %[[li2:.*]] = vector.maskedload %{{.*}}[%[[boundary]]], %[[mask2]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK:         %[[zi2:.*]] = zexti %[[li2]] : vector<16xi32> to vector<16xi64>
+// CHECK:         %[[la2:.*]] = vector.maskedload %{{.*}}[%[[boundary]]], %[[mask2]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK:         %[[lb2:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK:         %[[m2:.*]] = mulf %[[la2]], %[[lb2]] : vector<16xf32>
+// CHECK:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %[[m2]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
+// CHECK:       }
+// CHECK:       return
+//
+func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
+  %0 = linalg.generic #trait_mul_s
+    ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
+    outs(%argx: tensor<1024xf32>) {
+      ^bb(%a: f32, %b: f32, %x: f32):
+        %0 = mulf %a, %b : f32
+        linalg.yield %0 : f32
+  } -> tensor<1024xf32>
+  return %0 : tensor<1024xf32>
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e3a7e7b4d606..db2381168374 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1636,6 +1636,7 @@ cc_library(
     hdrs = ["include/mlir/Dialect/SparseTensor/Transforms/Passes.h"],
     includes = ["include"],
     deps = [
+        ":Affine",
         ":IR",
         ":LLVMDialect",
         ":LinalgOps",


        


More information about the Mlir-commits mailing list