Skip to content

Adapters

Adapter

Bases: Generic[T]

Base class for adapters.

An Adapter modifies the structure of a Module (typically by adding, removing or replacing layers), to adapt it to a new task.

target property

target: T

The target of the adapter.

eject

eject() -> None

Eject the adapter.

This method is the inverse of inject, and should leave the target in the same state as before the injection.

Source code in src/refiners/fluxion/adapters/adapter.py
def eject(self) -> None:
    """Eject the adapter.

    This method is the inverse of [`inject`][refiners.fluxion.adapters.Adapter.inject],
    and should leave the target in the same state as before the injection.
    """
    assert isinstance(self, fl.Chain)

    # In general, the "actual target" is the target.
    # Here we deal with the edge case where the target
    # is part of the replacement block and has been adapted by
    # another adapter after this one. For instance, this is the
    # case when stacking Controlnets.
    actual_target = lookup_top_adapter(self, self.target)

    if (parent := self.parent) is None:
        if isinstance(actual_target, fl.ContextModule):
            actual_target._set_parent(None)  # type: ignore[reportPrivateUsage]
    else:
        parent.replace(old_module=self, new_module=actual_target)

inject

inject(parent: Chain | None = None) -> TAdapter

Inject the adapter.

This method replaces the target of the adapter by the adapter inside the parent of the target.

Parameters:

Name Type Description Default
parent Chain | None

The parent to inject the adapter into, if the target doesn't have a parent.

None
Source code in src/refiners/fluxion/adapters/adapter.py
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
    """Inject the adapter.

    This method replaces the target of the adapter by the adapter inside the parent of the target.

    Args:
        parent: The parent to inject the adapter into, if the target doesn't have a parent.
    """
    assert isinstance(self, fl.Chain)

    if (parent is None) and isinstance(self.target, fl.ContextModule):
        parent = self.target.parent
        if parent is not None:
            assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"

    target_parent = self.find_parent(self.target)

    if parent is None:
        if isinstance(self.target, fl.ContextModule):
            self.target._set_parent(target_parent)  # type: ignore[reportPrivateUsage]
        return self

    # In general, `true_parent` is `parent`. We do this to support multiple adaptation,
    # i.e. initializing two adapters before injecting them.
    true_parent = parent.ensure_find_parent(self.target)
    true_parent.replace(
        old_module=self.target,
        new_module=self,
        old_module_parent=target_parent,
    )
    return self

setup_adapter

setup_adapter(target: T) -> Iterator[None]

Setup the adapter.

This method should be called by the constructor of the adapter. It sets the target of the adapter and ensures that the adapter is not a submodule of the target.

Parameters:

Name Type Description Default
target T

The target of the adapter.

required
Source code in src/refiners/fluxion/adapters/adapter.py
@contextlib.contextmanager
def setup_adapter(self, target: T) -> Iterator[None]:
    """Setup the adapter.

    This method should be called by the constructor of the adapter.
    It sets the target of the adapter and ensures that the adapter
    is not a submodule of the target.

    Args:
        target: The target of the adapter.
    """
    assert isinstance(self, fl.Chain)
    assert (not hasattr(self, "_modules")) or (
        len(self) == 0
    ), "Call the Chain constructor in the setup_adapter context."
    self._target = [target]

    if isinstance(target, fl.ContextModule):
        assert isinstance(target, fl.ContextModule)
        with target.no_parent_refresh():
            yield
    else:
        yield

Conv2dLora

