Skip to content

Segment Anything

HQSAMAdapter

HQSAMAdapter(
    target: SegmentAnything,
    hq_mask_only: bool = False,
    weights: dict[str, Tensor] | None = None,
)

Bases: Chain, Adapter[SegmentAnything]

Adapter for SAM introducing HQ features.

See [arXiv:2306.01567] Segment Anything in High Quality for details.

Example
from refiners.fluxion.utils import load_from_safetensors

# Tips: run scripts/prepare_test_weights.py to download the weights
tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors"
weights = load_from_safetensors(tensor_path)

hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
hq_sam_adapter.inject()  # then use SAM as usual

Parameters:

Name Type Description Default
target SegmentAnything

The SegmentAnything model to adapt.

required
hq_mask_only bool

Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).

False
weights dict[str, Tensor] | None

The weights of the HQSAMAdapter.

None
Source code in src/refiners/foundationals/segment_anything/hq_sam.py
def __init__(
    self,
    target: SegmentAnything,
    hq_mask_only: bool = False,
    weights: dict[str, torch.Tensor] | None = None,
) -> None:
    """Initialize the adapter.

    Args:
        target: The SegmentAnything model to adapt.
        hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).
        weights: The weights of the HQSAMAdapter.
    """
    self.vit_embedding_dim = target.image_encoder.embedding_dim
    self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2

    with self.setup_adapter(target):
        super().__init__(target)

    if target.mask_decoder.multimask_output:
        raise NotImplementedError("Multi-mask mode is not supported in HQSAMAdapter.")

    mask_prediction = target.mask_decoder.ensure_find(MaskPrediction)

    self._mask_prediction_adapter = [
        MaskPredictionAdapter(
            mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
        )
    ]
    self._register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)

    self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
    self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]

    mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
    self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
    self._register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)

    if weights is not None:
        self.load_weights(weights)

    self.to(device=target.device, dtype=target.dtype)

SegmentAnything

SegmentAnything(
    image_encoder: SAMViT,
    point_encoder: PointEncoder,
    mask_encoder: MaskEncoder,
    mask_decoder: MaskDecoder,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: Chain

SegmentAnything model.

See [arXiv:2304.02643] Segment Anything

E.g. see SegmentAnythingH for usage.

Attributes:

Name Type Description
mask_threshold float

0.0

Parameters:

Name Type Description Default
image_encoder SAMViT

The image encoder to use.

required
point_encoder PointEncoder

The point encoder to use.

required
mask_encoder MaskEncoder

The mask encoder to use.

required
mask_decoder MaskDecoder

The mask decoder to use.

required
Source code in src/refiners/foundationals/segment_anything/model.py
def __init__(
    self,
    image_encoder: SAMViT,
    point_encoder: PointEncoder,
    mask_encoder: MaskEncoder,
    mask_decoder: MaskDecoder,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    """Initialize SegmentAnything model.

    Args:
        image_encoder: The image encoder to use.
        point_encoder: The point encoder to use.
        mask_encoder: The mask encoder to use.
        mask_decoder: The mask decoder to use.
    """
    super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)

    self.to(device=device, dtype=dtype)

image_encoder property

image_encoder: SAMViT

The image encoder.

image_encoder_resolution property

image_encoder_resolution: int

The resolution of the image encoder.

mask_decoder property

mask_decoder: MaskDecoder

The mask decoder.

mask_encoder property

mask_encoder: MaskEncoder

The mask encoder.

point_encoder property

point_encoder: PointEncoder

The point encoder.

compute_image_embedding

compute_image_embedding(image: Image) -> ImageEmbedding

Compute the emmbedding of an image.

Parameters:

Name Type Description Default
image Image

The image to compute the embedding of.

required

Returns:

Type Description
ImageEmbedding

The computed image embedding.

Source code in src/refiners/foundationals/segment_anything/model.py
@no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
    """Compute the emmbedding of an image.

    Args:
        image: The image to compute the embedding of.

    Returns:
        The computed image embedding.
    """
    original_size = (image.height, image.width)
    return ImageEmbedding(
        features=self.image_encoder(self.preprocess_image(image)),
        original_image_size=original_size,
    )

normalize

normalize(
    coordinates: Tensor, original_size: tuple[int, int]
) -> Tensor

See [normalize_coordinates][refiners.foundationals.segment_anything.utils.normalize_coordinates] Args: coordinates: a tensor of coordinates. original_size: (h, w) the original size of the image. Returns: The [0,1] normalized coordinates tensor.

Source code in src/refiners/foundationals/segment_anything/model.py
def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor:
    """
    See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates]
    Args:
        coordinates: a tensor of coordinates.
        original_size: (h, w) the original size of the image.
    Returns:
        The [0,1] normalized coordinates tensor.
    """
    return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution)

