[Mlir-commits] [mlir] 4a411eb - [MLIR] Fix rewrite of ops with vector operands to LLVM on GPU (#127844)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 19 11:52:05 PST 2025


Author: Benoit Jacob
Date: 2025-02-19T14:52:02-05:00
New Revision: 4a411eb4ee673e2687d38fda16d6db6b907f37d2

URL: https://github.com/llvm/llvm-project/commit/4a411eb4ee673e2687d38fda16d6db6b907f37d2
DIFF: https://github.com/llvm/llvm-project/commit/4a411eb4ee673e2687d38fda16d6db6b907f37d2.diff

LOG: [MLIR] Fix rewrite of ops with vector operands to LLVM on GPU (#127844)

There was a discrepancy between the type-converter and rewrite-pattern
parts of conversion to LLVM used in various GPU targets, at least ROCDL
and NVVM:
- The TypeConverter part was handling vectors of arbitrary rank,
converting them to nests of `!llvm.array< ... >` with a vector at the
inner-most dimension:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp#L629-L655
- The rewrite pattern part was not handling `llvm.array`:
https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp#L594-L596

That led to conversion failures when lowering `math` dialect ops on
rank-2 vectors, as in the testcase being added in this PR.

This PR fixes this by reusing a shared utility already used in other
conversions to LLVM:

https://github.com/llvm/llvm-project/blob/8337d01e3058e7f47675f5b2b908b4e7821895d7/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp#L80-L104

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index cfa434699cdef..c3b3a78abe7f7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   return success();
 }
 
-/// Unrolls op if it's operating on vectors.
-LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
-                                      ConversionPatternRewriter &rewriter,
-                                      const LLVMTypeConverter &converter) {
+/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
+/// Used either directly (for ops on 1D vectors) or as the callback passed to
+/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
+static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
+                                     Type llvm1DVectorTy,
+                                     ConversionPatternRewriter &rewriter,
+                                     const LLVMTypeConverter &converter) {
   TypeRange operandTypes(operands);
-  if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
-    return rewriter.notifyMatchFailure(op, "expected vector operand");
-  }
-  if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
-    return rewriter.notifyMatchFailure(op, "expected no region/successor");
-  if (op->getNumResults() != 1)
-    return rewriter.notifyMatchFailure(op, "expected single result");
-  VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
-  if (!vectorType)
-    return rewriter.notifyMatchFailure(op, "expected vector result");
-
+  VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
   Location loc = op->getLoc();
   Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
   Type indexType = converter.convertType(rewriter.getIndexType());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
     result = rewriter.create<LLVM::InsertElementOp>(
         loc, result, scalarOp->getResult(0), index);
   }
+  return result;
+}
 
-  rewriter.replaceOp(op, result);
-  return success();
+/// Unrolls op to array/vector elements.
+LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
+                                      ConversionPatternRewriter &rewriter,
+                                      const LLVMTypeConverter &converter) {
+  TypeRange operandTypes(operands);
+  if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
+    VectorType vectorType = cast<VectorType>(op->getResultTypes()[0]);
+    rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
+                                                   rewriter, converter));
+    return success();
+  }
+
+  if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
+    return LLVM::detail::handleMultidimensionalVectors(
+        op, operands, converter,
+        [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+          return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
+                                         converter);
+        },
+        rewriter);
+  }
+
+  return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
 }
 
 static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index e73a74845d2b6..bd2fd020f684b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -172,13 +172,13 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
 };
 
 namespace impl {
-/// Unrolls op if it's operating on vectors.
+/// Unrolls op to array/vector elements.
 LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
                                 ConversionPatternRewriter &rewriter,
                                 const LLVMTypeConverter &converter);
 } // namespace impl
 
-/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
+/// Unrolls SourceOp to array/vector elements.
 template <typename SourceOp>
 struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
@@ -191,6 +191,7 @@ struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
                                    *this->getTypeConverter());
   }
 };
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_

diff  --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e4b2f01d6544a..9448304f11dbd 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -513,3 +513,54 @@ module {
     "test.possible_terminator"() : () -> ()
   }) : () -> ()
 }
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
+  // CHECK-LABEL: func @math_sin_vector_1d
+  func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<4xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    // CHECK: llvm.insertelement {{.*}} : vector<4xf16>
+    %result = math.sin %arg : vector<4xf16>
+    func.return %result : vector<4xf16>
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
+  // CHECK-LABEL: func @math_sin_vector_2d
+  func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
+    // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    // CHECK: llvm.insertelement {{.*}} : vector<2xf16>
+    // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>    
+    %result = math.sin %arg : vector<2x2xf16>
+    func.return %result : vector<2x2xf16>
+  }
+}


        


More information about the Mlir-commits mailing list