[Mlir-commits] [mlir] [MLIR][VectorToLLVM] Handle scalable dim in createVectorLengthValue() (PR #93361)

Zhaoshi Zheng llvmlistbot at llvm.org
Wed May 29 12:57:16 PDT 2024


https://github.com/zhaoshiz updated https://github.com/llvm/llvm-project/pull/93361

>From e6156c2fe895e64eb367def898ac5ac46ab99e71 Mon Sep 17 00:00:00 2001
From: Zhaoshi Zheng <zhaoshiz at quicinc.com>
Date: Fri, 24 May 2024 17:40:16 -0700
Subject: [PATCH 1/3] [MLIR][VectorToLLVM] Handle scalable dim in
 createVectorLengthValue()

LLVM's Vector Predication Intrinsics require an explicit vector length
parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.

For a scalable vector type, this should be caculated as VectorScaleOp
multiplied by base vector length, e.g.: for <[4]xf32> we should return:
vscale * 4.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 14 ++++++-
 .../vector-reduction-to-llvm.mlir             | 38 +++++++++++++++++++
 2 files changed, 50 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fe6bcc1c8b667..18bd9660525b4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
                                      llvmType);
 }
 
-/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
+/// Creates a value with the 1-D vector shape provided in `llvmType`.
 /// This is used as effective vector length by some intrinsics supporting
 /// dynamic vector lengths at runtime.
 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
@@ -532,9 +532,19 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
   auto vShape = vType.getShape();
   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
 
-  return rewriter.create<LLVM::ConstantOp>(
+  Value vLen = rewriter.create<LLVM::ConstantOp>(
       loc, rewriter.getI32Type(),
       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
+
+  if (!vType.getScalableDims()[0])
+    return vLen;
+
+  // Create VScale*vShape[0] and return it as vector length.
+  Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
+  vScale = rewriter.create<arith::IndexCastOp>(
+      loc, rewriter.getI32Type(), vScale);
+  vLen = rewriter.create<arith::MulIOp>(loc, vLen, vScale);
+  return vLen;
 }
 
 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index f98a05f8d17e2..209afa217437b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
 // CHECK:           "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
 
 
+// -----
+
+func.func @masked_reduce_add_f32_scalable(%arg0: vector<[4]xf32>, %mask : vector<[4]xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_add_f32_scalable(
+// CHECK-SAME:                              %[[INPUT:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                              %[[MASK:.*]]: vector<[4]xi1>) -> f32 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK:           %[[VL_BASE:.*]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK:           %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK:           %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK:           %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK:           %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK:           "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[4]xf32>, vector<[4]xi1>, i32) -> f32
+
+
 // -----
 
 func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
@@ -167,6 +186,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
 // CHECK:           "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
 
 
+// -----
+
+func.func @masked_reduce_add_i8_scalable(%arg0: vector<[16]xi8>, %mask : vector<[16]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[16]xi8> into i8 } : vector<[16]xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_add_i8_scalable(
+// CHECK-SAME:                             %[[INPUT:.*]]: vector<[16]xi8>,
+// CHECK-SAME:                             %[[MASK:.*]]: vector<[16]xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK:           %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK:           %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK:           %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK:           "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[16]xi8>, vector<[16]xi1>, i32) -> i8
+
+
 // -----
 
 func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {

>From 1addb3c41ec9e611883933e5b821df8b1befadd0 Mon Sep 17 00:00:00 2001
From: Zhaoshi Zheng <zhaoshiz at quicinc.com>
Date: Fri, 24 May 2024 20:10:37 -0700
Subject: [PATCH 2/3] Update per clang-format. NFC.

---
 mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 18bd9660525b4..abb522d8081ad 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -541,8 +541,8 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
 
   // Create VScale*vShape[0] and return it as vector length.
   Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
-  vScale = rewriter.create<arith::IndexCastOp>(
-      loc, rewriter.getI32Type(), vScale);
+  vScale =
+      rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
   vLen = rewriter.create<arith::MulIOp>(loc, vLen, vScale);
   return vLen;
 }

>From 6c588fcf94ead9d99d4cda4324cb0c77e01251c5 Mon Sep 17 00:00:00 2001
From: Zhaoshi Zheng <zhaoshiz at quicinc.com>
Date: Wed, 29 May 2024 12:38:43 -0700
Subject: [PATCH 3/3] Improve readability and add more test cases. NFC.

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 11 +--
 .../vector-reduction-to-llvm.mlir             | 72 +++++++++++++++++++
 2 files changed, 78 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index abb522d8081ad..a1f79ada34968 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -532,19 +532,20 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
   auto vShape = vType.getShape();
   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
 
-  Value vLen = rewriter.create<LLVM::ConstantOp>(
+  Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
       loc, rewriter.getI32Type(),
       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
 
   if (!vType.getScalableDims()[0])
-    return vLen;
+    return baseVecLength;
 
-  // Create VScale*vShape[0] and return it as vector length.
+  // For a scalable vector type, create and return `vScale * baseVecLength`.
   Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
   vScale =
       rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
-  vLen = rewriter.create<arith::MulIOp>(loc, vLen, vScale);
-  return vLen;
+  Value scalableVecLength =
+      rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
+  return scalableVecLength;
 }
 
 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index 209afa217437b..56f82ce54bff5 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -129,6 +129,24 @@ func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
 
 // -----
 
+func.func @masked_reduce_minf_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <minnumf>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_minf_f32_scalable(
+// CHECK-SAME:                                      %[[INPUT:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                      %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
+// CHECK:           %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK:           %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK:           %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK:           %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK:           "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
+
+// -----
+
 func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
   %0 = vector.mask %mask { vector.reduction <maxnumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
   return %0 : f32
@@ -235,6 +253,24 @@ func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -
 
 // -----
 
+func.func @masked_reduce_minui_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_minui_i8_scalable(
+// CHECK-SAME:                               %[[INPUT:.*]]: vector<[32]xi8>,
+// CHECK-SAME:                               %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
+// CHECK:           %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK:           %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK:           %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK:           %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK:           "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+
+// -----
+
 func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
   %0 = vector.mask %mask { vector.reduction <maxui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
   return %0 : i8
@@ -277,6 +313,24 @@ func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -
 
 // -----
 
+func.func @masked_reduce_maxsi_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_maxsi_i8_scalable(
+// CHECK-SAME:                               %[[INPUT:.*]]: vector<[32]xi8>,
+// CHECK-SAME:                               %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
+// CHECK:           %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK:           %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK:           %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK:           %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK:           "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+
+// -----
+
 func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
   %0 = vector.mask %mask { vector.reduction <or>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
   return %0 : i8
@@ -318,4 +372,22 @@ func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
 // CHECK:           %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
 // CHECK:           "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
 
+// -----
+
+func.func @masked_reduce_xor_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_xor_i8_scalable(
+// CHECK-SAME:                             %[[INPUT:.*]]: vector<[32]xi8>,
+// CHECK-SAME:                             %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK:           %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK:           %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK:           %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK:           %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK:           "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+
 



More information about the Mlir-commits mailing list