[Mlir-commits] [mlir] [MLIR] Fix rewrite of ops with vector operands to LLVM on GPU (PR #127844)

Benoit Jacob llvmlistbot at llvm.org
Wed Feb 19 11:50:38 PST 2025


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/127844

>From d5f407ebaecdfaee6d4e7bc15bc9f18c44f66a7c Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 19 Feb 2025 09:59:23 -0600
Subject: [PATCH 1/2] llvm-array-unroll

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 51 ++++++++++++-------
 .../lib/Conversion/GPUCommon/GPUOpsLowering.h |  5 +-
 .../Conversion/MathToROCDL/math-to-rocdl.mlir | 41 +++++++++++++++
 3 files changed, 78 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index cfa434699cdef..c3b3a78abe7f7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   return success();
 }
 
-/// Unrolls op if it's operating on vectors.
-LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
-                                      ConversionPatternRewriter &rewriter,
-                                      const LLVMTypeConverter &converter) {
+/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
+/// Used either directly (for ops on 1D vectors) or as the callback passed to
+/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
+static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
+                                     Type llvm1DVectorTy,
+                                     ConversionPatternRewriter &rewriter,
+                                     const LLVMTypeConverter &converter) {
   TypeRange operandTypes(operands);
-  if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
-    return rewriter.notifyMatchFailure(op, "expected vector operand");
-  }
-  if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
-    return rewriter.notifyMatchFailure(op, "expected no region/successor");
-  if (op->getNumResults() != 1)
-    return rewriter.notifyMatchFailure(op, "expected single result");
-  VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
-  if (!vectorType)
-    return rewriter.notifyMatchFailure(op, "expected vector result");
-
+  VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
   Location loc = op->getLoc();
   Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
   Type indexType = converter.convertType(rewriter.getIndexType());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
     result = rewriter.create<LLVM::InsertElementOp>(
         loc, result, scalarOp->getResult(0), index);
   }
+  return result;
+}
 
-  rewriter.replaceOp(op, result);
-  return success();
+/// Unrolls op to array/vector elements.
+LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
+                                      ConversionPatternRewriter &rewriter,
+                                      const LLVMTypeConverter &converter) {
+  TypeRange operandTypes(operands);
+  if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
+    VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
+    rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
+                                                   rewriter, converter));
+    return success();
+  }
+
+  if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
+    return LLVM::detail::handleMultidimensionalVectors(
+        op, operands, converter,
+        [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+          return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
+                                         converter);
+        },
+        rewriter);
+  }
+
+  return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
 }
 
 static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index e73a74845d2b6..bd2fd020f684b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -172,13 +172,13 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
 };
 
 namespace impl {
-/// Unrolls op if it's operating on vectors.
+/// Unrolls op to array/vector elements.
 LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
                                 ConversionPatternRewriter &rewriter,
                                 const LLVMTypeConverter &converter);
 } // namespace impl
 
-/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
+/// Unrolls SourceOp to array/vector elements.
 template <typename SourceOp>
 struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
@@ -191,6 +191,7 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
                                    *this->getTypeConverter());
   }
 };
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e4b2f01d6544a..b6493ca9b32c3 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -513,3 +513,44 @@ module {
     "test.possible_terminator"() : () -> ()
   }) : () -> ()
 }
+
+// -----
+
+module @test_module {
+  // CHECK-LABEL: func @math_sin_vector_1d
+  func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    %result = math.sin %arg : vector<4xf16>
+    func.return %result : vector<4xf16>
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK-LABEL: func @math_sin_vector_2d
+  func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
+    // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>    
+    %result = math.sin %arg : vector<2x2xf16>
+    func.return %result : vector<2x2xf16>
+  }
+}

>From 524888966a756eecef92189ff228b758e56bf8fa Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 19 Feb 2025 13:50:21 -0600
Subject: [PATCH 2/2] review-comment

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index b6493ca9b32c3..9448304f11dbd 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -517,15 +517,20 @@ module {
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
   // CHECK-LABEL: func @math_sin_vector_1d
   func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
     // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
     // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
     // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
     // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
     %result = math.sin %arg : vector<4xf16>
     func.return %result : vector<4xf16>
@@ -535,19 +540,24 @@ module @test_module {
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
   // CHECK-LABEL: func @math_sin_vector_2d
   func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
     // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
     // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
     // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
     // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
     // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
     // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
     // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
     // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
     // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>    
     %result = math.sin %arg : vector<2x2xf16>



More information about the Mlir-commits mailing list