[Mlir-commits] [mlir] b251b60 - [mlir][gpu] Unroll ops on vectors which map to intrinsic calls

Christian Sigg llvmlistbot at llvm.org
Fri Oct 28 01:33:47 PDT 2022


Author: Christian Sigg
Date: 2022-10-28T10:33:38+02:00
New Revision: b251b608b5fc7c859bc73f0cb1b8cc16a626fecc

URL: https://github.com/llvm/llvm-project/commit/b251b608b5fc7c859bc73f0cb1b8cc16a626fecc
DIFF: https://github.com/llvm/llvm-project/commit/b251b608b5fc7c859bc73f0cb1b8cc16a626fecc.diff

LOG: [mlir][gpu] Unroll ops on vectors which map to intrinsic calls

Unroll ops that map to intrinsics when lowering to LLVM, because intrinsics don't support vector operands/results.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D136345

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
    mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index ca30af169ffd..40fb8e25d312 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -355,3 +356,45 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
   rewriter.eraseOp(gpuPrintfOp);
   return success();
 }
+
+/// Unrolls op if it's operating on vectors.
+LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
+                                      ConversionPatternRewriter &rewriter,
+                                      LLVMTypeConverter &converter) {
+  TypeRange operandTypes(operands);
+  if (llvm::none_of(operandTypes,
+                    [](Type type) { return type.isa<VectorType>(); })) {
+    return rewriter.notifyMatchFailure(op, "expected vector operand");
+  }
+  if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
+    return rewriter.notifyMatchFailure(op, "expected no region/successor");
+  if (op->getNumResults() != 1)
+    return rewriter.notifyMatchFailure(op, "expected single result");
+  VectorType vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
+  if (!vectorType)
+    return rewriter.notifyMatchFailure(op, "expected vector result");
+
+  Location loc = op->getLoc();
+  Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
+  Type indexType = converter.convertType(rewriter.getIndexType());
+  StringAttr name = op->getName().getIdentifier();
+  Type elementType = vectorType.getElementType();
+
+  for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
+    Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
+    auto extractElement = [&](Value operand) -> Value {
+      if (!operand.getType().isa<VectorType>())
+        return operand;
+      return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
+    };
+    auto scalarOperands =
+        llvm::to_vector(llvm::map_range(operands, extractElement));
+    Operation *scalarOp =
+        rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
+    rewriter.create<LLVM::InsertElementOp>(loc, result, scalarOp->getResult(0),
+                                           index);
+  }
+
+  rewriter.replaceOp(op, result);
+  return success();
+}

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 58d8ceae2404..17a3e9ff4723 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -78,6 +78,27 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
   }
 };
 
+namespace impl {
+/// Unrolls op if it's operating on vectors.
+LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
+                                ConversionPatternRewriter &rewriter,
+                                LLVMTypeConverter &converter);
+} // namespace impl
+
+/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
+template <typename SourceOp>
+struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
+public:
+  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    return impl::scalarizeVectorOp(op, adaptor.getOperands(), rewriter,
+                                   *this->getTypeConverter());
+  }
+};
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index e5bc16ce54e1..6c3e8ab4c3b3 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -20,7 +20,6 @@
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -28,10 +27,8 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
 
 #include "../GPUCommon/GPUOpsLowering.h"
 #include "../GPUCommon/IndexIntrinsicsOpLowering.h"
@@ -231,6 +228,14 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
 }
 
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns, StringRef f32Func,
+                               StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
 void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns) {
   populateWithGenerated(patterns);
@@ -254,42 +259,38 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
       StringAttr::get(&converter.getContext(),
                       NVVM::NVVMDialect::getKernelFuncAttrName()));
 
