Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions pyrit/models/storage_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,32 @@ def parse_blob_url(self, file_path: str) -> tuple[str, str]:
return container_name, blob_name
raise ValueError("Invalid blob URL")

def _resolve_blob_name(self, path: Union[Path, str]) -> str:
"""
Resolve a blob name from either a full blob URL or a relative blob path.

When a full URL is provided the blob name is extracted from it. The container
name embedded in the URL is intentionally discarded — operations always run
against the container configured in the constructor.

Backslashes are normalized to forward slashes so that ``Path`` objects
created on Windows still produce valid blob names.

Args:
path (Union[Path, str]): Blob URL or relative blob path.

Returns:
str: The resolved blob name.

"""
path_str = str(path).replace("\\", "/")
try:
# parse_blob_url validates scheme + netloc internally
_, blob_name = self.parse_blob_url(path_str)
return blob_name
except ValueError:
return path_str

async def read_file(self, path: Union[Path, str]) -> bytes:
"""
Asynchronously reads the content of a file (blob) from Azure Blob Storage.
Expand Down Expand Up @@ -285,7 +311,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes:
if not self._client_async:
await self._create_container_client_async()

_, blob_name = self.parse_blob_url(str(path))
blob_name = self._resolve_blob_name(path)

try:
blob_client = self._client_async.get_blob_client(blob=blob_name)
Expand All @@ -305,14 +331,17 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None:
"""
Write data to Azure Blob Storage at the specified path.

If the provided ``path`` is a full URL, the blob name is extracted from it.
If a relative path is provided, it is used as the blob name directly.

Args:
path (str): The full Azure Blob Storage URL
path (Union[Path, str]): Full blob URL or relative blob path.
data (bytes): The data to write.