Conv2dLora(
    name: str,
    /,
    in_channels: int,
    out_channels: int,
    rank: int = 16,
    scale: float = 1.0,
    kernel_size: tuple[int, int] = (1, 3),
    stride: tuple[int, int] = (1, 1),
    padding: tuple[int, int] = (0, 1),
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Lora[Conv2d]

Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.

This layer uses two Conv2d layers as its down and up layers.

Parameters:

Name Type Description Default
name str

The name of the LoRA.

required
in_channels int

The number of input channels.

required
out_channels int

The number of output channels.

required
rank int

The rank of the LoRA.

16
scale float

The scale of the LoRA.

1.0
kernel_size tuple[int, int]

The kernel size of the LoRA.

(1, 3)
stride tuple[int, int]

The stride of the LoRA.

(1, 1)
padding tuple[int, int]

The padding of the LoRA.

(0, 1)
device device | str | None

The device of the LoRA weights.

None
dtype dtype | None

The dtype of the LoRA weights.

None
Source code in src/refiners/fluxion/adapters/lora.py
def __init__(
    self,
    name: str,
    /,
    in_channels: int,
    out_channels: int,
    rank: int = 16,
    scale: float = 1.0,
    kernel_size: tuple[int, int] = (1, 3),
    stride: tuple[int, int] = (1, 1),
    padding: tuple[int, int] = (0, 1),
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initialize the LoRA layer.

    Args:
        name: The name of the LoRA.
        in_channels: The number of input channels.
        out_channels: The number of output channels.
        rank: The rank of the LoRA.
        scale: The scale of the LoRA.
        kernel_size: The kernel size of the LoRA.
        stride: The stride of the LoRA.
        padding: The padding of the LoRA.
        device: The device of the LoRA weights.
        dtype: The dtype of the LoRA weights.
    """
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding

    super().__init__(
        name,
        rank=rank,
        scale=scale,
        device=device,
        dtype=dtype,
    )

LinearLora

LinearLora(
    name: str,
    /,
    in_features: int,
    out_features: int,
    rank: int = 16,
    scale: float = 1.0,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Lora[Linear]

Low-Rank Adaptation (LoRA) layer for linear layers.

This layer uses two Linear layers as its down and up layers.

Parameters:

Name Type Description Default
name str

The name of the LoRA.

required
in_features int

The number of input features.

required
out_features int

The number of output features.

required
rank int

The rank of the LoRA.

16
scale float

The scale of the LoRA.

1.0
device device | str | None

The device of the LoRA weights.

None
dtype dtype | None

The dtype of the LoRA weights.

None
Source code in src/refiners/fluxion/adapters/lora.py
def __init__(
    self,
    name: str,
    /,
    in_features: int,
    out_features: int,
    rank: int = 16,
    scale: float = 1.0,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initialize the LoRA layer.

    Args:
        name: The name of the LoRA.
        in_features: The number of input features.
        out_features: The number of output features.
        rank: The rank of the LoRA.
        scale: The scale of the LoRA.
        device: The device of the LoRA weights.
        dtype: The dtype of the LoRA weights.
    """
    self.in_features = in_features
    self.out_features = out_features

    super().__init__(
        name,
        rank=rank,
        scale=scale,
        device=device,
        dtype=dtype,
    )

Lora

Lora(
    name: str,
    /,
    rank: int = 16,
    scale: float = 1.0,
    device: device | str | None = None,
    dtype: dtype | None = None,
)

Bases: Generic[T], Chain, ABC

Low-Rank Adaptation (LoRA) layer.

This layer's purpose is to approximate a given layer by two smaller layers: the down layer (aka A) and the up layer (aka B). See [ arXiv:2106.09685] LoRA: Low-Rank Adaptation of Large Language Models for more details.

Note

This layer is not meant to be used directly. Instead, use one of its subclasses:

Parameters:

Name Type Description Default
name str

The name of the LoRA.

required
rank int

The rank of the LoRA.

16
scale float

The scale of the LoRA.

1.0
device device | str | None

The device of the LoRA weights.

None
dtype dtype | None

The dtype of the LoRA weights.

None
Source code in src/refiners/fluxion/adapters/lora.py
def __init__(
    self,
    name: str,
    /,
    rank: int = 16,
    scale: float = 1.0,
    device: Device | str | None = None,
    dtype: DType | None = None,
) -> None:
    """Initialize the LoRA layer.

    Args:
        name: The name of the LoRA.
        rank: The rank of the LoRA.
        scale: The scale of the LoRA.
        device: The device of the LoRA weights.
        dtype: The dtype of the LoRA weights.
    """
    self.name = name
    self._rank = rank
    self._scale = scale

    super().__init__(
        *self.lora_layers(device=device, dtype=dtype),
        fl.Multiply(scale),
    )
    self.reset_parameters()

down property

down: T

The down layer.

rank property

rank: int

The rank of the low-rank approximation.

scale property writable

scale: float

The scale of the low-rank approximation.

up property

up: T

The up layer.

from_dict classmethod

from_dict(
    name: str, /, state_dict: dict[str, Tensor]
) -> dict[str, Lora[Any]]

Create a dictionary of LoRA layers from a state dict.

Expects the state dict to be a succession of down and up weights.

Source code in src/refiners/fluxion/adapters/lora.py
@classmethod
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora[Any]"]:
    """
    Create a dictionary of LoRA layers from a state dict.

    Expects the state dict to be a succession of down and up weights.
    """
    state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
    loras: dict[str, Lora[Any]] = {}
    for down_key, down_tensor, up_tensor in zip(
        list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
    ):
        key = ".".join(down_key.split(".")[:-2])
        loras[key] = cls.from_weights(name, down=down_tensor, up=up_tensor)
    return loras

load_weights

load_weights(
    down_weight: Tensor, up_weight: Tensor
) -> None

Load the (pre-trained) weights of the LoRA.

Parameters:

Name Type Description Default
down_weight Tensor

The down weight.

required
up_weight Tensor

The up weight.

required
Source code in src/refiners/fluxion/adapters/lora.py
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
    """Load the (pre-trained) weights of the LoRA.

    Args:
        down_weight: The down weight.
        up_weight: The up weight.
    """
    assert down_weight.shape == self.down.weight.shape
    assert up_weight.shape == self.up.weight.shape
    self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
    self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))

