[Mlir-commits] [mlir] andrzej/extend vector to llvm test 8 (PR #111997)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Oct 11 06:43:50 PDT 2024
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/111997
- **[mlir][vector] Add more tests for ConvertVectorToLLVM (7/n)**
- **fixup! [mlir][vector] Add more tests for ConvertVectorToLLVM (7/n)**
- **[mlir][vector] Add more tests for ConvertVectorToLLVM (8/n)**
>From 26aaeec0c822d80efb3c8a7fbb7ad00e8f80d699 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 10 Oct 2024 19:59:41 +0100
Subject: [PATCH 1/3] [mlir][vector] Add more tests for ConvertVectorToLLVM
(7/n)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
* vector.fma
* vector.reduce
---
.../VectorToLLVM/vector-to-llvm.mlir | 308 ++++++++++++++++++
1 file changed, 308 insertions(+)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ae1d6fe8bd1672..c7e76ea9a5bbc9 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2001,6 +2001,37 @@ func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf
return %0, %1, %2, %3: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>
}
+func.func @vector_fma_scalable(%a: vector<[8]xf32>, %b: vector<2x[4]xf32>, %c: vector<1x1x[1]xf32>, %d: vector<f32>) -> (vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>) {
+ // CHECK-LABEL: @vector_fma
+ // CHECK-SAME: %[[A:.*]]: vector<[8]xf32>
+ // CHECK-SAME: %[[B:.*]]: vector<2x[4]xf32>
+ // CHECK-SAME: %[[C:.*]]: vector<1x1x[1]xf32>
+ // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x[4]xf32> to !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: llvm.intr.fmuladd
+ // CHECK-SAME: (vector<[8]xf32>, vector<[8]xf32>, vector<[8]xf32>) -> vector<[8]xf32>
+ %0 = vector.fma %a, %a, %a : vector<[8]xf32>
+
+ // CHECK: %[[b00:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[b01:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[b02:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[B0:.*]] = llvm.intr.fmuladd(%[[b00]], %[[b01]], %[[b02]]) :
+ // CHECK-SAME: (vector<[4]xf32>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[b10:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[b11:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[b12:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[B1:.*]] = llvm.intr.fmuladd(%[[b10]], %[[b11]], %[[b12]]) :
+ // CHECK-SAME: (vector<[4]xf32>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm.array<2 x vector<[4]xf32>>
+ %1 = vector.fma %b, %b, %b : vector<2x[4]xf32>
+
+ // CHECK: %[[C0:.*]] = llvm.intr.fmuladd
+ // CHECK-SAME: (vector<[1]xf32>, vector<[1]xf32>, vector<[1]xf32>) -> vector<[1]xf32>
+ %2 = vector.fma %c, %c, %c : vector<1x1x[1]xf32>
+
+ return %0, %1, %2: vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>
+}
+
// -----
func.func @reduce_0d_f32(%arg0: vector<f32>) -> f32 {
@@ -2028,6 +2059,17 @@ func.func @reduce_f16(%arg0: vector<16xf16>) -> f16 {
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f16, vector<16xf16>) -> f16
// CHECK: return %[[V]] : f16
+func.func @reduce_f16_scalable(%arg0: vector<[16]xf16>) -> f16 {
+ %0 = vector.reduction <add>, %arg0 : vector<[16]xf16> into f16
+ return %0 : f16
+}
+// CHECK-LABEL: @reduce_f16_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf16>)
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f16) : f16
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
+// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f16, vector<[16]xf16>) -> f16
+// CHECK: return %[[V]] : f16
+
// -----
func.func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
@@ -2041,6 +2083,17 @@ func.func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f32, vector<16xf32>) -> f32
// CHECK: return %[[V]] : f32
+func.func @reduce_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<[16]xf32> into f32
+ return %0 : f32
+}
+// CHECK-LABEL: @reduce_f32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>)
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
+// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f32, vector<[16]xf32>) -> f32
+// CHECK: return %[[V]] : f32
+
// -----
func.func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
@@ -2054,6 +2107,17 @@ func.func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f64, vector<16xf64>) -> f64
// CHECK: return %[[V]] : f64
+func.func @reduce_f64_scalable(%arg0: vector<[16]xf64>) -> f64 {
+ %0 = vector.reduction <add>, %arg0 : vector<[16]xf64> into f64
+ return %0 : f64
+}
+// CHECK-LABEL: @reduce_f64_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf64>)
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : f64
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]])
+// CHECK-SAME: <{fastmathFlags = #llvm.fastmath<none>}> : (f64, vector<[16]xf64>) -> f64
+// CHECK: return %[[V]] : f64
+
// -----
func.func @reduce_i8(%arg0: vector<16xi8>) -> i8 {
@@ -2065,6 +2129,15 @@ func.func @reduce_i8(%arg0: vector<16xi8>) -> i8 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
// CHECK: return %[[V]] : i8
+func.func @reduce_i8_scalable(%arg0: vector<[16]xi8>) -> i8 {
+ %0 = vector.reduction <add>, %arg0 : vector<[16]xi8> into i8
+ return %0 : i8
+}
+// CHECK-LABEL: @reduce_i8_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi8>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
+// CHECK: return %[[V]] : i8
+
// -----
func.func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2076,6 +2149,15 @@ func.func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <add>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2088,6 +2170,16 @@ func.func @reduce_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.add %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.add %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_mul_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2099,6 +2191,15 @@ func.func @reduce_mul_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_mul_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <mul>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_mul_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_mul_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2111,6 +2212,16 @@ func.func @reduce_mul_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.mul %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_mul_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <mul>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_mul_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.mul %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_fmaximum_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
@@ -2123,6 +2234,16 @@ func.func @reduce_fmaximum_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
// CHECK: %[[R:.*]] = llvm.intr.maximum(%[[V]], %[[B]]) : (f32, f32) -> f32
// CHECK: return %[[R]] : f32
+func.func @reduce_fmaximum_f32_scalable(%arg0: vector<[16]xf32>, %arg1: f32) -> f32 {
+ %0 = vector.reduction <maximumf>, %arg0, %arg1 : vector<[16]xf32> into f32
+ return %0 : f32
+}
+// CHECK-LABEL: @reduce_fmaximum_f32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>, %[[B:.*]]: f32)
+// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmaximum(%[[A]]) : (vector<[16]xf32>) -> f32
+// CHECK: %[[R:.*]] = llvm.intr.maximum(%[[V]], %[[B]]) : (f32, f32) -> f32
+// CHECK: return %[[R]] : f32
+
// -----
func.func @reduce_fminimum_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
@@ -2135,6 +2256,16 @@ func.func @reduce_fminimum_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
// CHECK: %[[R:.*]] = llvm.intr.minimum(%[[V]], %[[B]]) : (f32, f32) -> f32
// CHECK: return %[[R]] : f32
+func.func @reduce_fminimum_f32_scalable(%arg0: vector<[16]xf32>, %arg1: f32) -> f32 {
+ %0 = vector.reduction <minimumf>, %arg0, %arg1 : vector<[16]xf32> into f32
+ return %0 : f32
+}
+// CHECK-LABEL: @reduce_fminimum_f32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>, %[[B:.*]]: f32)
+// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fminimum(%[[A]]) : (vector<[16]xf32>) -> f32
+// CHECK: %[[R:.*]] = llvm.intr.minimum(%[[V]], %[[B]]) : (f32, f32) -> f32
+// CHECK: return %[[R]] : f32
+
// -----
func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
@@ -2147,6 +2278,16 @@ func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
// CHECK: %[[R:.*]] = llvm.intr.maxnum(%[[V]], %[[B]]) : (f32, f32) -> f32
// CHECK: return %[[R]] : f32
+func.func @reduce_fmax_f32_scalable(%arg0: vector<[16]xf32>, %arg1: f32) -> f32 {
+ %0 = vector.reduction <maxnumf>, %arg0, %arg1 : vector<[16]xf32> into f32
+ return %0 : f32
+}
+// CHECK-LABEL: @reduce_fmax_f32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>, %[[B:.*]]: f32)
+// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<[16]xf32>) -> f32
+// CHECK: %[[R:.*]] = llvm.intr.maxnum(%[[V]], %[[B]]) : (f32, f32) -> f32
+// CHECK: return %[[R]] : f32
+
// -----
func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
@@ -2159,6 +2300,16 @@ func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
// CHECK: %[[R:.*]] = llvm.intr.minnum(%[[V]], %[[B]]) : (f32, f32) -> f32
// CHECK: return %[[R]] : f32
+func.func @reduce_fmin_f32_scalable(%arg0: vector<[16]xf32>, %arg1: f32) -> f32 {
+ %0 = vector.reduction <minnumf>, %arg0, %arg1 : vector<[16]xf32> into f32
+ return %0 : f32
+}
+// CHECK-LABEL: @reduce_fmin_f32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>, %[[B:.*]]: f32)
+// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<[16]xf32>) -> f32
+// CHECK: %[[R:.*]] = llvm.intr.minnum(%[[V]], %[[B]]) : (f32, f32) -> f32
+// CHECK: return %[[R]] : f32
+
// -----
func.func @reduce_minui_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2170,6 +2321,15 @@ func.func @reduce_minui_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_minui_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <minui>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_minui_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_minui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2183,6 +2343,17 @@ func.func @reduce_minui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_minui_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <minui>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_minui_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]])
+// CHECK: %[[S:.*]] = llvm.icmp "ule" %[[ACC]], %[[R]]
+// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_maxui_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2194,6 +2365,15 @@ func.func @reduce_maxui_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_maxui_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <maxui>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxui_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_maxui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2207,6 +2387,17 @@ func.func @reduce_maxui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_maxui_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <maxui>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxui_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]])
+// CHECK: %[[S:.*]] = llvm.icmp "uge" %[[ACC]], %[[R]]
+// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_minsi_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2218,6 +2409,15 @@ func.func @reduce_minsi_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_minsi_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <minsi>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_minsi_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_minsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2231,6 +2431,17 @@ func.func @reduce_minsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_minsi_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <minsi>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_minsi_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]])
+// CHECK: %[[S:.*]] = llvm.icmp "sle" %[[ACC]], %[[R]]
+// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_maxsi_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2242,6 +2453,15 @@ func.func @reduce_maxsi_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_maxsi_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <maxsi>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxsi_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_maxsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2255,6 +2475,17 @@ func.func @reduce_maxsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_maxsi_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <maxsi>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxsi_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]])
+// CHECK: %[[S:.*]] = llvm.icmp "sge" %[[ACC]], %[[R]]
+// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_and_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2266,6 +2497,15 @@ func.func @reduce_and_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.and"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_and_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <and>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_and_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.and"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_and_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2278,6 +2518,16 @@ func.func @reduce_and_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.and %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_and_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <and>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_and_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.and"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.and %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_or_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2289,6 +2539,15 @@ func.func @reduce_or_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.or"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_or_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <or>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_or_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.or"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_or_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2301,6 +2560,16 @@ func.func @reduce_or_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.or %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_or_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <or>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_or_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.or"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.or %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_xor_i32(%arg0: vector<16xi32>) -> i32 {
@@ -2312,6 +2581,15 @@ func.func @reduce_xor_i32(%arg0: vector<16xi32>) -> i32 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]])
// CHECK: return %[[V]] : i32
+func.func @reduce_xor_i32_scalable(%arg0: vector<[16]xi32>) -> i32 {
+ %0 = vector.reduction <xor>, %arg0 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_xor_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]])
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_xor_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
@@ -2324,6 +2602,16 @@ func.func @reduce_xor_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
// CHECK: %[[V:.*]] = llvm.xor %[[ACC]], %[[R]]
// CHECK: return %[[V]] : i32
+func.func @reduce_xor_acc_i32_scalable(%arg0: vector<[16]xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <xor>, %arg0, %arg1 : vector<[16]xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_xor_acc_i32_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.xor %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
// -----
func.func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
@@ -2335,6 +2623,15 @@ func.func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
// CHECK: return %[[V]] : i64
+func.func @reduce_i64_scalable(%arg0: vector<[16]xi64>) -> i64 {
+ %0 = vector.reduction <add>, %arg0 : vector<[16]xi64> into i64
+ return %0 : i64
+}
+// CHECK-LABEL: @reduce_i64_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xi64>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
+// CHECK: return %[[V]] : i64
+
// -----
func.func @reduce_index(%arg0: vector<16xindex>) -> index {
@@ -2348,6 +2645,17 @@ func.func @reduce_index(%arg0: vector<16xindex>) -> index {
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : i64 to index
// CHECK: return %[[T2]] : index
+func.func @reduce_index_scalable(%arg0: vector<[16]xindex>) -> index {
+ %0 = vector.reduction <add>, %arg0 : vector<[16]xindex> into index
+ return %0 : index
+}
+// CHECK-LABEL: @reduce_index_scalable(
+// CHECK-SAME: %[[A:.*]]: vector<[16]xindex>)
+// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[16]xindex> to vector<[16]xi64>
+// CHECK: %[[T1:.*]] = "llvm.intr.vector.reduce.add"(%[[T0]])
+// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : i64 to index
+// CHECK: return %[[T2]] : index
+
// 4x16 16x3 4x3
// -----
>From a8bdf82ad5fa8df32c09fc965b861ee1eb9b0845 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 11 Oct 2024 13:07:41 +0100
Subject: [PATCH 2/3] fixup! [mlir][vector] Add more tests for
ConvertVectorToLLVM (7/n)
Fix CHECK-LABEL
---
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c7e76ea9a5bbc9..243082d2ba9aa9 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2002,7 +2002,7 @@ func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf
}
func.func @vector_fma_scalable(%a: vector<[8]xf32>, %b: vector<2x[4]xf32>, %c: vector<1x1x[1]xf32>, %d: vector<f32>) -> (vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>) {
- // CHECK-LABEL: @vector_fma
+ // CHECK-LABEL: @vector_fma_scalable
// CHECK-SAME: %[[A:.*]]: vector<[8]xf32>
// CHECK-SAME: %[[B:.*]]: vector<2x[4]xf32>
// CHECK-SAME: %[[C:.*]]: vector<1x1x[1]xf32>
>From 455a109ef8f7c31e34dfe037ffd31c0212db104a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 11 Oct 2024 14:37:07 +0100
Subject: [PATCH 3/3] [mlir][vector] Add more tests for ConvertVectorToLLVM
(8/n)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
* vector.transfer_read
* vector.transfer_write
In addition:
* Tests that test both xfer_read and xfer_write have their names
updated to capture that (e.g. `@transfer_read_1d_mask` ->
`@transfer_read_write_1d_mask`)
* `@transfer_write_1d_scalable_mask` and
`@transfer_read_1d_scalable_mask` are re-written as
`@transfer_read_write_1d_mask_scalable`. This is to make it clear
that this case is meant to complement
`@transfer_read_write_1d_mask`.
* `@transfer_write_tensor` is updated to also test `xfer_read`.
---
.../VectorToLLVM/vector-to-llvm.mlir | 223 +++++++++++++++---
1 file changed, 191 insertions(+), 32 deletions(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 243082d2ba9aa9..a63068efe0ba76 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2685,7 +2685,7 @@ func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vecto
// -----
-func.func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
+func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
%f7 = arith.constant 7.0: f32
%f = vector.transfer_read %A[%base], %f7
{permutation_map = affine_map<(d0) -> (d0)>} :
@@ -2695,7 +2695,7 @@ func.func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32>
vector<17xf32>, memref<?xf32>
return %f: vector<17xf32>
}
-// CHECK-LABEL: func @transfer_read_1d
+// CHECK-LABEL: func @transfer_read_write_1d
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
// CHECK: %[[C7:.*]] = arith.constant 7.0
@@ -2757,9 +2757,77 @@ func.func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32>
// CHECK-SAME: {alignment = 4 : i32} :
// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
+func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) -> vector<[17]xf32> {
+ %f7 = arith.constant 7.0: f32
+ %f = vector.transfer_read %A[%base], %f7
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<[17]xf32>
+ vector.transfer_write %f, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<[17]xf32>, memref<?xf32>
+ return %f: vector<[17]xf32>
+}
+// CHECK-LABEL: func @transfer_read_write_1d_scalable
+// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
+// CHECK: %[[C7:.*]] = arith.constant 7.0
+//
+// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
+// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
+//
+// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]xi32>
+//
+// 3. Create bound vector to compute in-bound mask:
+// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
+// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32
+// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
+// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
+// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+// CHECK-SAME: : vector<[17]xi32>
+//
+// 4. Create pass-through vector.
+// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
+//
+// 5. Bitcast to vector form.
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
+// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
+//
+// 6. Rewrite as a masked read.
+// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[gep]], %[[mask]],
+// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
+// CHECK-SAME: -> vector<[17]xf32>
+//
+// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+// CHECK: %[[C0_b:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
+// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
+//
+// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// CHECK: %[[linearIndex_b:.*]] = llvm.intr.stepvector : vector<[17]xi32>
+//
+// 3. Create bound vector to compute in-bound mask:
+// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
+// CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] : index to i32
+// CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
+// CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
+// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
+// CHECK-SAME: %[[boundVect_b]] : vector<[17]xi32>
+//
+// 4. Bitcast to vector form.
+// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
+//
+// 5. Rewrite as a masked write.
+// CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
+// CHECK-SAME: {alignment = 4 : i32} :
+// CHECK-SAME: vector<[17]xf32>, vector<[17]xi1> into !llvm.ptr
+
// -----
-func.func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xindex> {
+func.func @transfer_read_write_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xindex> {
%f7 = arith.constant 7: index
%f = vector.transfer_read %A[%base], %f7
{permutation_map = affine_map<(d0) -> (d0)>} :
@@ -2769,7 +2837,7 @@ func.func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<
vector<17xindex>, memref<?xindex>
return %f: vector<17xindex>
}
-// CHECK-LABEL: func @transfer_read_index_1d
+// CHECK-LABEL: func @transfer_read_write_index_1d
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<17xindex>
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64>
@@ -2780,6 +2848,27 @@ func.func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<
// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
// CHECK-SAME: vector<17xi64>, vector<17xi1> into !llvm.ptr
+func.func @transfer_read_write_index_1d_scalable(%A : memref<?xindex>, %base: index) -> vector<[17]xindex> {
+ %f7 = arith.constant 7: index
+ %f = vector.transfer_read %A[%base], %f7
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xindex>, vector<[17]xindex>
+ vector.transfer_write %f, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<[17]xindex>, memref<?xindex>
+ return %f: vector<[17]xindex>
+}
+// CHECK-LABEL: func @transfer_read_write_index_1d
+// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xindex>
+// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<[17]xindex>
+// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<[17]xindex> to vector<[17]xi64>
+
+// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
+// CHECK-SAME: (!llvm.ptr, vector<[17]xi1>, vector<[17]xi64>) -> vector<[17]xi64>
+
+// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
+// CHECK-SAME: vector<[17]xi64>, vector<[17]xi1> into !llvm.ptr
+
// -----
func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
@@ -2809,9 +2898,34 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<[17]xf32> {
+ %f7 = arith.constant 7.0: f32
+ %f = vector.transfer_read %A[%base0, %base1], %f7
+ {permutation_map = affine_map<(d0, d1) -> (d1)>} :
+ memref<?x?xf32>, vector<[17]xf32>
+ return %f: vector<[17]xf32>
+}
+// CHECK-LABEL: func @transfer_read_2d_to_1d
+// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
+//
+// Compute the in-bound index (dim - offset)
+// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
+//
+// Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]xi32>
+//
+// Create bound vector to compute in-bound mask:
+// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
+// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32
+// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
+// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
+// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+
// -----
-func.func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -> vector<17xf32> {
+func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -> vector<17xf32> {
%f7 = arith.constant 7.0: f32
%f = vector.transfer_read %A[%base], %f7
{permutation_map = affine_map<(d0) -> (d0)>} :
@@ -2821,7 +2935,7 @@ func.func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: ind
vector<17xf32>, memref<?xf32, 3>
return %f: vector<17xf32>
}
-// CHECK-LABEL: func @transfer_read_1d_non_zero_addrspace
+// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
//
// 1. Check address space for GEP is correct.
@@ -2836,6 +2950,31 @@ func.func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: ind
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
+func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref<?xf32, 3>, %base: index) -> vector<[17]xf32> {
+ %f7 = arith.constant 7.0: f32
+ %f = vector.transfer_read %A[%base], %f7
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32, 3>, vector<[17]xf32>
+ vector.transfer_write %f, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<[17]xf32>, memref<?xf32, 3>
+ return %f: vector<[17]xf32>
+}
+// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
+// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
+//
+// 1. Check address space for GEP is correct.
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
+//
+// 2. Check address space of the memref is correct.
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
+//
+// 3. Check address space for GEP is correct.
+// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
+
// -----
func.func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
@@ -2854,51 +2993,71 @@ func.func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector
// 2. Rewrite as a load.
// CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<17xf32>
+func.func @transfer_read_1d_inbounds_scalable(%A : memref<?xf32>, %base: index) -> vector<[17]xf32> {
+ %f7 = arith.constant 7.0: f32
+ %f = vector.transfer_read %A[%base], %f7 {in_bounds = [true]} :
+ memref<?xf32>, vector<[17]xf32>
+ return %f: vector<[17]xf32>
+}
+// CHECK-LABEL: func @transfer_read_1d_inbounds_scalable
+// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
+//
+// 1. Bitcast to vector form.
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
+//
+// 2. Rewrite as a load.
+// CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<[17]xf32>
+
// -----
-// CHECK-LABEL: func @transfer_read_1d_mask
+// CHECK-LABEL: func @transfer_read_write_1d_mask
// CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
// CHECK: %[[cmpi:.*]] = arith.cmpi slt
// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
+// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
+// CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
+// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
// CHECK: return %[[r]]
-func.func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
+func.func @transfer_read_write_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
%f7 = arith.constant 7.0: f32
%f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<5xf32>
+ vector.transfer_write %f, %A[%base], %m : vector<5xf32>, memref<?xf32>
return %f: vector<5xf32>
}
-// -----
-
-// CHECK-LABEL: func @transfer_read_1d_scalable_mask
-// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
-// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
-// CHECK: return %[[r]] : vector<[4]xf32>
-func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> {
- %c0 = arith.constant 0 : index
- %pad = arith.constant 0.0 : f32
- %vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32>
- return %vec : vector<[4]xf32>
+// CHECK-LABEL: func @transfer_read_write_1d_mask_scalable
+// CHECK-SAME: %[[mask:[a-zA-Z0-9]*]]: vector<[5]xi1>
+// CHECK: %[[cmpi:.*]] = arith.cmpi slt
+// CHECK: %[[mask1:.*]] = arith.andi %[[cmpi]], %[[mask]]
+// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask1]]
+// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
+// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi_1]], %[[mask]]
+// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask2]]
+// CHECK: return %[[r]]
+func.func @transfer_read_write_1d_mask_scalable(%A : memref<?xf32>, %base : index, %m : vector<[5]xi1>) -> vector<[5]xf32> {
+ %f7 = arith.constant 7.0: f32
+ %f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<[5]xf32>
+ vector.transfer_write %f, %A[%base], %m : vector<[5]xf32>, memref<?xf32>
+ return %f: vector<[5]xf32>
}
// -----
-// CHECK-LABEL: func @transfer_write_1d_scalable_mask
-// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr
-func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32>
- return
-}
-// -----
+// Can't lower xfer_read/xfer_write on tensors, but this shouldn't crash
-// CHECK-LABEL: func @transfer_write_tensor
+// CHECK-LABEL: func @transfer_read_write_tensor
+// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
-func.func @transfer_write_tensor(%arg0: vector<4xf32>,%arg1: tensor<?xf32>) -> tensor<?xf32> {
- %c0 = arith.constant 0 : index
- %0 = vector.transfer_write %arg0, %arg1[%c0] : vector<4xf32>, tensor<?xf32>
- return %0 : tensor<?xf32>
+func.func @transfer_read_write_tensor(%A: tensor<?xf32>, %base : index) -> vector<4xf32> {
+ %f7 = arith.constant 7.0: f32
+ %c0 = arith.constant 0: index
+ %f = vector.transfer_read %A[%base], %f7 : tensor<?xf32>, vector<4xf32>
+ %w = vector.transfer_write %f, %A[%c0] : vector<4xf32>, tensor<?xf32>
+ "test.some_use"(%w) : (tensor<?xf32>) -> ()
+ return %f : vector<4xf32>
}
// -----
More information about the Mlir-commits
mailing list