Segment Anything
HQSAMAdapter
¶
HQSAMAdapter(
target: SegmentAnything,
hq_mask_only: bool = False,
weights: dict[str, Tensor] | None = None,
)
Bases: Chain
, Adapter[SegmentAnything]
Adapter for SAM introducing HQ features.
See [arXiv:2306.01567] Segment Anything in High Quality for details.
Example
from refiners.fluxion.utils import load_from_safetensors
# Tips: run scripts/prepare_test_weights.py to download the weights
tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors"
weights = load_from_safetensors(tensor_path)
hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
hq_sam_adapter.inject() # then use SAM as usual
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target |
SegmentAnything
|
The SegmentAnything model to adapt. |
required |
hq_mask_only |
bool
|
Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask). |
False
|
weights |
dict[str, Tensor] | None
|
The weights of the HQSAMAdapter. |
None
|
Source code in src/refiners/foundationals/segment_anything/hq_sam.py
SegmentAnything
¶
SegmentAnything(
image_encoder: SAMViT,
point_encoder: PointEncoder,
mask_encoder: MaskEncoder,
mask_decoder: MaskDecoder,
device: device | str = "cpu",
dtype: dtype = torch.float32,
)
Bases: Chain
SegmentAnything model.
See [arXiv:2304.02643] Segment Anything
E.g. see SegmentAnythingH
for usage.
Attributes:
Name | Type | Description |
---|---|---|
mask_threshold |
float
|
0.0 |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_encoder |
SAMViT
|
The image encoder to use. |
required |
point_encoder |
PointEncoder
|
The point encoder to use. |
required |
mask_encoder |
MaskEncoder
|
The mask encoder to use. |
required |
mask_decoder |
MaskDecoder
|
The mask decoder to use. |
required |
Source code in src/refiners/foundationals/segment_anything/model.py
image_encoder_resolution
property
¶
image_encoder_resolution: int
The resolution of the image encoder.
compute_image_embedding
¶
Compute the emmbedding of an image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image |
Image
|
The image to compute the embedding of. |
required |
Returns:
Type | Description |
---|---|
ImageEmbedding
|
The computed image embedding. |
Source code in src/refiners/foundationals/segment_anything/model.py
normalize
¶
See [normalize_coordinates
][refiners.foundationals.segment_anything.utils.normalize_coordinates]
Args:
coordinates: a tensor of coordinates.
original_size: (h, w) the original size of the image.
Returns:
The [0,1] normalized coordinates tensor.
Source code in src/refiners/foundationals/segment_anything/model.py
postprocess_masks
¶
See [postprocess_masks
][refiners.foundationals.segment_anything.utils.postprocess_masks]
Args:
low_res_masks: a mask tensor of size (N, 1, 256, 256)
original_size: (h, w) the original size of the image.
Returns:
The mask of shape (N, 1, H, W)
Source code in src/refiners/foundationals/segment_anything/model.py
predict
¶
predict(
input: Image | ImageEmbedding,
foreground_points: (
Sequence[tuple[float, float]] | None
) = None,
background_points: (
Sequence[tuple[float, float]] | None
) = None,
box_points: (
Sequence[Sequence[tuple[float, float]]] | None
) = None,
low_res_mask: (
Float[Tensor, "1 1 256 256"] | None
) = None,
binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]
Predict the masks of the input image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
Image | ImageEmbedding
|
The input image or its embedding. |
required |
foreground_points |
Sequence[tuple[float, float]] | None
|
The points of the foreground. |
None
|
background_points |
Sequence[tuple[float, float]] | None
|
The points of the background. |
None
|
box_points |
Sequence[Sequence[tuple[float, float]]] | None
|
The points of the box. |
None
|
low_res_mask |
Float[Tensor, '1 1 256 256'] | None
|
The low resolution mask. |
None
|
binarize |
bool
|
Whether to binarize the masks. |
True
|
Returns:
Type | Description |
---|---|
Tensor
|
The predicted masks. |
Tensor
|
The IOU prediction. |
Tensor
|
The low resolution masks. |
Source code in src/refiners/foundationals/segment_anything/model.py
preprocess_image
¶
preprocess_image(image: Image) -> Tensor
See [preprocess_image
][refiners.foundationals.segment_anything.utils.preprocess_image]
Args:
image: The image to preprocess.
Returns:
The preprocessed tensor.
Source code in src/refiners/foundationals/segment_anything/model.py
SegmentAnythingH
¶
SegmentAnythingH(
image_encoder: SAMViTH | None = None,
point_encoder: PointEncoder | None = None,
mask_encoder: MaskEncoder | None = None,
mask_decoder: MaskDecoder | None = None,
multimask_output: bool | None = None,
device: device | str = "cpu",
dtype: dtype = torch.float32,
)
Bases: SegmentAnything
SegmentAnything huge model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_encoder |
SAMViTH | None
|
The image encoder to use. |
None
|
point_encoder |
PointEncoder | None
|
The point encoder to use. |
None
|
mask_encoder |
MaskEncoder | None
|
The mask encoder to use. |
None
|
mask_decoder |
MaskDecoder | None
|
The mask decoder to use. |
None
|
multimask_output |
bool | None
|
Whether to use multimask output. |
None
|
device |
device | str
|
The PyTorch device to use. |
'cpu'
|
dtype |
dtype
|
The PyTorch data type to use. |
float32
|
Example
device="cuda" if torch.cuda.is_available() else "cpu"
# multimask_output=True is recommended for ambiguous prompts such as a single point.
# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
sam_h = SegmentAnythingH(multimask_output=False, device=device)
# Tips: run scripts/prepare_test_weights.py to download the weights
tensors_path = "./tests/weights/segment-anything-h.safetensors"
sam_h.load_from_safetensors(tensors_path=tensors_path)
from PIL import Image
image = Image.open("image.png")
masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])
assert masks.shape == (1, 1, image.height, image.width)
assert masks.dtype == torch.bool
# convert it to [0,255] uint8 ndarray of shape (H, W)
mask = masks[0, 0].cpu().numpy().astype("uint8") * 255
Image.fromarray(mask).save("mask_image.png")