Skip to content

Model Converter

ConversionStage

Bases: Enum

Represents the current stage of the conversion process.

Attributes:

Name Type Description
INIT

The conversion process has not started.

BASIC_LAYERS_MATCH

The source and target models have the same number of basic layers.

SHAPE_AND_LAYERS_MATCH

The shape of both models agree.

MODELS_OUTPUT_AGREE

The source and target models agree.

ModelConverter

ModelConverter(
    source_model: Module,
    target_model: Module,
    source_keys_to_skip: list[str] | None = None,
    target_keys_to_skip: list[str] | None = None,
    custom_layer_mapping: (
        dict[type[Module], type[Module]] | None
    ) = None,
    threshold: float = 1e-05,
    skip_output_check: bool = False,
    skip_init_check: bool = False,
    verbose: bool = True,
)

Converts a model's state_dict to match another model's state_dict.

The conversion process consists of three stages
  1. Verify that the source and target models have the same number of basic layers.
  2. Find matching shapes and layers between the source and target models.
  3. Convert the source model's state_dict to match the target model's state_dict.
  4. Compare the outputs of the source and target models.

The conversion process can be run multiple times, and will resume from the last stage.

Example
source = ...
target = ...

converter = ModelConverter(
    source_model=source,
    target_model=target,
    threshold=0.1,
    verbose=False
)

is_converted = converter(args)
if is_converted:
    converter.save_to_safetensors(path="converted_model.pt")

Parameters:

Name Type Description Default
source_model Module

The model to convert from.

required
target_model Module

The model to convert to.

required
source_keys_to_skip list[str] | None

A list of keys to skip when tracing the source model.

None
target_keys_to_skip list[str] | None

A list of keys to skip when tracing the target model.

None
custom_layer_mapping dict[type[Module], type[Module]] | None

A dictionary mapping custom layer types between the source and target models.

None
threshold float

The threshold for comparing outputs between the source and target models.

1e-05
skip_output_check bool

Whether to skip comparing the outputs of the source and target models.

False
skip_init_check bool

Whether to skip checking that the source and target models have the same number of basic layers.

False
verbose bool

Whether to print messages during the conversion process.

True
Source code in src/refiners/fluxion/model_converter.py
def __init__(
    self,
    source_model: nn.Module,
    target_model: nn.Module,
    source_keys_to_skip: list[str] | None = None,
    target_keys_to_skip: list[str] | None = None,
    custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
    threshold: float = 1e-5,
    skip_output_check: bool = False,
    skip_init_check: bool = False,
    verbose: bool = True,
) -> None:
    """Initializes the ModelConverter.

    Args:
        source_model: The model to convert from.
        target_model: The model to convert to.
        source_keys_to_skip: A list of keys to skip when tracing the source model.
        target_keys_to_skip: A list of keys to skip when tracing the target model.
        custom_layer_mapping: A dictionary mapping custom layer types between the source and target models.
        threshold: The threshold for comparing outputs between the source and target models.
        skip_output_check: Whether to skip comparing the outputs of the source and target models.
        skip_init_check: Whether to skip checking that the source and target models have the same number of basic
            layers.
        verbose: Whether to print messages during the conversion process.

    """
    self.source_model = source_model
    self.target_model = target_model
    self.source_keys_to_skip = source_keys_to_skip or []
    self.target_keys_to_skip = target_keys_to_skip or []
    self.custom_layer_mapping = custom_layer_mapping or {}
    self.threshold = threshold
    self.skip_output_check = skip_output_check
    self.skip_init_check = skip_init_check
    self.verbose = verbose

compare_models

compare_models(
    source_args: ModuleArgs,
    target_args: ModuleArgs | None = None,
    threshold: float = 1e-05,
) -> bool

Compare the outputs of the source and target models.

Parameters:

Name Type Description Default
source_args ModuleArgs

The arguments to pass to the source model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys. If target_args is not provided, these arguments will also be passed to the target model.

required
target_args ModuleArgs | None

The arguments to pass to the target model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys.

None
threshold float

The threshold for comparing outputs between the source and target models.

1e-05

Returns:

Type Description
bool

True if the outputs of the source and target models agree.

Source code in src/refiners/fluxion/model_converter.py
def compare_models(
    self,
    source_args: ModuleArgs,
    target_args: ModuleArgs | None = None,
    threshold: float = 1e-5,
) -> bool:
    """Compare the outputs of the source and target models.

    Args:
        source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
            is not provided, these arguments will also be passed to the target model.
        target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
        threshold: The threshold for comparing outputs between the source and target models.

    Returns:
        True if the outputs of the source and target models agree.
    """
    if target_args is None:
        target_args = source_args

    source_outputs = self._collect_layers_outputs(
        module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
    )
    target_outputs = self._collect_layers_outputs(
        module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
    )

    diff, prev_source_key, prev_target_key = None, None, None
    for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
        diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()
        if diff > threshold:
            self._log(
                f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
                f" {target_key}, difference in norm: {diff}"
            )
            return False
        prev_source_key, prev_target_key = source_key, target_key

    self._log(message=f"Models agree. Difference in norm: {diff}")

    return True

get_mapping

get_mapping() -> dict[str, str]

Get the mapping between the source and target models' state_dicts.

Source code in src/refiners/fluxion/model_converter.py
def get_mapping(self) -> dict[str, str]:
    """Get the mapping between the source and target models' state_dicts."""
    if not self:
        raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
    assert self._stored_mapping is not None, "Mapping is not stored"
    return self._stored_mapping

