[Mlir-commits] [mlir] 23dc96b - [mlir][sparse] fix crashes when using custom reduce with unary operation.

Peiming Liu llvmlistbot at llvm.org
Mon Jun 5 16:41:32 PDT 2023


Author: Peiming Liu
Date: 2023-06-05T23:41:26Z
New Revision: 23dc96bbe41b1a80bafd90b67e76a6882d9b6cd9

URL: https://github.com/llvm/llvm-project/commit/23dc96bbe41b1a80bafd90b67e76a6882d9b6cd9
DIFF: https://github.com/llvm/llvm-project/commit/23dc96bbe41b1a80bafd90b67e76a6882d9b6cd9.diff

LOG: [mlir][sparse] fix crashes when using custom reduce with unary operation.

The tests case is directly copied from https://reviews.llvm.org/D152179 authored by @aartbik

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D152204

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 20fc678773be0..caef60eb1ab7b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -505,8 +505,15 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
 /// that analysis and rewriting code stay in sync.
 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
                           bool codegen) {
-  Location loc = forOp.getLoc();
   Block &block = forOp.getRegion().front();
+  // For loops with single yield statement (as below) could be generated
+  // when custom reduce is used with unary operation.
+  // for (...)
+  //   yield c_0
+  if (block.getOperations().size() <= 1)
+    return false;
+
+  Location loc = forOp.getLoc();
   scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
   auto &last = *++block.rbegin();
   scf::ForOp forOpNew;

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir
index 06b8a1ad0f3a7..9884c65cea53e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir
@@ -30,6 +30,7 @@
 // An example of vector reductions.
 module {
 
+  // Custom prod reduction: stored i32 elements only.
   func.func @prod_dreduction_i32(%arga: tensor<32xi32, #DV>,
                                  %argx: tensor<i32>) -> tensor<i32> {
     %c = tensor.extract %argx[] : tensor<i32>
@@ -47,6 +48,7 @@ module {
     return %0 : tensor<i32>
   }
 
+  // Custom prod reduction: stored f32 elements only.
   func.func @prod_dreduction_f32(%arga: tensor<32xf32, #DV>,
                                  %argx: tensor<f32>) -> tensor<f32> {
     %c = tensor.extract %argx[] : tensor<f32>
@@ -64,6 +66,7 @@ module {
     return %0 : tensor<f32>
   }
 
+  // Custom prod reduction: stored i32 elements only.
   func.func @prod_sreduction_i32(%arga: tensor<32xi32, #SV>,
                                  %argx: tensor<i32>) -> tensor<i32> {
     %c = tensor.extract %argx[] : tensor<i32>
@@ -81,6 +84,7 @@ module {
     return %0 : tensor<i32>
   }
 
+  // Custom prod reduction: stored f32 elements only.
   func.func @prod_sreduction_f32(%arga: tensor<32xf32, #SV>,
                                  %argx: tensor<f32>) -> tensor<f32> {
     %c = tensor.extract %argx[] : tensor<f32>
@@ -98,6 +102,42 @@ module {
     return %0 : tensor<f32>
   }
 
+  // Custom prod reduction: stored i32 elements and implicit zeros.
+  //
+  // NOTE: this is a somewhat strange operation, since for most sparse
+  //       situations the outcome would always be zero; it is added
+  //       to test full functionality and illustrate the subtle 
diff erences
+  //       between the various custom operations; it would make a bit more
+  //       sense for e.g. a min/max reductions, although it still would
+  //       "densify" the iteration space.
+  //
+  func.func @prod_xreduction_i32(%arga: tensor<32xi32, #SV>,
+                                 %argx: tensor<i32>) -> tensor<i32> {
+    %c = tensor.extract %argx[] : tensor<i32>
+    %0 = linalg.generic #trait_reduction
+      ins(%arga: tensor<32xi32, #SV>)
+      outs(%argx: tensor<i32>) {
+        ^bb(%a: i32, %b: i32):
+           %u = sparse_tensor.unary %a : i32 to i32
+           present={
+             ^bb0(%x: i32):
+             sparse_tensor.yield %x : i32
+           } absent={
+             ^bb0:
+             %c0 = arith.constant 0 : i32
+             sparse_tensor.yield %c0 : i32
+          }
+          %1 = sparse_tensor.reduce %u, %b, %c : i32 {
+            ^bb0(%x: i32, %y: i32):
+              %2 = arith.muli %x, %y : i32
+              sparse_tensor.yield %2 : i32
+          }
+          linalg.yield %1 : i32
+    } -> tensor<i32>
+    return %0 : tensor<i32>
+  }
+
+
   func.func @dump_i32(%arg0 : tensor<i32>) {
     %v = tensor.extract %arg0[] : tensor<i32>
     vector.print %v : i32
@@ -174,6 +214,8 @@ module {
     %6 = call @prod_sreduction_i32(%s1_i32, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
     %7 = call @prod_sreduction_f32(%s1_f32, %rf) : (tensor<32xf32, #SV>, tensor<f32>) -> tensor<f32>
     %8 = call @prod_sreduction_i32(%s0,     %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
+    %9 = call @prod_xreduction_i32(%s0_i32, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
+    %10 = call @prod_xreduction_i32(%s1_i32, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
 
     // Verify results. Note that the custom reduction gave permission
     // to treat an explicit vs implicit zero 
diff erently to compute the
@@ -190,6 +232,8 @@ module {
     // CHECK: 3087
     // CHECK: 168
     // CHECK: 0
+    // CHECK: 0
+    // CHECK: 3087
     //
     call @dump_i32(%0) : (tensor<i32>) -> ()
     call @dump_f32(%1) : (tensor<f32>) -> ()
@@ -200,6 +244,8 @@ module {
     call @dump_i32(%6) : (tensor<i32>) -> ()
     call @dump_f32(%7) : (tensor<f32>) -> ()
     call @dump_i32(%8) : (tensor<i32>) -> ()
+    call @dump_i32(%9) : (tensor<i32>) -> ()
+    call @dump_i32(%10) : (tensor<i32>) -> ()
 
     // Release the resources.
     bufferization.dealloc_tensor %d0_i32 : tensor<32xi32, #DV>


        


More information about the Mlir-commits mailing list