lora_layers abstractmethod

lora_layers(
    device: device | str | None = None,
    dtype: dtype | None = None,
) -> tuple[T, T]

Create the down and up layers of the LoRA.

Parameters:

Name Type Description Default
device device | str | None

The device of the LoRA weights.

None
dtype dtype | None

The dtype of the LoRA weights.

None
Source code in src/refiners/fluxion/adapters/lora.py
@abstractmethod
def lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]:
    """Create the down and up layers of the LoRA.

    Args:
        device: The device of the LoRA weights.
        dtype: The dtype of the LoRA weights.
    """
    ...

reset_parameters

reset_parameters() -> None

Reset the parameters of up and down layers.

Source code in src/refiners/fluxion/adapters/lora.py
def reset_parameters(self) -> None:
    """Reset the parameters of up and down layers."""
    normal_(tensor=self.down.weight, std=1 / self.rank)
    zeros_(tensor=self.up.weight)

LoraAdapter

LoraAdapter(target: WeightedModule, /, *loras: Lora[Any])

Bases: Sum, Adapter[WeightedModule]

Adapter for LoRA layers.

This adapter simply sums the target layer with the given LoRA layers.

Parameters:

Name Type Description Default
target WeightedModule

The target layer.

required
loras Lora[Any]

The LoRA layers.

()
Source code in src/refiners/fluxion/adapters/lora.py
def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:
    """Initialize the adapter.

    Args:
        target: The target layer.
        loras: The LoRA layers.
    """
    with self.setup_adapter(target):
        super().__init__(target, *loras)

lora_layers property

lora_layers: Iterator[Lora[Any]]

The LoRA layers.

loras property

loras: dict[str, Lora[Any]]

The LoRA layers indexed by name.

names property

names: list[str]

The names of the LoRA layers.

scales property

scales: dict[str, float]

The scales of the LoRA layers indexed by names.

add_lora

add_lora(lora: Lora[Any]) -> None

Add a LoRA layer to the adapter.

Raises:

Type Description
AssertionError

If the adapter already contains a LoRA layer with the same name.

Parameters:

Name Type Description Default
lora Lora[Any]

The LoRA layer to add.

required
Source code in src/refiners/fluxion/adapters/lora.py
def add_lora(self, lora: Lora[Any], /) -> None:
    """Add a LoRA layer to the adapter.

    Raises:
        AssertionError: If the adapter already contains a LoRA layer with the same name.

    Args:
        lora: The LoRA layer to add.
    """
    assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
    self.append(lora)