postprocess_masks

postprocess_masks(
    low_res_masks: Tensor, original_size: tuple[int, int]
) -> Tensor

See [postprocess_masks][refiners.foundationals.segment_anything.utils.postprocess_masks] Args: low_res_masks: a mask tensor of size (N, 1, 256, 256) original_size: (h, w) the original size of the image. Returns: The mask of shape (N, 1, H, W)

Source code in src/refiners/foundationals/segment_anything/model.py
def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor:
    """
    See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks]
    Args:
        low_res_masks: a mask tensor of size (N, 1, 256, 256)
        original_size: (h, w) the original size of the image.
    Returns:
        The mask of shape (N, 1, H, W)
    """
    return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution)

predict

predict(
    input: Image | ImageEmbedding,
    foreground_points: (
        Sequence[tuple[float, float]] | None
    ) = None,
    background_points: (
        Sequence[tuple[float, float]] | None
    ) = None,
    box_points: (
        Sequence[Sequence[tuple[float, float]]] | None
    ) = None,
    low_res_mask: (
        Float[Tensor, "1 1 256 256"] | None
    ) = None,
    binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]

Predict the masks of the input image.

Parameters:

Name Type Description Default
input Image | ImageEmbedding

The input image or its embedding.

required
foreground_points Sequence[tuple[float, float]] | None

The points of the foreground.

None
background_points Sequence[tuple[float, float]] | None

The points of the background.

None
box_points Sequence[Sequence[tuple[float, float]]] | None

The points of the box.

None
low_res_mask Float[Tensor, '1 1 256 256'] | None

The low resolution mask.

None
binarize bool

Whether to binarize the masks.

True

Returns:

Type Description
Tensor

The predicted masks.

Tensor

The IOU prediction.

Tensor

The low resolution masks.

Source code in src/refiners/foundationals/segment_anything/model.py
@no_grad()
def predict(
    self,
    input: Image.Image | ImageEmbedding,
    foreground_points: Sequence[tuple[float, float]] | None = None,
    background_points: Sequence[tuple[float, float]] | None = None,
    box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
    low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
    binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]:
    """Predict the masks of the input image.

    Args:
        input: The input image or its embedding.
        foreground_points: The points of the foreground.
        background_points: The points of the background.
        box_points: The points of the box.
        low_res_mask: The low resolution mask.
        binarize: Whether to binarize the masks.

    Returns:
        The predicted masks.
        The IOU prediction.
        The low resolution masks.
    """
    if isinstance(input, ImageEmbedding):
        original_size = input.original_image_size
        image_embedding = input.features
    else:
        original_size = (input.height, input.width)
        image_embedding = self.image_encoder(self.preprocess_image(input))

    coordinates, type_mask = self.point_encoder.points_to_tensor(
        foreground_points=foreground_points,
        background_points=background_points,
        box_points=box_points,
    )
    self.point_encoder.set_type_mask(type_mask=type_mask)

    if low_res_mask is not None:
        mask_embedding = self.mask_encoder(low_res_mask)
    else:
        mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
            image_embedding_size=self.image_encoder.image_embedding_size
        )

    point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size))
    dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
        image_embedding_size=self.image_encoder.image_embedding_size
    )

    self.mask_decoder.set_image_embedding(image_embedding=image_embedding)
    self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding)
    self.mask_decoder.set_point_embedding(point_embedding=point_embedding)
    self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding)

    low_res_masks, iou_predictions = self.mask_decoder()

    high_res_masks = self.postprocess_masks(low_res_masks, original_size)

    if binarize:
        high_res_masks = high_res_masks > self.mask_threshold

    return high_res_masks, iou_predictions, low_res_masks

