[Mlir-commits] [mlir] [mlir][spirv] Add tests for `scf.while` and `scf.for` in `convert-to-spirv` pass (PR #102528)

Angel Zhang llvmlistbot at llvm.org
Fri Aug 9 11:27:40 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/102528

>From 7242400f6be026f39d05e293959dc3e0f5692290 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 8 Aug 2024 16:45:20 +0000
Subject: [PATCH 1/2] [mlir][spirv] Add tests for scf.while and scf.for in
 convert-to-spirv

---
 mlir/test/Conversion/ConvertToSPIRV/scf.mlir | 54 ++++++++++++++++++++
 1 file changed, 54 insertions(+)

diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
index 246464928b81c0..9ff5ae7a91b1df 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -33,6 +33,23 @@ func.func @if_yield(%arg0: i1) -> f32 {
 }
 
 // CHECK-LABEL: @while
+// CHECK-SAME: (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK:       %[[INITVAR:.*]] = spirv.Constant 2 : i32
+// CHECK:       %[[VAR1:.*]] = spirv.Variable : !spirv.ptr<i32, Function>
+// CHECK:       spirv.mlir.loop {
+// CHECK:         spirv.Branch ^[[HEADER:.*]](%[[ARG1]] : i32)
+// CHECK:       ^[[HEADER]](%[[INDVAR1:.*]]: i32):
+// CHECK:         %[[CMP:.*]] = spirv.SLessThan %[[INDVAR1]], %[[ARG2]] : i32
+// CHECK:         spirv.Store "Function" %[[VAR1]], %[[INDVAR1]] : i32
+// CHECK:         spirv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[INDVAR1]] : i32), ^[[MERGE:.*]]
+// CHECK:       ^[[BODY]](%[[INDVAR2:.*]]: i32):
+// CHECK:         %[[UPDATED:.*]] = spirv.IMul %[[INDVAR2]], %[[INITVAR]] : i32
+// CHECK:       spirv.Branch ^[[HEADER]](%[[UPDATED]] : i32)
+// CHECK:       ^[[MERGE]]:
+// CHECK:         spirv.mlir.merge
+// CHECK:       }
+// CHECK:       %[[OUT:.*]] = spirv.Load "Function" %[[VAR1]] : i32
+// CHECK:       spirv.ReturnValue %[[OUT]] : i32
 func.func @while(%arg0: i32, %arg1: i32) -> i32 {
   %c2_i32 = arith.constant 2 : i32
   %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
@@ -45,3 +62,40 @@ func.func @while(%arg0: i32, %arg1: i32) -> i32 {
   }
   return %0 : i32
 }
+
+// CHECK-LABEL: @for
+// CHECK:       %[[LB:.*]] = spirv.Constant 4 : i32
+// CHECK:       %[[UB:.*]] = spirv.Constant 42 : i32
+// CHECK:       %[[STEP:.*]] = spirv.Constant 2 : i32
+// CHECK:       %[[INITVAR1:.*]] = spirv.Constant 0.000000e+00 : f32
+// CHECK:       %[[INITVAR2:.*]] = spirv.Constant 1.000000e+00 : f32
+// CHECK:       %[[VAR1:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
+// CHECK:       %[[VAR2:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
+// CHECK:       spirv.mlir.loop {
+// CHECK:         spirv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32)
+// CHECK:       ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32):
+// CHECK:         %[[CMP:.*]] = spirv.SLessThan %[[INDVAR]], %[[UB]] : i32
+// CHECK:         spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
+// CHECK:       ^[[BODY]]:
+// CHECK:         %[[UPDATED:.*]] = spirv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32
+// CHECK-DAG:     %[[INCREMENT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
+// CHECK-DAG:     spirv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32
+// CHECK-DAG:     spirv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32
+// CHECK:       spirv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32)
+// CHECK:       ^[[MERGE]]:
+// CHECK:         spirv.mlir.merge
+// CHECK:       }
+// CHECK-DAG:  %[[OUT1:.*]] = spirv.Load "Function" %[[VAR1]] : f32
+// CHECK-DAG:  %[[OUT2:.*]] = spirv.Load "Function" %[[VAR2]] : f32
+func.func @for() {
+  %lb = arith.constant 4 : index
+  %ub = arith.constant 42 : index
+  %step = arith.constant 2 : index
+  %s0 = arith.constant 0.0 : f32
+  %s1 = arith.constant 1.0 : f32
+  %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
+    %sn = arith.addf %si, %si : f32
+    scf.yield %sn, %sn: f32, f32
+  }
+  return
+}

>From 9cee07d550607d39c5c319f4328f36906f627376 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Fri, 9 Aug 2024 18:27:27 +0000
Subject: [PATCH 2/2] Trim tests

---
 mlir/test/Conversion/ConvertToSPIRV/scf.mlir | 54 +++++++-------------
 1 file changed, 19 insertions(+), 35 deletions(-)

diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
index 9ff5ae7a91b1df..8f0a9b75a06e1f 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir
@@ -33,23 +33,17 @@ func.func @if_yield(%arg0: i1) -> f32 {
 }
 
 // CHECK-LABEL: @while
-// CHECK-SAME: (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
-// CHECK:       %[[INITVAR:.*]] = spirv.Constant 2 : i32
-// CHECK:       %[[VAR1:.*]] = spirv.Variable : !spirv.ptr<i32, Function>
 // CHECK:       spirv.mlir.loop {
-// CHECK:         spirv.Branch ^[[HEADER:.*]](%[[ARG1]] : i32)
-// CHECK:       ^[[HEADER]](%[[INDVAR1:.*]]: i32):
-// CHECK:         %[[CMP:.*]] = spirv.SLessThan %[[INDVAR1]], %[[ARG2]] : i32
-// CHECK:         spirv.Store "Function" %[[VAR1]], %[[INDVAR1]] : i32
-// CHECK:         spirv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[INDVAR1]] : i32), ^[[MERGE:.*]]
-// CHECK:       ^[[BODY]](%[[INDVAR2:.*]]: i32):
-// CHECK:         %[[UPDATED:.*]] = spirv.IMul %[[INDVAR2]], %[[INITVAR]] : i32
-// CHECK:       spirv.Branch ^[[HEADER]](%[[UPDATED]] : i32)
-// CHECK:       ^[[MERGE]]:
+// CHECK:         spirv.Branch ^[[HEADER:.*]](%{{.*}} : i32)
+// CHECK:       ^[[HEADER]]
+// CHECK:         spirv.Store "Function"
+// CHECK:         spirv.BranchConditional %{{.*}}, ^[[BODY:.*]](%{{.*}} : i32), ^[[MERGE:.*]]
+// CHECK:       ^[[BODY]]
+// CHECK:       spirv.Branch
+// CHECK:       ^[[MERGE]]
 // CHECK:         spirv.mlir.merge
 // CHECK:       }
-// CHECK:       %[[OUT:.*]] = spirv.Load "Function" %[[VAR1]] : i32
-// CHECK:       spirv.ReturnValue %[[OUT]] : i32
+// CHECK:       spirv.Load "Function"
 func.func @while(%arg0: i32, %arg1: i32) -> i32 {
   %c2_i32 = arith.constant 2 : i32
   %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) {
@@ -64,29 +58,19 @@ func.func @while(%arg0: i32, %arg1: i32) -> i32 {
 }
 
 // CHECK-LABEL: @for
-// CHECK:       %[[LB:.*]] = spirv.Constant 4 : i32
-// CHECK:       %[[UB:.*]] = spirv.Constant 42 : i32
-// CHECK:       %[[STEP:.*]] = spirv.Constant 2 : i32
-// CHECK:       %[[INITVAR1:.*]] = spirv.Constant 0.000000e+00 : f32
-// CHECK:       %[[INITVAR2:.*]] = spirv.Constant 1.000000e+00 : f32
-// CHECK:       %[[VAR1:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
-// CHECK:       %[[VAR2:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
 // CHECK:       spirv.mlir.loop {
-// CHECK:         spirv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32)
-// CHECK:       ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32):
-// CHECK:         %[[CMP:.*]] = spirv.SLessThan %[[INDVAR]], %[[UB]] : i32
-// CHECK:         spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
-// CHECK:       ^[[BODY]]:
-// CHECK:         %[[UPDATED:.*]] = spirv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32
-// CHECK-DAG:     %[[INCREMENT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
-// CHECK-DAG:     spirv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32
-// CHECK-DAG:     spirv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32
-// CHECK:       spirv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32)
-// CHECK:       ^[[MERGE]]:
+// CHECK:         spirv.Branch ^[[HEADER:.*]](%{{.*}}, %{{.*}}, %{{.*}} : i32, f32, f32)
+// CHECK:       ^[[HEADER]]
+// CHECK:         spirv.BranchConditional %{{.*}}, ^[[BODY:.*]], ^[[MERGE:.*]]
+// CHECK:       ^[[BODY]]
+// CHECK-DAG:     spirv.Store "Function"
+// CHECK-DAG:     spirv.Store "Function"
+// CHECK:       spirv.Branch ^[[HEADER]]
+// CHECK:       ^[[MERGE]]
 // CHECK:         spirv.mlir.merge
-// CHECK:       }
-// CHECK-DAG:  %[[OUT1:.*]] = spirv.Load "Function" %[[VAR1]] : f32
-// CHECK-DAG:  %[[OUT2:.*]] = spirv.Load "Function" %[[VAR2]] : f32
+// CHECK:      }
+// CHECK-DAG:  spirv.Load "Function"
+// CHECK-DAG:  spirv.Load "Function"
 func.func @for() {
   %lb = arith.constant 4 : index
   %ub = arith.constant 42 : index



More information about the Mlir-commits mailing list