[Mlir-commits] [mlir] b251b60 - [mlir][gpu] Unroll ops on vectors which map to intrinsic calls
Christian Sigg
llvmlistbot at llvm.org
Fri Oct 28 01:33:47 PDT 2022
Author: Christian Sigg
Date: 2022-10-28T10:33:38+02:00
New Revision: b251b608b5fc7c859bc73f0cb1b8cc16a626fecc
URL: https://github.com/llvm/llvm-project/commit/b251b608b5fc7c859bc73f0cb1b8cc16a626fecc
DIFF: https://github.com/llvm/llvm-project/commit/b251b608b5fc7c859bc73f0cb1b8cc16a626fecc.diff
LOG: [mlir][gpu] Unroll ops on vectors which map to intrinsic calls
Unroll ops that map to intrinsics when lowering to LLVM, because intrinsics don't support vector operands/results.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D136345
Added:
Modified:
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index ca30af169ffd..40fb8e25d312 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -355,3 +356,45 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
rewriter.eraseOp(gpuPrintfOp);
return success();
}
+
+/// Unrolls op if it's operating on vectors.
+LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &converter) {
+ TypeRange operandTypes(operands);
+ if (llvm::none_of(operandTypes,
+ [](Type type) { return type.isa<VectorType>(); })) {
+ return rewriter.notifyMatchFailure(op, "expected vector operand");
+ }
+ if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
+ return rewriter.notifyMatchFailure(op, "expected no region/successor");
+ if (op->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(op, "expected single result");
+ VectorType vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
+ if (!vectorType)
+ return rewriter.notifyMatchFailure(op, "expected vector result");
+
+ Location loc = op->getLoc();
+ Value result = rewriter.create<LLVM::UndefOp>(loc, vectorType);
+ Type indexType = converter.convertType(rewriter.getIndexType());
+ StringAttr name = op->getName().getIdentifier();
+ Type elementType = vectorType.getElementType();
+
+ for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
+ Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
+ auto extractElement = [&](Value operand) -> Value {
+ if (!operand.getType().isa<VectorType>())
+ return operand;
+ return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
+ };
+ auto scalarOperands =
+ llvm::to_vector(llvm::map_range(operands, extractElement));
+ Operation *scalarOp =
+ rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
+ rewriter.create<LLVM::InsertElementOp>(loc, result, scalarOp->getResult(0),
+ index);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 58d8ceae2404..17a3e9ff4723 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -78,6 +78,27 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
}
};
+namespace impl {
+/// Unrolls op if it's operating on vectors.
+LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &converter);
+} // namespace impl
+
+/// Rewriting that unrolls SourceOp to scalars if it's operating on vectors.
+template <typename SourceOp>
+struct ScalarizeVectorOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
+public:
+ using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return impl::scalarizeVectorOp(op, adaptor.getOperands(), rewriter,
+ *this->getTypeConverter());
+ }
+};
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index e5bc16ce54e1..6c3e8ab4c3b3 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -20,7 +20,6 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -28,10 +27,8 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
@@ -231,6 +228,14 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
}
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns, StringRef f32Func,
+ StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
populateWithGenerated(patterns);
@@ -254,42 +259,38 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getKernelFuncAttrName()));
- patterns.add<OpToFuncCallLowering<math::AbsFOp>>(converter, "__nv_fabsf",
- "__nv_fabs");
- patterns.add<OpToFuncCallLowering<math::AtanOp>>(converter, "__nv_atanf",
- "__nv_atan");
- patterns.add<OpToFuncCallLowering<math::Atan2Op>>(converter, "__nv_atan2f",
- "__nv_atan2");
- patterns.add<OpToFuncCallLowering<math::CeilOp>>(converter, "__nv_ceilf",
- "__nv_ceil");
- patterns.add<OpToFuncCallLowering<math::CosOp>>(converter, "__nv_cosf",
- "__nv_cos");
- patterns.add<OpToFuncCallLowering<math::ExpOp>>(converter, "__nv_expf",
- "__nv_exp");
- patterns.add<OpToFuncCallLowering<math::Exp2Op>>(converter, "__nv_exp2f",
- "__nv_exp2");
- patterns.add<OpToFuncCallLowering<math::ExpM1Op>>(converter, "__nv_expm1f",
- "__nv_expm1");
- patterns.add<OpToFuncCallLowering<math::FloorOp>>(converter, "__nv_floorf",
- "__nv_floor");
- patterns.add<OpToFuncCallLowering<math::LogOp>>(converter, "__nv_logf",
- "__nv_log");
- patterns.add<OpToFuncCallLowering<math::Log1pOp>>(converter, "__nv_log1pf",
- "__nv_log1p");
- patterns.add<OpToFuncCallLowering<math::Log10Op>>(converter, "__nv_log10f",
- "__nv_log10");
- patterns.add<OpToFuncCallLowering<math::Log2Op>>(converter, "__nv_log2f",
- "__nv_log2");
- patterns.add<OpToFuncCallLowering<math::PowFOp>>(converter, "__nv_powf",
- "__nv_pow");
- patterns.add<OpToFuncCallLowering<math::RsqrtOp>>(converter, "__nv_rsqrtf",
- "__nv_rsqrt");
- patterns.add<OpToFuncCallLowering<math::SinOp>>(converter, "__nv_sinf",
- "__nv_sin");
- patterns.add<OpToFuncCallLowering<math::SqrtOp>>(converter, "__nv_sqrtf",
- "__nv_sqrt");
- patterns.add<OpToFuncCallLowering<math::TanhOp>>(converter, "__nv_tanhf",
- "__nv_tanh");
+ populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
+ "__nv_fabs");
+ populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
+ "__nv_atan");
+ populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
+ "__nv_atan2");
+ populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
+ "__nv_ceil");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
+ populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
+ "__nv_exp2");
+ populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
+ "__nv_expm1");
+ populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
+ "__nv_floor");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
+ populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
+ "__nv_log1p");
+ populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
+ "__nv_log10");
+ populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
+ "__nv_log2");
+ populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
+ "__nv_pow");
+ populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
+ "__nv_rsqrt");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
+ populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
+ "__nv_sqrt");
+ populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
+ "__nv_tanh");
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 3a9c816f2fa1..1f8159017ee9 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -157,6 +157,14 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
}
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns, StringRef f32Func,
+ StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
void mlir::populateGpuToROCDLConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
mlir::gpu::amd::Runtime runtime) {
@@ -184,42 +192,42 @@ void mlir::populateGpuToROCDLConversionPatterns(
patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
}
- patterns.add<OpToFuncCallLowering<math::AbsFOp>>(converter, "__ocml_fabs_f32",
- "__ocml_fabs_f64");
- patterns.add<OpToFuncCallLowering<math::AtanOp>>(converter, "__ocml_atan_f32",
- "__ocml_atan_f64");
- patterns.add<OpToFuncCallLowering<math::Atan2Op>>(
- converter, "__ocml_atan2_f32", "__ocml_atan2_f64");
- patterns.add<OpToFuncCallLowering<math::CeilOp>>(converter, "__ocml_ceil_f32",
- "__ocml_ceil_f64");
- patterns.add<OpToFuncCallLowering<math::CosOp>>(converter, "__ocml_cos_f32",
- "__ocml_cos_f64");
- patterns.add<OpToFuncCallLowering<math::ExpOp>>(converter, "__ocml_exp_f32",
- "__ocml_exp_f64");
- patterns.add<OpToFuncCallLowering<math::Exp2Op>>(converter, "__ocml_exp2_f32",
- "__ocml_exp2_f64");
- patterns.add<OpToFuncCallLowering<math::ExpM1Op>>(
- converter, "__ocml_expm1_f32", "__ocml_expm1_f64");
- patterns.add<OpToFuncCallLowering<math::FloorOp>>(
- converter, "__ocml_floor_f32", "__ocml_floor_f64");
- patterns.add<OpToFuncCallLowering<math::LogOp>>(converter, "__ocml_log_f32",
- "__ocml_log_f64");
- patterns.add<OpToFuncCallLowering<math::Log10Op>>(
- converter, "__ocml_log10_f32", "__ocml_log10_f64");
- patterns.add<OpToFuncCallLowering<math::Log1pOp>>(
- converter, "__ocml_log1p_f32", "__ocml_log1p_f64");
- patterns.add<OpToFuncCallLowering<math::Log2Op>>(converter, "__ocml_log2_f32",
- "__ocml_log2_f64");
- patterns.add<OpToFuncCallLowering<math::PowFOp>>(converter, "__ocml_pow_f32",
- "__ocml_pow_f64");
- patterns.add<OpToFuncCallLowering<math::RsqrtOp>>(
- converter, "__ocml_rsqrt_f32", "__ocml_rsqrt_f64");
- patterns.add<OpToFuncCallLowering<math::SinOp>>(converter, "__ocml_sin_f32",
- "__ocml_sin_f64");
- patterns.add<OpToFuncCallLowering<math::SqrtOp>>(converter, "__ocml_sqrt_f32",
- "__ocml_sqrt_f64");
- patterns.add<OpToFuncCallLowering<math::TanhOp>>(converter, "__ocml_tanh_f32",
- "__ocml_tanh_f64");
+ populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
+ "__ocml_fabs_f64");
+ populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
+ "__ocml_atan_f64");
+ populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
+ "__ocml_atan2_f64");
+ populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
+ "__ocml_ceil_f64");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
+ "__ocml_cos_f64");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
+ "__ocml_exp_f64");
+ populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
+ "__ocml_exp2_f64");
+ populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
+ "__ocml_expm1_f64");
+ populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
+ "__ocml_floor_f64");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
+ "__ocml_log_f64");
+ populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
+ "__ocml_log10_f64");
+ populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
+ "__ocml_log1p_f64");
+ populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
+ "__ocml_log2_f64");
+ populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
+ "__ocml_pow_f64");
+ populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
+ "__ocml_rsqrt_f64");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
+ "__ocml_sin_f64");
+ populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
+ "__ocml_sqrt_f64");
+ populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
+ "__ocml_tanh_f64");
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index c036e1c1b6e6..4014a19f89d9 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -478,6 +478,20 @@ gpu.module @test_module {
// -----
+gpu.module @test_module {
+ // CHECK-LABEL: func @gpu_unroll
+ func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
+ %result = math.exp %arg0 : vector<4xf32>
+ // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
+ // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
+ // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
+ // CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
+ func.return %result : vector<4xf32>
+ }
+}
+
+// -----
+
gpu.module @test_module {
// CHECK-LABEL: @kernel_func
// CHECK: attributes
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index d2cab517822c..918fdcc08e90 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -377,6 +377,20 @@ gpu.module @test_module {
// -----
+gpu.module @test_module {
+ // CHECK-LABEL: func @gpu_unroll
+ func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
+ %result = math.exp %arg0 : vector<4xf32>
+ // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ func.return %result : vector<4xf32>
+ }
+}
+
+// -----
+
gpu.module @test_module {
// CHECK-LABEL: @kernel_func
// CHECK: attributes
More information about the Mlir-commits
mailing list