[Mlir-commits] [mlir] 7889090 - [mlir][math] Propagate scalability in `convert-math-to-llvm` (#82635)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 23 01:49:02 PST 2024


Author: Benjamin Maxwell
Date: 2024-02-23T09:48:58Z
New Revision: 78890904c41cc4221839dafb7ae906971a9db51a

URL: https://github.com/llvm/llvm-project/commit/78890904c41cc4221839dafb7ae906971a9db51a
DIFF: https://github.com/llvm/llvm-project/commit/78890904c41cc4221839dafb7ae906971a9db51a.diff

LOG: [mlir][math] Propagate scalability in `convert-math-to-llvm` (#82635)

This also generally increases the coverage of scalable vector types in
the math-to-llvm tests.

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
    mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 1b729611a36235..23e957288eb95e 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -148,10 +148,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
     return LLVM::detail::handleMultidimensionalVectors(
         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
         [&](Type llvm1DVectorTy, ValueRange operands) {
+          auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
           auto splatAttr = SplatElementsAttr::get(
-              mlir::VectorType::get(
-                  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
-                  floatType),
+              mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+                                    {numElements.isScalable()}),
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -207,10 +207,10 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
     return LLVM::detail::handleMultidimensionalVectors(
         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
         [&](Type llvm1DVectorTy, ValueRange operands) {
+          auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
           auto splatAttr = SplatElementsAttr::get(
-              mlir::VectorType::get(
-                  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
-                  floatType),
+              mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+                                    {numElements.isScalable()}),
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -266,10 +266,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
     return LLVM::detail::handleMultidimensionalVectors(
         op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
         [&](Type llvm1DVectorTy, ValueRange operands) {
+          auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
           auto splatAttr = SplatElementsAttr::get(
-              mlir::VectorType::get(
-                  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
-                  floatType),
+              mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+                                    {numElements.isScalable()}),
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);

diff  --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 3de2f11d1d12c7..56129dbd278892 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -77,6 +77,18 @@ func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @log1p_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+  // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[VEC]]  : vector<[4]xf32>
+  // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]])  : (vector<[4]xf32>) -> vector<[4]xf32>
+  %0 = math.log1p %arg0 : vector<[4]xf32>
+  func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @expm1(
 // CHECK-SAME: f32
 func.func @expm1(%arg0 : f32) {
@@ -113,6 +125,18 @@ func.func @expm1_vector(%arg0 : vector<4xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @expm1_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @expm1_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+  // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[VEC]])  : (vector<[4]xf32>) -> vector<[4]xf32>
+  // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<[4]xf32>
+  %0 = math.expm1 %arg0 : vector<[4]xf32>
+  func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @expm1_vector_fmf(
 // CHECK-SAME: vector<4xf32>
 func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
@@ -177,6 +201,16 @@ func.func @cttz_vec(%arg0 : vector<4xi32>) {
 
 // -----
 
+// CHECK-LABEL: func @cttz_scalable_vec(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
+func.func @cttz_scalable_vec(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: "llvm.intr.cttz"(%[[VEC]]) <{is_zero_poison = false}> : (vector<[4]xi32>) -> vector<[4]xi32>
+  %0 = math.cttz %arg0 : vector<[4]xi32>
+  func.return %0 : vector<[4]xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @ctpop(
 // CHECK-SAME: i32
 func.func @ctpop(%arg0 : i32) {
@@ -197,6 +231,16 @@ func.func @ctpop_vector(%arg0 : vector<3xi32>) {
 
 // -----
 
+// CHECK-LABEL: func @ctpop_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
+func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: llvm.intr.ctpop(%[[VEC]]) : (vector<[4]xi32>) -> vector<[4]xi32>
+  %0 = math.ctpop %arg0 : vector<[4]xi32>
+  func.return %0 : vector<[4]xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt_double(
 // CHECK-SAME: f64
 func.func @rsqrt_double(%arg0 : f64) {
@@ -233,6 +277,18 @@ func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @rsqrt_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @rsqrt_scalable_vector(%arg0 : vector<[4]xf32>) ->  vector<[4]xf32>{
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+  // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
+  %0 = math.rsqrt %arg0 : vector<[4]xf32>
+  func.return  %0 : vector<[4]xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt_vector_fmf(
 // CHECK-SAME: vector<4xf32>
 func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
@@ -245,6 +301,18 @@ func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @rsqrt_scalable_vector_fmf(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @rsqrt_scalable_vector_fmf(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+  // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) {fastmathFlags = #llvm.fastmath<fast>} : (vector<[4]xf32>) -> vector<[4]xf32>
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<[4]xf32>
+  %0 = math.rsqrt %arg0 fastmath<fast> : vector<[4]xf32>
+  func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt_multidim_vector(
 func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
   // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -258,6 +326,19 @@ func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @rsqrt_multidim_scalable_vector(
+func.func @rsqrt_multidim_scalable_vector(%arg0 : vector<4x[4]xf32>) -> vector<4x[4]xf32> {
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+  // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[EXTRACT]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
+  // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
+  %0 = math.rsqrt %arg0 : vector<4x[4]xf32>
+  func.return %0 : vector<4x[4]xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fpowi(
 // CHECK-SAME: f64
 func.func @fpowi(%arg0 : f64, %arg1 : i32) {


        


More information about the Mlir-commits mailing list