[Mlir-commits] [mlir] [mlir][vector] Add extra check on distribute types to avoid crashes (PR #102952)

Bangtian Liu llvmlistbot at llvm.org
Tue Aug 13 14:26:00 PDT 2024


https://github.com/bangtianliu updated https://github.com/llvm/llvm-project/pull/102952

>From 3edb3003a35c066a9701eddf6e202a8440efa407 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Mon, 12 Aug 2024 11:55:00 -0700
Subject: [PATCH 1/4] add extra check on distribute types to avoid crashes

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7285ad65fb549e..29899f44eb2e22 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1689,6 +1689,9 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
           }
         });
 
+    if(llvm::any_of(distTypes, [](Type type){return !type;}))
+      return failure();
+
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, escapingValues.getArrayRef(), distTypes,

>From 2a85257cf7b7e627f42e0c9cfe4bfc5451495e87 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Mon, 12 Aug 2024 12:12:15 -0700
Subject: [PATCH 2/4] format the code

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 29899f44eb2e22..6596f3fc0ed81f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1689,7 +1689,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
           }
         });
 
-    if(llvm::any_of(distTypes, [](Type type){return !type;}))
+    if (llvm::any_of(distTypes, [](Type type) { return !type; }))
       return failure();
 
     SmallVector<size_t> newRetIndices;

>From 26d70f67ac5a15a5f9aabb43e83f22b804f03ae0 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Tue, 13 Aug 2024 13:53:25 -0700
Subject: [PATCH 3/4] add a test

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 .../Vector/Transforms/VectorDistribute.cpp    |  2 +-
 .../Vector/vector-warp-distribute.mlir        | 32 +++++++++++++++++++
 2 files changed, 33 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 6596f3fc0ed81f..2289fd1ff1364e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1689,7 +1689,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
           }
         });
 
-    if (llvm::any_of(distTypes, [](Type type) { return !type; }))
+    if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
     SmallVector<size_t> newRetIndices;
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bf90c4a6ebb3c2..c4ec2d25cee390 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -620,6 +620,38 @@ func.func @vector_reduction(%laneid: index) -> (f32) {
 
 // -----
 
+// CHECK-PROP-LABEL: func @warp_distribute(
+// CHECK-PROP-SAME:   %[[ID:.*]]: index, %[[SRC:.+]]: memref<128xf32>, %[[DEST:.+]]: memref<128xf32>)
+// CHECK-PROP:   vector.warp_execute_on_lane_0(%[[ID]])[32]
+// CHECK-PROP-NEXT:     "some_def"() : () -> vector<4096xf32>
+// CHECK-PROP-NEXT: %{{.*}} = vector.reduction
+// CHECK-PROP-DAG: %[[DEF:.*]] = arith.divf %{{.*}}, %{{.*}} : vector<1xf32>
+// CHECK-PROP-NOT:   vector.warp_execute_on_lane_0
+// CHECK-PROP: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK-PROP: %{{.*}} = arith.subf %{{.*}}, %[[DEF]] : vector<1xf32>
+func.func @warp_distribute(%arg0: index, %src: memref<128xf32>, %dest: memref<128xf32>){
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c128 = arith.constant 128 : index
+  %f0 = arith.constant 0.000000e+00 : f32
+  vector.warp_execute_on_lane_0(%arg0)[32]{
+    %cst_1 = arith.constant dense<2.621440e+05> : vector<1xf32>
+    %0 = "some_def"() : () -> (vector<4096xf32>)
+    %1 = vector.reduction <add>, %0, %cst : vector<4096xf32> into f32
+    %2 = vector.broadcast %1 : f32 to vector<1xf32>
+    %3 = arith.divf %2, %cst_1 : vector<1xf32>
+    scf.for %arg1 = %c0 to %c128 step %c1 {
+        %4 = vector.transfer_read %src[%arg1], %f0 {in_bounds = [true]} : memref<128xf32>, vector<1xf32>
+        %5 = arith.subf %4, %3 : vector<1xf32>
+        vector.transfer_write %5, %dest[%arg1] : vector<1xf32>, memref<128xf32>
+    }
+  }
+  return
+}
+
+// -----
+
 func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref<f32>) {
   %c0 = arith.constant 0: index
   %f0 = arith.constant 0.0: f32

>From db3090f353a644108680aced9ade071129f6266e Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Tue, 13 Aug 2024 14:26:16 -0700
Subject: [PATCH 4/4] format the test file

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 .../Dialect/Vector/vector-warp-distribute.mlir | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index c4ec2d25cee390..0544cef3e38281 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -621,14 +621,16 @@ func.func @vector_reduction(%laneid: index) -> (f32) {
 // -----
 
 // CHECK-PROP-LABEL: func @warp_distribute(
-// CHECK-PROP-SAME:   %[[ID:.*]]: index, %[[SRC:.+]]: memref<128xf32>, %[[DEST:.+]]: memref<128xf32>)
-// CHECK-PROP:   vector.warp_execute_on_lane_0(%[[ID]])[32]
-// CHECK-PROP-NEXT:     "some_def"() : () -> vector<4096xf32>
-// CHECK-PROP-NEXT: %{{.*}} = vector.reduction
-// CHECK-PROP-DAG: %[[DEF:.*]] = arith.divf %{{.*}}, %{{.*}} : vector<1xf32>
-// CHECK-PROP-NOT:   vector.warp_execute_on_lane_0
-// CHECK-PROP: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}}
-// CHECK-PROP: %{{.*}} = arith.subf %{{.*}}, %[[DEF]] : vector<1xf32>
+//  CHECK-PROP-SAME:    %[[ID:[a-zA-Z0-9]+]]
+//  CHECK-PROP-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-PROP-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+//       CHECK-PROP:    vector.warp_execute_on_lane_0(%[[ID]])[32]
+//  CHECK-PROP-NEXT:      "some_def"() : () -> vector<4096xf32>
+//  CHECK-PROP-NEXT:      %{{.*}} = vector.reduction
+//       CHECK-PROP:      %[[DEF:.*]] = arith.divf %{{.*}}, %{{.*}} : vector<1xf32>
+//   CHECK-PROP-NOT:      vector.warp_execute_on_lane_0
+//       CHECK-PROP:      scf.for
+//       CHECK-PROP:        %{{.*}} = arith.subf %{{.*}}, %[[DEF]] : vector<1xf32>
 func.func @warp_distribute(%arg0: index, %src: memref<128xf32>, %dest: memref<128xf32>){
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list