[Mlir-commits] [mlir] [MLIR][ROCDL] Add math.clampf -> rocdl.fmed3 conversion (PR #163259)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Oct 14 08:57:03 PDT 2025


================
@@ -42,8 +44,65 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
                                            f32ApproxFunc, f16Func);
 }
 
+struct ClampFOpConversion final
+    : public ConvertOpToLLVMPattern<math::ClampFOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+  ClampFOpConversion(const LLVMTypeConverter &converter,
+                     amdgpu::Chipset chipset)
+      : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
+
+  LogicalResult
+  matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only f16 and f32 types are supported by fmed3
+    Type opTy = op.getType();
+    auto resultType = getTypeConverter()->convertType(opTy);
+
+    if (auto vectorType = dyn_cast<VectorType>(opTy)) {
+      opTy = vectorType.getElementType();
+    }
+
+    if (!opTy.isF16() && !opTy.isF32()) {
+      return rewriter.notifyMatchFailure(
+          op, "fmed3 only supports f16 and f32 types");
+    }
+
+    // Handle multi-dimensional vectors (converted to LLVM arrays)
+    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
+      // Handle multi-dimensional vectors (converted to LLVM arrays)
+      return LLVM::detail::handleMultidimensionalVectors(
+          op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+          [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+            typename math::ClampFOp::Adaptor adaptor(operands);
+            return rewriter.create<ROCDL::FMed3Op>(
+                op.getLoc(), llvm1DVectorTy, adaptor.getValue(),
+                adaptor.getMin(), adaptor.getMax());
+          },
+          rewriter);
+    }
+
+    // Handle 1D vectors and scalars directly
+    rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
+                                                op.getMin(), op.getMax());
+    return success();
+  }
+
+  amdgpu::Chipset chipset;
+};
+
+void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
----------------
krzysz00 wrote:

```suggestion
static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
```

https://github.com/llvm/llvm-project/pull/163259


More information about the Mlir-commits mailing list