[Mlir-commits] [mlir] [mlir] Extend CombineTransferReadOpTranspose pattern to handle extf ops. (PR #74754)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 7 12:16:32 PST 2023
https://github.com/harsh-nod created https://github.com/llvm/llvm-project/pull/74754
This patch modifies the CombineTransferReadOpTranspose pattern to handle extf ops. Also adds a test which shows the transpose getting folded into the transfer_read.
>From 6c2ae62789e79a706619cc690f497eae2f38ef50 Mon Sep 17 00:00:00 2001
From: Harsh Menon <harsh at nod-labs.com>
Date: Thu, 7 Dec 2023 12:03:54 -0800
Subject: [PATCH] [mlir] Extend CombineTransferReadOpTranspose pattern to
handle extf ops
This patch modifies the CombineTransferReadOpTranspose
pattern to handle extf ops. Also adds a test which
shows the transpose getting folded into the transfer_read.
---
.../Conversion/VectorToGPU/VectorToGPU.cpp | 8 +++--
.../VectorToGPU/vector-to-mma-ops.mlir | 30 +++++++++++++++++++
2 files changed, 36 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 429d1137b6f37..f151011ee48af 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -455,7 +455,8 @@ struct CombineTransferReadOpTranspose final
Type resultType = op.getType();
Operation *extOp;
if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
- (extOp = source.getDefiningOp<arith::ExtUIOp>())) {
+ (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
+ (extOp = source.getDefiningOp<arith::ExtFOp>())) {
source = extOp->getOperand(0);
resultType =
VectorType::get(cast<VectorType>(resultType).getShape(),
@@ -493,9 +494,12 @@ struct CombineTransferReadOpTranspose final
if (isa<arith::ExtSIOp>(extOp))
result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
.getResult();
- else
+ else if (isa<arith::ExtUIOp>(extOp))
result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
.getResult();
+ else
+ result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
+ .getResult();
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index fa9fff2dad664..962ed7de584a2 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -460,3 +460,33 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf
vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
return
}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: func @fold_transpose_into_transfer_read(
+// CHECK-SAME: %[[ALLOC:.+]]: memref<64x128xf16>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true], permutation_map = #[[$MAP]]}
+// CHECK: %[[EXTF1:.+]] = arith.extf %[[READ]]
+// CHECK-NOT: vector.transpose
+// CHECK: %[[RESULT:.+]] = vector.contract
+func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector: vector<32x128xf16>, %alloc2: memref<32x64xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %init = arith.constant dense<0.000000e+00> : vector<32x64xf32>
+ %0 = vector.transfer_read %alloc[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x128xf16>, vector<64x128xf16>
+ %1 = arith.extf %0 : vector<64x128xf16> to vector<64x128xf32>
+ %2 = arith.extf %vector : vector<32x128xf16> to vector<32x128xf32>
+ %3 = vector.transpose %1, [1, 0] : vector<64x128xf32> to vector<128x64xf32>
+ %4 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %init : vector<32x128xf32>, vector<128x64xf32> into vector<32x64xf32>
+ vector.transfer_write %4, %alloc2[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32>
+ return
+}
+
+// -----
More information about the Mlir-commits
mailing list