"""
if not self._client_async:
await self._create_container_client_async()
_, blob_name = self.parse_blob_url(str(path))
blob_name = self._resolve_blob_name(path)
try:
await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type)
except Exception as exc:
Expand All @@ -336,7 +365,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool:
if not self._client_async:
await self._create_container_client_async()
try:
_, blob_name = self.parse_blob_url(str(path))
blob_name = self._resolve_blob_name(path)
blob_client = self._client_async.get_blob_client(blob=blob_name)
await blob_client.get_blob_properties()
return True
Expand All @@ -360,7 +389,7 @@ async def is_file(self, path: Union[Path, str]) -> bool:
if not self._client_async:
await self._create_container_client_async()
try:
_, blob_name = self.parse_blob_url(str(path))
blob_name = self._resolve_blob_name(path)
blob_client = self._client_async.get_blob_client(blob=blob_name)
blob_properties = await blob_client.get_blob_properties()
return blob_properties.size > 0
Expand Down
101 changes: 101 additions & 0 deletions tests/unit/models/test_storage_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ async def test_azure_blob_storage_io_read_file(azure_blob_storage_io):
assert result == b"Test file content"


@pytest.mark.asyncio
async def test_azure_blob_storage_io_read_file_with_relative_path(azure_blob_storage_io):
mock_container_client = AsyncMock()
azure_blob_storage_io._client_async = mock_container_client

mock_blob_client = AsyncMock()
mock_blob_stream = AsyncMock()

mock_container_client.get_blob_client = Mock(return_value=mock_blob_client)
mock_blob_client.download_blob = AsyncMock(return_value=mock_blob_stream)
mock_blob_stream.readall = AsyncMock(return_value=b"Test file content")
mock_container_client.close = AsyncMock()

result = await azure_blob_storage_io.read_file("dir1/dir2/sample.png")

assert result == b"Test file content"
mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/sample.png")


@pytest.mark.asyncio
async def test_azure_blob_storage_io_write_file():
container_url = "https://youraccount.blob.core.windows.net/yourcontainer"
Expand Down Expand Up @@ -129,6 +148,29 @@ async def test_azure_blob_storage_io_write_file():
)


@pytest.mark.asyncio
async def test_azure_blob_storage_io_write_file_with_relative_path():
container_url = "https://youraccount.blob.core.windows.net/yourcontainer"
azure_blob_storage_io = AzureBlobStorageIO(
container_url=container_url, blob_content_type=SupportedContentType.PLAIN_TEXT
)

mock_container_client = AsyncMock()

with patch.object(azure_blob_storage_io, "_create_container_client_async", return_value=None):
azure_blob_storage_io._client_async = mock_container_client
azure_blob_storage_io._upload_blob_async = AsyncMock()

data_to_write = b"Test data"
await azure_blob_storage_io.write_file("dir1/dir2/testfile.txt", data_to_write)

azure_blob_storage_io._upload_blob_async.assert_awaited_with(
file_name="dir1/dir2/testfile.txt",
data=data_to_write,
content_type=SupportedContentType.PLAIN_TEXT.value,
)


@pytest.mark.asyncio
async def test_azure_blob_storage_io_create_container_client_uses_explicit_sas_token():
container_url = "https://youraccount.blob.core.windows.net/yourcontainer"
Expand Down Expand Up @@ -164,6 +206,23 @@ async def test_azure_storage_io_path_exists(azure_blob_storage_io):
assert exists is True


@pytest.mark.asyncio
async def test_azure_storage_io_path_exists_with_relative_path(azure_blob_storage_io):
mock_container_client = AsyncMock()
azure_blob_storage_io._client_async = mock_container_client

mock_blob_client = AsyncMock()

mock_container_client.get_blob_client = Mock(return_value=mock_blob_client)
mock_blob_client.get_blob_properties = AsyncMock()
mock_container_client.close = AsyncMock()

exists = await azure_blob_storage_io.path_exists("dir1/dir2/blob_name.txt")

assert exists is True
mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/blob_name.txt")


@pytest.mark.asyncio
async def test_azure_storage_io_is_file(azure_blob_storage_io):
azure_blob_storage_io._client_async = AsyncMock()
Expand All @@ -179,6 +238,24 @@ async def test_azure_storage_io_is_file(azure_blob_storage_io):
assert is_file is True


@pytest.mark.asyncio
async def test_azure_storage_io_is_file_with_relative_path(azure_blob_storage_io):
mock_container_client = AsyncMock()
azure_blob_storage_io._client_async = mock_container_client

mock_blob_client = AsyncMock()

mock_container_client.get_blob_client = Mock(return_value=mock_blob_client)
mock_blob_properties = Mock(size=1024)
mock_blob_client.get_blob_properties = AsyncMock(return_value=mock_blob_properties)
mock_container_client.close = AsyncMock()

is_file = await azure_blob_storage_io.is_file("dir1/dir2/blob_name.txt")

assert is_file is True
mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/blob_name.txt")


def test_azure_storage_io_parse_blob_url_valid(azure_blob_storage_io):
file_path = "https://example.blob.core.windows.net/container/dir1/dir2/blob_name.txt"
container_name, blob_name = azure_blob_storage_io.parse_blob_url(file_path)
Expand All @@ -200,3 +277,27 @@ def test_azure_storage_io_parse_blob_url_without_scheme(azure_blob_storage_io):
def test_azure_storage_io_parse_blob_url_without_netloc(azure_blob_storage_io):
with pytest.raises(ValueError, match="Invalid blob URL"):
azure_blob_storage_io.parse_blob_url("https:///container/dir1/blob_name.txt")


def test_resolve_blob_name_with_full_url(azure_blob_storage_io):
result = azure_blob_storage_io._resolve_blob_name("https://account.blob.core.windows.net/container/dir1/file.txt")
assert result == "dir1/file.txt"


def test_resolve_blob_name_with_relative_path(azure_blob_storage_io):
assert azure_blob_storage_io._resolve_blob_name("dir1/dir2/file.txt") == "dir1/dir2/file.txt"


def test_resolve_blob_name_with_simple_filename(azure_blob_storage_io):
assert azure_blob_storage_io._resolve_blob_name("file.txt") == "file.txt"


def test_resolve_blob_name_normalizes_backslashes(azure_blob_storage_io):
assert azure_blob_storage_io._resolve_blob_name("dir1\\dir2\\file.txt") == "dir1/dir2/file.txt"


def test_resolve_blob_name_with_path_object(azure_blob_storage_io):
from pathlib import PurePosixPath

result = azure_blob_storage_io._resolve_blob_name(PurePosixPath("dir1/dir2/file.txt"))
assert result == "dir1/dir2/file.txt"
Loading