[Mlir-commits] [mlir] 91594b5 - [mlir][nvpu] Prevent F32ToTF32 pattern to generate illegal IR
Thomas Raoux
llvmlistbot at llvm.org
Mon Aug 15 09:46:29 PDT 2022
Author: Thomas Raoux
Date: 2022-08-15T16:46:18Z
New Revision: 91594b5b985cf144f9cde64cddc534684d644665
URL: https://github.com/llvm/llvm-project/commit/91594b5b985cf144f9cde64cddc534684d644665
DIFF: https://github.com/llvm/llvm-project/commit/91594b5b985cf144f9cde64cddc534684d644665.diff
LOG: [mlir][nvpu] Prevent F32ToTF32 pattern to generate illegal IR
We shouldn't apply this pattern to non F32->F32 mma.sync operations.
Differential Revision: https://reviews.llvm.org/D131902
Added:
Modified:
mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
index 4ef93b30978a4..d24001c4d28ae 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -42,7 +42,8 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
PatternRewriter &rewrite) const override {
Location location = op->getLoc();
- if (op->hasAttr(op.getTf32EnabledAttrName()))
+ if (op->hasAttr(op.getTf32EnabledAttrName()) ||
+ !op.getMatrixA().getType().cast<VectorType>().getElementType().isF32())
return failure();
if (precision == MmaSyncF32Lowering::Unkown)
diff --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
index a8c72262f101b..80de11fad711b 100644
--- a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
+++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
@@ -18,3 +18,12 @@ func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: v
return %d : vector<2x2xf32>
}
// -----
+
+// Negative test for non f32 case.
+// CHECK-LABEL: mma_sync_f16
+// CHECK-NOT: tf32Enabled
+// CHECK: return
+func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
More information about the Mlir-commits
mailing list