[Mlir-commits] [mlir] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion (PR #119675)

Victor Perez llvmlistbot at llvm.org
Thu Jan 2 04:01:49 PST 2025


================
@@ -286,30 +291,94 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
            val == getSubgroupSize(op);
   }
 
+  static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+    Type type = op.getType(0);
+    return isa<BFloat16Type>(type) || type.isInteger(1);
+  }
+
+  static Type getBitCastOrExtTy(Type oldTy,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Type>(oldTy)
+        .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+        .Case<IntegerType>([&](auto intTy) -> Type {
+          if (intTy.getWidth() == 1)
+            return rewriter.getIntegerType(8);
+          return Type{};
+        })
+        .Default([](auto) { return Type{}; });
+  }
+
+  static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+                              ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(oldVal.getType())
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
+  static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(newTy)
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
   LogicalResult
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     if (!hasValidWidth(op))
       return rewriter.notifyMatchFailure(
           op, "shuffle width and subgroup size mismatch");
 
-    std::optional<std::string> funcName = getFuncName(op);
+    Location loc = op->getLoc();
+    Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+    std::optional<std::string> funcName;
+    Value inValue;
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastorext");
+      funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+      inValue = newVal;
+    } else {
+      funcName = getFuncName(op);
+      inValue = adaptor.getValue();
+    }
     if (!funcName)
       return rewriter.notifyMatchFailure(op, "unsupported value type");
 
     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     assert(moduleOp && "Expecting module");
-    Type valueType = adaptor.getValue().getType();
+    Type valueType = inValue.getType();
     Type offsetType = adaptor.getOffset().getType();
     Type resultType = valueType;
     LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
         moduleOp, funcName.value(), {valueType, offsetType}, resultType,
         /*isMemNone=*/false, /*isConvergent=*/true);
 
-    Location loc = op->getLoc();
-    std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+    std::array<Value, 2> args{inValue, adaptor.getOffset()};
     Value result =
         createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastortrunc");
+      result = newVal;
+    }
----------------
victor-eds wrote:

```suggestion
    result = bitcastOrTruncAfterShuffle(loc, result, rewriter);
```

https://github.com/llvm/llvm-project/pull/119675


More information about the Mlir-commits mailing list