diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 7d00dc157..e69306c07 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -255,6 +255,32 @@ def parse_blob_url(self, file_path: str) -> tuple[str, str]: return container_name, blob_name raise ValueError("Invalid blob URL") + def _resolve_blob_name(self, path: Union[Path, str]) -> str: + """ + Resolve a blob name from either a full blob URL or a relative blob path. + + When a full URL is provided the blob name is extracted from it. The container + name embedded in the URL is intentionally discarded — operations always run + against the container configured in the constructor. + + Backslashes are normalized to forward slashes so that ``Path`` objects + created on Windows still produce valid blob names. + + Args: + path (Union[Path, str]): Blob URL or relative blob path. + + Returns: + str: The resolved blob name. + + """ + path_str = str(path).replace("\\", "/") + try: + # parse_blob_url validates scheme + netloc internally + _, blob_name = self.parse_blob_url(path_str) + return blob_name + except ValueError: + return path_str + async def read_file(self, path: Union[Path, str]) -> bytes: """ Asynchronously reads the content of a file (blob) from Azure Blob Storage. @@ -285,7 +311,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: if not self._client_async: await self._create_container_client_async() - _, blob_name = self.parse_blob_url(str(path)) + blob_name = self._resolve_blob_name(path) try: blob_client = self._client_async.get_blob_client(blob=blob_name) @@ -305,14 +331,17 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: """ Write data to Azure Blob Storage at the specified path. + If the provided ``path`` is a full URL, the blob name is extracted from it. + If a relative path is provided, it is used as the blob name directly. + Args: - path (str): The full Azure Blob Storage URL + path (Union[Path, str]): Full blob URL or relative blob path. data (bytes): The data to write. """ if not self._client_async: await self._create_container_client_async() - _, blob_name = self.parse_blob_url(str(path)) + blob_name = self._resolve_blob_name(path) try: await self._upload_blob_async(file_name=blob_name, data=data, content_type=self._blob_content_type) except Exception as exc: @@ -336,7 +365,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool: if not self._client_async: await self._create_container_client_async() try: - _, blob_name = self.parse_blob_url(str(path)) + blob_name = self._resolve_blob_name(path) blob_client = self._client_async.get_blob_client(blob=blob_name) await blob_client.get_blob_properties() return True @@ -360,7 +389,7 @@ async def is_file(self, path: Union[Path, str]) -> bool: if not self._client_async: await self._create_container_client_async() try: - _, blob_name = self.parse_blob_url(str(path)) + blob_name = self._resolve_blob_name(path) blob_client = self._client_async.get_blob_client(blob=blob_name) blob_properties = await blob_client.get_blob_properties() return blob_properties.size > 0 diff --git a/tests/unit/models/test_storage_io.py b/tests/unit/models/test_storage_io.py index e173b06ee..0159d65b9 100644 --- a/tests/unit/models/test_storage_io.py +++ b/tests/unit/models/test_storage_io.py @@ -101,6 +101,25 @@ async def test_azure_blob_storage_io_read_file(azure_blob_storage_io): assert result == b"Test file content" +@pytest.mark.asyncio +async def test_azure_blob_storage_io_read_file_with_relative_path(azure_blob_storage_io): + mock_container_client = AsyncMock() + azure_blob_storage_io._client_async = mock_container_client + + mock_blob_client = AsyncMock() + mock_blob_stream = AsyncMock() + + mock_container_client.get_blob_client = Mock(return_value=mock_blob_client) + mock_blob_client.download_blob = AsyncMock(return_value=mock_blob_stream) + mock_blob_stream.readall = AsyncMock(return_value=b"Test file content") + mock_container_client.close = AsyncMock() + + result = await azure_blob_storage_io.read_file("dir1/dir2/sample.png") + + assert result == b"Test file content" + mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/sample.png") + + @pytest.mark.asyncio async def test_azure_blob_storage_io_write_file(): container_url = "https://youraccount.blob.core.windows.net/yourcontainer" @@ -129,6 +148,29 @@ async def test_azure_blob_storage_io_write_file(): ) +@pytest.mark.asyncio +async def test_azure_blob_storage_io_write_file_with_relative_path(): + container_url = "https://youraccount.blob.core.windows.net/yourcontainer" + azure_blob_storage_io = AzureBlobStorageIO( + container_url=container_url, blob_content_type=SupportedContentType.PLAIN_TEXT + ) + + mock_container_client = AsyncMock() + + with patch.object(azure_blob_storage_io, "_create_container_client_async", return_value=None): + azure_blob_storage_io._client_async = mock_container_client + azure_blob_storage_io._upload_blob_async = AsyncMock() + + data_to_write = b"Test data" + await azure_blob_storage_io.write_file("dir1/dir2/testfile.txt", data_to_write) + + azure_blob_storage_io._upload_blob_async.assert_awaited_with( + file_name="dir1/dir2/testfile.txt", + data=data_to_write, + content_type=SupportedContentType.PLAIN_TEXT.value, + ) + + @pytest.mark.asyncio async def test_azure_blob_storage_io_create_container_client_uses_explicit_sas_token(): container_url = "https://youraccount.blob.core.windows.net/yourcontainer" @@ -164,6 +206,23 @@ async def test_azure_storage_io_path_exists(azure_blob_storage_io): assert exists is True +@pytest.mark.asyncio +async def test_azure_storage_io_path_exists_with_relative_path(azure_blob_storage_io): + mock_container_client = AsyncMock() + azure_blob_storage_io._client_async = mock_container_client + + mock_blob_client = AsyncMock() + + mock_container_client.get_blob_client = Mock(return_value=mock_blob_client) + mock_blob_client.get_blob_properties = AsyncMock() + mock_container_client.close = AsyncMock() + + exists = await azure_blob_storage_io.path_exists("dir1/dir2/blob_name.txt") + + assert exists is True + mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/blob_name.txt") + + @pytest.mark.asyncio async def test_azure_storage_io_is_file(azure_blob_storage_io): azure_blob_storage_io._client_async = AsyncMock() @@ -179,6 +238,24 @@ async def test_azure_storage_io_is_file(azure_blob_storage_io): assert is_file is True +@pytest.mark.asyncio +async def test_azure_storage_io_is_file_with_relative_path(azure_blob_storage_io): + mock_container_client = AsyncMock() + azure_blob_storage_io._client_async = mock_container_client + + mock_blob_client = AsyncMock() + + mock_container_client.get_blob_client = Mock(return_value=mock_blob_client) + mock_blob_properties = Mock(size=1024) + mock_blob_client.get_blob_properties = AsyncMock(return_value=mock_blob_properties) + mock_container_client.close = AsyncMock() + + is_file = await azure_blob_storage_io.is_file("dir1/dir2/blob_name.txt") + + assert is_file is True + mock_container_client.get_blob_client.assert_called_once_with(blob="dir1/dir2/blob_name.txt") + + def test_azure_storage_io_parse_blob_url_valid(azure_blob_storage_io): file_path = "https://example.blob.core.windows.net/container/dir1/dir2/blob_name.txt" container_name, blob_name = azure_blob_storage_io.parse_blob_url(file_path) @@ -200,3 +277,27 @@ def test_azure_storage_io_parse_blob_url_without_scheme(azure_blob_storage_io): def test_azure_storage_io_parse_blob_url_without_netloc(azure_blob_storage_io): with pytest.raises(ValueError, match="Invalid blob URL"): azure_blob_storage_io.parse_blob_url("https:///container/dir1/blob_name.txt") + + +def test_resolve_blob_name_with_full_url(azure_blob_storage_io): + result = azure_blob_storage_io._resolve_blob_name("https://account.blob.core.windows.net/container/dir1/file.txt") + assert result == "dir1/file.txt" + + +def test_resolve_blob_name_with_relative_path(azure_blob_storage_io): + assert azure_blob_storage_io._resolve_blob_name("dir1/dir2/file.txt") == "dir1/dir2/file.txt" + + +def test_resolve_blob_name_with_simple_filename(azure_blob_storage_io): + assert azure_blob_storage_io._resolve_blob_name("file.txt") == "file.txt" + + +def test_resolve_blob_name_normalizes_backslashes(azure_blob_storage_io): + assert azure_blob_storage_io._resolve_blob_name("dir1\\dir2\\file.txt") == "dir1/dir2/file.txt" + + +def test_resolve_blob_name_with_path_object(azure_blob_storage_io): + from pathlib import PurePosixPath + + result = azure_blob_storage_io._resolve_blob_name(PurePosixPath("dir1/dir2/file.txt")) + assert result == "dir1/dir2/file.txt"