preprocess_image

preprocess_image(image: Image) -> Tensor

See [preprocess_image][refiners.foundationals.segment_anything.utils.preprocess_image] Args: image: The image to preprocess. Returns: The preprocessed tensor.

Source code in src/refiners/foundationals/segment_anything/model.py
def preprocess_image(self, image: Image.Image) -> Tensor:
    """
    See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image]
    Args:
        image: The image to preprocess.
    Returns:
        The preprocessed tensor.
    """
    return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype)

SegmentAnythingH

SegmentAnythingH(
    image_encoder: SAMViTH | None = None,
    point_encoder: PointEncoder | None = None,
    mask_encoder: MaskEncoder | None = None,
    mask_decoder: MaskDecoder | None = None,
    multimask_output: bool | None = None,
    device: device | str = "cpu",
    dtype: dtype = torch.float32,
)

Bases: SegmentAnything

SegmentAnything huge model.

Parameters:

Name Type Description Default
image_encoder SAMViTH | None

The image encoder to use.

None
point_encoder PointEncoder | None

The point encoder to use.

None
mask_encoder MaskEncoder | None

The mask encoder to use.

None
mask_decoder MaskDecoder | None

The mask decoder to use.

None
multimask_output bool | None

Whether to use multimask output.

None
device device | str

The PyTorch device to use.

'cpu'
dtype dtype

The PyTorch data type to use.

float32
Example
device="cuda" if torch.cuda.is_available() else "cpu"

# multimask_output=True is recommended for ambiguous prompts such as a single point.
# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
sam_h = SegmentAnythingH(multimask_output=False, device=device)

# Tips: run scripts/prepare_test_weights.py to download the weights
tensors_path = "./tests/weights/segment-anything-h.safetensors"
sam_h.load_from_safetensors(tensors_path=tensors_path)

from PIL import Image
image = Image.open("image.png")

masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])

assert masks.shape == (1, 1, image.height, image.width)
assert masks.dtype == torch.bool

# convert it to [0,255] uint8 ndarray of shape (H, W)
mask = masks[0, 0].cpu().numpy().astype("uint8") * 255

Image.fromarray(mask).save("mask_image.png")
Source code in src/refiners/foundationals/segment_anything/model.py
def __init__(
    self,
    image_encoder: SAMViTH | None = None,
    point_encoder: PointEncoder | None = None,
    mask_encoder: MaskEncoder | None = None,
    mask_decoder: MaskDecoder | None = None,
    multimask_output: bool | None = None,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    """Initialize SegmentAnything huge model.

    Args:
        image_encoder: The image encoder to use.
        point_encoder: The point encoder to use.
        mask_encoder: The mask encoder to use.
        mask_decoder: The mask decoder to use.
        multimask_output: Whether to use multimask output.
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.

    Example:
        ```py
        device="cuda" if torch.cuda.is_available() else "cpu"

        # multimask_output=True is recommended for ambiguous prompts such as a single point.
        # Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
        sam_h = SegmentAnythingH(multimask_output=False, device=device)

        # Tips: run scripts/prepare_test_weights.py to download the weights
        tensors_path = "./tests/weights/segment-anything-h.safetensors"
        sam_h.load_from_safetensors(tensors_path=tensors_path)

        from PIL import Image
        image = Image.open("image.png")

        masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])

        assert masks.shape == (1, 1, image.height, image.width)
        assert masks.dtype == torch.bool

        # convert it to [0,255] uint8 ndarray of shape (H, W)
        mask = masks[0, 0].cpu().numpy().astype("uint8") * 255

        Image.fromarray(mask).save("mask_image.png")
        ```
    """
    image_encoder = image_encoder or SAMViTH()
    point_encoder = point_encoder or PointEncoder()
    mask_encoder = mask_encoder or MaskEncoder()

    if mask_decoder:
        assert (
            multimask_output is None or mask_decoder.multimask_output == multimask_output
        ), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output ({multimask_output})"
    else:
        mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()

    super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)

    self.to(device=device, dtype=dtype)

image_encoder property

image_encoder: SAMViTH

The image encoder.