[Mlir-commits] [mlir] [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (PR #98074)

Rob Suderman llvmlistbot at llvm.org
Mon Jul 8 16:25:59 PDT 2024


================
@@ -81,28 +78,37 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  if (!isa<VectorType>(in.getType())) {
+  auto inType = dyn_cast<VectorType>(in.getType());
+  if (!inType) {
     Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
         loc, rewriter.getF32Type(), in, 0);
     Value result = castF32To(outElemType, asFloat, loc, rewriter);
     return rewriter.replaceOp(op, result);
   }
-  VectorType inType = cast<VectorType>(in.getType());
   int64_t numElements = inType.getNumElements();
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result =
-      rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
   if (inType.getShape().empty()) {
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarExt =
         rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
-    result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
-                                               ArrayRef<int64_t>{});
+    Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+                                                     ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
+
+  VectorType flatTy =
+      VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+  if (inType.getShape().size() > 1) {
----------------
rsuderman wrote:

Fixed

https://github.com/llvm/llvm-project/pull/98074


More information about the Mlir-commits mailing list