-  patterns.add<OpToFuncCallLowering<math::AbsFOp>>(converter, "__nv_fabsf",
-                                                   "__nv_fabs");
-  patterns.add<OpToFuncCallLowering<math::AtanOp>>(converter, "__nv_atanf",
-                                                   "__nv_atan");
-  patterns.add<OpToFuncCallLowering<math::Atan2Op>>(converter, "__nv_atan2f",
-                                                    "__nv_atan2");
-  patterns.add<OpToFuncCallLowering<math::CeilOp>>(converter, "__nv_ceilf",
-                                                   "__nv_ceil");
-  patterns.add<OpToFuncCallLowering<math::CosOp>>(converter, "__nv_cosf",
-                                                  "__nv_cos");
-  patterns.add<OpToFuncCallLowering<math::ExpOp>>(converter, "__nv_expf",
-                                                  "__nv_exp");
-  patterns.add<OpToFuncCallLowering<math::Exp2Op>>(converter, "__nv_exp2f",
-                                                   "__nv_exp2");
-  patterns.add<OpToFuncCallLowering<math::ExpM1Op>>(converter, "__nv_expm1f",
-                                                    "__nv_expm1");
-  patterns.add<OpToFuncCallLowering<math::FloorOp>>(converter, "__nv_floorf",
-                                                    "__nv_floor");
-  patterns.add<OpToFuncCallLowering<math::LogOp>>(converter, "__nv_logf",
-                                                  "__nv_log");
-  patterns.add<OpToFuncCallLowering<math::Log1pOp>>(converter, "__nv_log1pf",
-                                                    "__nv_log1p");
-  patterns.add<OpToFuncCallLowering<math::Log10Op>>(converter, "__nv_log10f",
-                                                    "__nv_log10");
-  patterns.add<OpToFuncCallLowering<math::Log2Op>>(converter, "__nv_log2f",
-                                                   "__nv_log2");
-  patterns.add<OpToFuncCallLowering<math::PowFOp>>(converter, "__nv_powf",
-                                                   "__nv_pow");
-  patterns.add<OpToFuncCallLowering<math::RsqrtOp>>(converter, "__nv_rsqrtf",
-                                                    "__nv_rsqrt");
-  patterns.add<OpToFuncCallLowering<math::SinOp>>(converter, "__nv_sinf",
-                                                  "__nv_sin");
-  patterns.add<OpToFuncCallLowering<math::SqrtOp>>(converter, "__nv_sqrtf",
-                                                   "__nv_sqrt");
-  patterns.add<OpToFuncCallLowering<math::TanhOp>>(converter, "__nv_tanhf",
-                                                   "__nv_tanh");
+  populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
+                                   "__nv_fabs");
+  populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
+                                   "__nv_atan");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
+                                    "__nv_atan2");
+  populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
+                                   "__nv_ceil");
+  populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
+                                   "__nv_exp2");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
+                                    "__nv_expm1");
+  populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
+                                    "__nv_floor");
+  populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
+                                    "__nv_log1p");
+  populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
+                                    "__nv_log10");
+  populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
+                                   "__nv_log2");
+  populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
+                                   "__nv_pow");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
+                                    "__nv_rsqrt");
+  populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
+  populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
+                                   "__nv_sqrt");
+  populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
+                                   "__nv_tanh");
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 3a9c816f2fa1..1f8159017ee9 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -157,6 +157,14 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
 }
 
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns, StringRef f32Func,
+                               StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
 void mlir::populateGpuToROCDLConversionPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns,
     mlir::gpu::amd::Runtime runtime) {
@@ -184,42 +192,42 @@ void mlir::populateGpuToROCDLConversionPatterns(
     patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
   }
 
-  patterns.add<OpToFuncCallLowering<math::AbsFOp>>(converter, "__ocml_fabs_f32",
-                                                   "__ocml_fabs_f64");
-  patterns.add<OpToFuncCallLowering<math::AtanOp>>(converter, "__ocml_atan_f32",
-                                                   "__ocml_atan_f64");
-  patterns.add<OpToFuncCallLowering<math::Atan2Op>>(
-      converter, "__ocml_atan2_f32", "__ocml_atan2_f64");
-  patterns.add<OpToFuncCallLowering<math::CeilOp>>(converter, "__ocml_ceil_f32",
-                                                   "__ocml_ceil_f64");
-  patterns.add<OpToFuncCallLowering<math::CosOp>>(converter, "__ocml_cos_f32",
-                                                  "__ocml_cos_f64");
-  patterns.add<OpToFuncCallLowering<math::ExpOp>>(converter, "__ocml_exp_f32",
-                                                  "__ocml_exp_f64");
-  patterns.add<OpToFuncCallLowering<math::Exp2Op>>(converter, "__ocml_exp2_f32",
-                                                   "__ocml_exp2_f64");
-  patterns.add<OpToFuncCallLowering<math::ExpM1Op>>(
-      converter, "__ocml_expm1_f32", "__ocml_expm1_f64");
-  patterns.add<OpToFuncCallLowering<math::FloorOp>>(
-      converter, "__ocml_floor_f32", "__ocml_floor_f64");
-  patterns.add<OpToFuncCallLowering<math::LogOp>>(converter, "__ocml_log_f32",
-                                                  "__ocml_log_f64");
-  patterns.add<OpToFuncCallLowering<math::Log10Op>>(
-      converter, "__ocml_log10_f32", "__ocml_log10_f64");
-  patterns.add<OpToFuncCallLowering<math::Log1pOp>>(
-      converter, "__ocml_log1p_f32", "__ocml_log1p_f64");
-  patterns.add<OpToFuncCallLowering<math::Log2Op>>(converter, "__ocml_log2_f32",
-                                                   "__ocml_log2_f64");
-  patterns.add<OpToFuncCallLowering<math::PowFOp>>(converter, "__ocml_pow_f32",
-                                                   "__ocml_pow_f64");
-  patterns.add<OpToFuncCallLowering<math::RsqrtOp>>(
-      converter, "__ocml_rsqrt_f32", "__ocml_rsqrt_f64");
-  patterns.add<OpToFuncCallLowering<math::SinOp>>(converter, "__ocml_sin_f32",
-                                                  "__ocml_sin_f64");
-  patterns.add<OpToFuncCallLowering<math::SqrtOp>>(converter, "__ocml_sqrt_f32",
-                                                   "__ocml_sqrt_f64");
-  patterns.add<OpToFuncCallLowering<math::TanhOp>>(converter, "__ocml_tanh_f32",
-                                                   "__ocml_tanh_f64");
+  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
+                                   "__ocml_fabs_f64");
+  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
+                                   "__ocml_atan_f64");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
+                                    "__ocml_atan2_f64");
+  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
+                                   "__ocml_ceil_f64");
+  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
+                                  "__ocml_cos_f64");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
+                                  "__ocml_exp_f64");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
+                                   "__ocml_exp2_f64");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
+                                    "__ocml_expm1_f64");
+  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
+                                    "__ocml_floor_f64");
+  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
+                                  "__ocml_log_f64");
+  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
+                                    "__ocml_log10_f64");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
+                                    "__ocml_log1p_f64");
+  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
+                                   "__ocml_log2_f64");
+  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
+                                   "__ocml_pow_f64");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
+                                    "__ocml_rsqrt_f64");
+  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
+                                  "__ocml_sin_f64");
+  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
+                                   "__ocml_sqrt_f64");
+  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
+                                   "__ocml_tanh_f64");
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index c036e1c1b6e6..4014a19f89d9 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -478,6 +478,20 @@ gpu.module @test_module {
 
 // -----
 
+gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_unroll
+  func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
+    %result = math.exp %arg0 : vector<4xf32>
+    // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
+    // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
+    // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
+    // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
+    func.return %result : vector<4xf32>
+  }
+}
+
+// -----
+
 gpu.module @test_module {
   // CHECK-LABEL: @kernel_func
   // CHECK: attributes

diff  --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index d2cab517822c..918fdcc08e90 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -377,6 +377,20 @@ gpu.module @test_module {
 
 // -----
 
+gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_unroll
+  func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
+    %result = math.exp %arg0 : vector<4xf32>
+    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    func.return %result : vector<4xf32>
+  }
+}
+
+// -----
+
 gpu.module @test_module {
   // CHECK-LABEL: @kernel_func
   // CHECK: attributes


        


More information about the Mlir-commits mailing list