[Mlir-commits] [mlir] [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (PR #98074)

Rob Suderman llvmlistbot at llvm.org
Mon Jul 8 16:26:23 PDT 2024


================
@@ -214,18 +224,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   int64_t numElements = outType.getNumElements();
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
   if (outType.getShape().empty()) {
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarTrunc =
         rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
-    result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
-                                               ArrayRef<int64_t>{});
+    Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
+                                                     ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
 
+  VectorType flatTy =
+      VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+  if (inVectorTy.getShape().size() > 1) {
+    inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
----------------
rsuderman wrote:

Fixed.

https://github.com/llvm/llvm-project/pull/98074


More information about the Mlir-commits mailing list