[Mlir-commits] [mlir] [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (PR #98074)

Rob Suderman llvmlistbot at llvm.org
Mon Jul 8 16:25:55 PDT 2024


https://github.com/rsuderman updated https://github.com/llvm/llvm-project/pull/98074

>From c71df3b8206ca11a8b20177beb147041b3630f5f Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Mon, 8 Jul 2024 13:30:06 -0700
Subject: [PATCH 1/2] [mlir][amdgpu] Add support for multi-dim
 arith.truncf/extf fp8 lowering

The existing `fp8` lowering from `arith` to `amdgpu` bails out on the
multidimensional case. We can handle this by `vector.shape_cast`
collapsing to the 1-D case on extraction and re-casting back to the
desired output shape.
---
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           | 56 +++++++++++++-----
 .../ArithToAMDGPU/8-bit-floats.mlir           | 58 +++++++++++++++++++
 2 files changed, 98 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 3d3ff001c541b..7524db4da5917 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -68,9 +68,6 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
   if (auto inVecType = dyn_cast<VectorType>(inType)) {
     if (inVecType.isScalable())
       return failure();
-    if (inVecType.getShape().size() > 1)
-      // Multi-dimensional vectors are currently unsupported.
-      return failure();
     inType = inVecType.getElementType();
   }
   return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
@@ -81,28 +78,37 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  if (!isa<VectorType>(in.getType())) {
+  auto inType = dyn_cast<VectorType>(in.getType());
+  if (!inType) {
     Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
         loc, rewriter.getF32Type(), in, 0);
     Value result = castF32To(outElemType, asFloat, loc, rewriter);
     return rewriter.replaceOp(op, result);
   }
-  VectorType inType = cast<VectorType>(in.getType());
   int64_t numElements = inType.getNumElements();
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result =
-      rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
   if (inType.getShape().empty()) {
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarExt =
         rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
-    result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
-                                               ArrayRef<int64_t>{});
+    Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+                                                     ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
+
+  VectorType flatTy =
+      VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+  if (inType.getShape().size() > 1) {
+    inType = VectorType::get(SmallVector<int64_t>{numElements},
+                             inType.getElementType());
+    in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
+  }
+
   for (int64_t i = 0; i < numElements; i += 4) {
     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
     Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -114,6 +120,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
       result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
     }
   }
+
+  VectorType outType = cast<VectorType>(op.getOut().getType());
+  if (inType.getShape().size() != outType.getShape().size()) {
+    result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+  }
+
   rewriter.replaceOp(op, result);
 }
 
@@ -182,9 +194,6 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
   if (auto outVecType = dyn_cast<VectorType>(outType)) {
     if (outVecType.isScalable())
       return failure();
-    if (outVecType.getShape().size() > 1)
-      // Multi-dimensional vectors are currently unsupported.
-      return failure();
     outType = outVecType.getElementType();
   }
   auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
@@ -201,8 +210,9 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
+  auto inVectorTy = dyn_cast<VectorType>(in.getType());
   VectorType truncResType = VectorType::get(4, outElemType);
-  if (!isa<VectorType>(in.getType())) {
+  if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
         loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
@@ -214,18 +224,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   int64_t numElements = outType.getNumElements();
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
   if (outType.getShape().empty()) {
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarTrunc =
         rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
-    result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
-                                               ArrayRef<int64_t>{});
+    Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
+                                                     ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
 
+  VectorType flatTy =
+      VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+  if (inVectorTy.getShape().size() > 1) {
+    inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
+                                 inVectorTy.getElementType());
+    in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+  }
+
   for (int64_t i = 0; i < numElements; i += 4) {
     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
     Value thisResult = nullptr;
@@ -246,6 +265,11 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
     result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
                                                            result, i, 1);
   }
+
+  if (inVectorTy.getShape().size() != outType.getShape().size()) {
+    result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+  }
+
   rewriter.replaceOp(op, result);
 }
 
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 159a2f02f0560..32c069180ece5 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -115,3 +115,61 @@ func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
   %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
   return %w : vector<9xf8E4M3FNUZ>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FNUZ> to vector<1x9xf8E4M3FNUZ>
+// CHECK: return [[RE]]
+func.func @vector_trunc_long(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
+  %w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FNUZ>
+  return %w : vector<1x9xf8E4M3FNUZ>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: return [[CAST]]
+func.func @vector_ext_long(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
+  %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
+  return %w : vector<1x9xf32>
+}

>From 3d5d4657ddc1256649c708f931f658b6e82ad962 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Mon, 8 Jul 2024 16:25:33 -0700
Subject: [PATCH 2/2] rank and clone comments addressed

---
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           | 28 ++++++++-----------
 1 file changed, 12 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 7524db4da5917..5ec53874057b2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -78,7 +78,7 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  auto inType = dyn_cast<VectorType>(in.getType());
+  auto inType = dyn_cast<ShapedType>(in.getType());
   if (!inType) {
     Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
         loc, rewriter.getF32Type(), in, 0);
@@ -99,13 +99,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
     return rewriter.replaceOp(op, result);
   }
 
-  VectorType flatTy =
-      VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+  ShapedType outType = cast<ShapedType>(op.getOut().getType());
+  ShapedType flatTy = outType.clone(SmallVector<int64_t>{numElements});
   Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
 
-  if (inType.getShape().size() > 1) {
-    inType = VectorType::get(SmallVector<int64_t>{numElements},
-                             inType.getElementType());
+  if (inType.getRank() > 1) {
+    inType = inType.clone(SmallVector<int64_t>{numElements});
     in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
   }
 
@@ -121,8 +120,7 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
     }
   }
 
-  VectorType outType = cast<VectorType>(op.getOut().getType());
-  if (inType.getShape().size() != outType.getShape().size()) {
+  if (inType.getRank() != outType.getRank()) {
     result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
   }
 
@@ -210,7 +208,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
-  auto inVectorTy = dyn_cast<VectorType>(in.getType());
+  auto inVectorTy = dyn_cast<ShapedType>(in.getType());
   VectorType truncResType = VectorType::get(4, outElemType);
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
@@ -220,7 +218,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
     Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
     return rewriter.replaceOp(op, result);
   }
-  VectorType outType = cast<VectorType>(op.getOut().getType());
+  ShapedType outType = cast<ShapedType>(op.getOut().getType());
   int64_t numElements = outType.getNumElements();
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
@@ -235,13 +233,11 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
     return rewriter.replaceOp(op, result);
   }
 
-  VectorType flatTy =
-      VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+  ShapedType flatTy = outType.clone(SmallVector<int64_t>{numElements});
   Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
 
-  if (inVectorTy.getShape().size() > 1) {
-    inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
-                                 inVectorTy.getElementType());
+  if (inVectorTy.getRank() > 1) {
+    inVectorTy = inVectorTy.clone(SmallVector<int64_t>{numElements});
     in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
   }
 
@@ -266,7 +262,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
                                                            result, i, 1);
   }
 
-  if (inVectorTy.getShape().size() != outType.getShape().size()) {
+  if (inVectorTy.getRank() != outType.getRank()) {
     result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
   }
 



More information about the Mlir-commits mailing list