[Mlir-commits] [mlir] [mlir][amdgpu] Add conversion from arith.scaling_extf to amdgpu (PR #146372)

Tim Gymnich llvmlistbot at llvm.org
Tue Jul 1 04:03:48 PDT 2025


================
@@ -395,6 +424,242 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   return success();
 }
 
+static Value getOriginalVectorValue(Value value) {
+  Value current = value;
+  while (Operation *definingOp = current.getDefiningOp()) {
+    bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
+                      .Case<vector::ShapeCastOp>([&current](auto op) {
+                        current = op.getSource();
+                        return true;
+                      })
+                      .Case<vector::BroadcastOp>([&current](auto op) {
+                        current = op.getSource();
+                        return false;
+                      })
+                      .Case<vector::SplatOp>([&current](auto op) {
+                        current = op.getInput();
+                        return false;
+                      })
+                      .Default([](Operation *) { return false; });
+
+    if (!skipOp) {
+      break;
+    }
+  }
+  return current;
+}
+
+LogicalResult
+ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
+                                           PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  constexpr const int64_t opWidth = 2;
+
+  Value in = op.getIn();
+  Value scale = op.getScale();
+  Value out = op.getOut();
+
+  Type f32 = rewriter.getF32Type();
+  Type inType = getElementTypeOrSelf(in);
+  Type scaleType = getElementTypeOrSelf(scale);
+  Type outType = getElementTypeOrSelf(out);
+  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+  VectorType inVecType = dyn_cast<VectorType>(in.getType());
+  VectorType outVecType = dyn_cast<VectorType>(out.getType());
+
+  if (outVecType && outVecType.isScalable())
+    return failure();
+
+  Type scaleF32Type =
+      scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+  if (scaleType.getIntOrFloatBitWidth() < 32)
+    scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+  else if (scaleType.getIntOrFloatBitWidth() > 32)
+    scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+  VectorType extScaleResultType = VectorType::get(opWidth, outType);
+
+  if (!outVecType) {
+    Value inCast =
+        rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+    // TODO: replace this with non-packed ScaledExtOp
+    Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+        loc, extScaleResultType, inCast, scale, 0);
+    scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+    return success();
+  }
+
+  Value origScale = getOriginalVectorValue(scale);
+  Type origScaleType = origScale.getType();
+  VectorType origScaleVecType = isa<VectorType>(origScaleType)
+                                    ? cast<VectorType>(origScaleType)
+                                    : VectorType::get(1, origScaleType);
+
+  ArrayRef<int64_t> originalScaleShape = origScaleVecType.getShape();
+  ArrayRef<int64_t> inShape = inVecType.getShape();
+
+  SmallVector<int64_t> paddedScaleShape(originalScaleShape);
+  paddedScaleShape.insert(paddedScaleShape.end(),
+                          inShape.size() - originalScaleShape.size(), 1);
+
+  auto ratio = computeShapeRatio(inShape, paddedScaleShape);
+  if (!ratio)
+    return failure();
+
+  const int64_t blockSize = computeProduct(*ratio);
+
+  Value zero = rewriter.create<arith::ConstantOp>(
+      loc, outType, rewriter.getFloatAttr(outType, 0.0));
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, *ratio)) {
----------------
tgymnich wrote:

this vector is created in-place on the stack.

https://github.com/llvm/llvm-project/pull/146372


More information about the Mlir-commits mailing list