get_module_signature

get_module_signature(module: Module) -> ModelTypeShape

Get the signature of a module.

Source code in src/refiners/fluxion/model_converter.py
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
    """Get the signature of a module."""
    layer_type = self._infer_basic_layer_type(module=module)
    assert layer_type is not None, f"Module {module} is not a basic layer"
    param_shapes = [p.shape for p in module.parameters()]
    return (layer_type, tuple(param_shapes))

get_state_dict

get_state_dict() -> dict[str, Tensor]

Get the converted state_dict.

Source code in src/refiners/fluxion/model_converter.py
def get_state_dict(self) -> dict[str, Tensor]:
    """Get the converted state_dict."""
    if not self:
        raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
    return self.target_model.state_dict()

map_state_dicts

map_state_dicts(
    source_args: ModuleArgs,
    target_args: ModuleArgs | None = None,
) -> dict[str, str] | None

Find a mapping between the source and target models' state_dicts.

Parameters:

Name Type Description Default
source_args ModuleArgs

The arguments to pass to the source model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys. If target_args is not provided, these arguments will also be passed to the target model.

required
target_args ModuleArgs | None

The arguments to pass to the target model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys.

None

Returns:

Type Description
dict[str, str] | None

A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.

Source code in src/refiners/fluxion/model_converter.py
def map_state_dicts(
    self,
    source_args: ModuleArgs,
    target_args: ModuleArgs | None = None,
) -> dict[str, str] | None:
    """Find a mapping between the source and target models' state_dicts.

    Args:
        source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
            is not provided, these arguments will also be passed to the target model.
        target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.

    Returns:
        A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
    """
    if target_args is None:
        target_args = source_args

    source_order = self._trace_module_execution_order(
        module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
    )
    target_order = self._trace_module_execution_order(
        module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
    )

    if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order):
        return None

    mapping: dict[str, str] = {}
    for source_type_shape in source_order:
        source_keys = source_order[source_type_shape]
        target_type_shape = source_type_shape
        if not self._is_torch_basic_layer(module_type=source_type_shape[0]):
            for source_custom_type, target_custom_type in self.custom_layer_mapping.items():
                if source_custom_type == source_type_shape[0]:
                    target_type_shape = (target_custom_type, source_type_shape[1])
                    break

        target_keys = target_order[target_type_shape]
        mapping.update(zip(target_keys, source_keys))

    return mapping

run

run(
    source_args: ModuleArgs,
    target_args: ModuleArgs | None = None,
) -> bool

Run the conversion process.

Parameters:

Name Type Description Default
source_args ModuleArgs

The arguments to pass to the source model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys. If target_args is not provided, these arguments will also be passed to the target model.

required
target_args ModuleArgs | None

The arguments to pass to the target model it can be either a tuple of positional arguments, a dictionary of keyword arguments, or a dictionary with positional and keyword keys.

None

Returns:

Type Description
bool

True if the conversion process is done and the models agree.

Source code in src/refiners/fluxion/model_converter.py
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
    """Run the conversion process.

    Args:
        source_args: The arguments to pass to the source model it can be either a tuple of positional arguments,
            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
            is not provided, these arguments will also be passed to the target model.
        target_args: The arguments to pass to the target model it can be either a tuple of positional arguments,
            a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.

    Returns:
        True if the conversion process is done and the models agree.
    """
    if target_args is None:
        target_args = source_args

    match self.stage:
        case ConversionStage.MODELS_OUTPUT_AGREE:
            self._increment_stage()
            return True

        case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(
            source_args=source_args, target_args=target_args
        ):
            self._increment_stage()
            return True

        case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
            source_args=source_args, target_args=target_args
        ):
            self._increment_stage()
            return self.run(source_args=source_args, target_args=target_args)

        case ConversionStage.INIT if self._run_init_stage():
            self._increment_stage()
            return self.run(source_args=source_args, target_args=target_args)

        case _:
            self._log(message=f"Conversion failed at stage {self.stage.value}")
            return False

save_to_safetensors

save_to_safetensors(
    path: Path | str,
    metadata: dict[str, str] | None = None,
    half: bool = False,
) -> None

Save the converted model to a SafeTensors file.

Warning

This method can only be called after the conversion process is done.

Parameters:

Name Type Description Default
path Path | str

The path to save the converted model to.

required
metadata dict[str, str] | None

Metadata to save with the converted model.

None
half bool

Whether to save the converted model as half precision.

False

Raises:

Type Description
ValueError

If the conversion process is not done yet. Run converter first.

Source code in src/refiners/fluxion/model_converter.py
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
    """Save the converted model to a SafeTensors file.

    Warning:
        This method can only be called after the conversion process is done.

    Args:
        path: The path to save the converted model to.
        metadata: Metadata to save with the converted model.
        half: Whether to save the converted model as half precision.

    Raises:
        ValueError: If the conversion process is not done yet. Run `converter` first.
    """
    if not self:
        raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
    state_dict = self.get_state_dict()
    if half:
        state_dict = {key: value.half() for key, value in state_dict.items()}
    save_to_safetensors(path=path, tensors=state_dict, metadata=metadata)

ModuleArgsDict

Bases: TypedDict

Represents positional and keyword arguments passed to a module.

  • positional: A tuple of positional arguments.
  • keyword: A dictionary of keyword arguments.