remove_lora

remove_lora(name: str) -> Lora[Any] | None

Remove a LoRA layer from the adapter.

Note

If the adapter doesn't contain a LoRA layer with the given name, nothing happens and None is returned.

Parameters:

Name Type Description Default
name str

The name of the LoRA layer to remove.

required
Source code in src/refiners/fluxion/adapters/lora.py
def remove_lora(self, name: str, /) -> Lora[Any] | None:
    """Remove a LoRA layer from the adapter.

    Note:
        If the adapter doesn't contain a LoRA layer with the given name, nothing happens and `None` is returned.

    Args:
        name: The name of the LoRA layer to remove.
    """
    if name in self.names:
        lora = self.loras[name]
        self.remove(lora)
        return lora

auto_attach_loras

auto_attach_loras(
    loras: dict[str, Lora[Any]],
    target: Chain,
    /,
    include: list[str] | None = None,
    exclude: list[str] | None = None,
    sanity_check: bool = True,
    debug_map: list[tuple[str, str]] | None = None,
) -> list[str]

Auto-attach several LoRA layers to a Chain.

Parameters:

Name Type Description Default
loras dict[str, Lora[Any]]

A dictionary of LoRA layers associated to their respective key. The keys are typically derived from the state dict and only used for debug_map and the return value.

required
target Chain

The target Chain.

required
include list[str] | None

A list of layer names, only layers with such a layer in their ancestors will be considered.

None
exclude list[str] | None

A list of layer names, layers with such a layer in their ancestors will not be considered.

None
sanity_check bool

Check that LoRAs passed are correctly attached.

True
debug_map list[tuple[str, str]] | None

Pass a list to get a debug mapping of key - path pairs of attached points.

None

Returns: A list of keys of LoRA layers which failed to attach.

Source code in src/refiners/fluxion/adapters/lora.py
def auto_attach_loras(
    loras: dict[str, Lora[Any]],
    target: fl.Chain,
    /,
    include: list[str] | None = None,
    exclude: list[str] | None = None,
    sanity_check: bool = True,
    debug_map: list[tuple[str, str]] | None = None,
) -> list[str]:
    """Auto-attach several LoRA layers to a Chain.

    Args:
        loras: A dictionary of LoRA layers associated to their respective key. The keys are typically
            derived from the state dict and only used for `debug_map` and the return value.
        target: The target Chain.
        include: A list of layer names, only layers with such a layer in their ancestors will be considered.
        exclude: A list of layer names, layers with such a layer in their ancestors will not be considered.
        sanity_check: Check that LoRAs passed are correctly attached.
        debug_map: Pass a list to get a debug mapping of key - path pairs of attached points.
    Returns:
        A list of keys of LoRA layers which failed to attach.
    """

    if not sanity_check:
        return _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map)

    loras_copy = {key: Lora.from_weights(lora.name, lora.down.weight, lora.up.weight) for key, lora in loras.items()}
    debug_map_1: list[tuple[str, str]] = []
    failed_keys_1 = _auto_attach_loras(loras, target, include=include, exclude=exclude, debug_map=debug_map_1)
    if debug_map is not None:
        debug_map += debug_map_1
    if len(debug_map_1) != len(loras) or failed_keys_1:
        raise ValueError(
            f"sanity check failed: {len(debug_map_1)} / {len(loras)} LoRA layers attached, {len(failed_keys_1)} failed"
        )

    # Extra sanity check: if we re-run the attach, all layers should fail.
    debug_map_2: list[tuple[str, str]] = []
    failed_keys_2 = _auto_attach_loras(loras_copy, target, include=include, exclude=exclude, debug_map=debug_map_2)
    if debug_map_2 or len(failed_keys_2) != len(loras):
        raise ValueError(
            f"sanity check failed: {len(debug_map_2)} / {len(loras)} LoRA layers attached twice, {len(failed_keys_2)} skipped"
        )

    return failed_keys_1