Fix Pixtral 12b compatibility

This commit is contained in:
doctorpangloss 2025-03-03 13:07:36 -08:00
parent a4a9e4b59f
commit c6111fae7d

View File

@ -399,14 +399,21 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
if hasattr(self.processor, "to"):
self.processor.to(device=self.offload_device)
assert "input_ids" in batch_feature
batch_feature.to(device=self.load_device, dtype=self.model_dtype())
try:
batch_feature.to(device=self.load_device, dtype=self.model_dtype())
except TypeError:
# works around Pixtral processor bug
batch_feature.to(self.load_device)
batch_feature.to(self.model_dtype())
# noinspection PyTypeChecker
return {
"image_sizes": image_sizes,
"images": batch_feature["pixel_values"],
batch_feature_dict = {
"inputs": batch_feature["input_ids"],
**batch_feature
}
if "pixel_values" in batch_feature:
batch_feature_dict["image_sizes"] = image_sizes
batch_feature_dict["images"] = batch_feature["pixel_values"]
return batch_feature_dict
@property
def repo_id(self) -> str: