Skip to content

Add compression for to_csv() and to_json() #502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 64 additions & 44 deletions awswrangler/s3/_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import socket
from contextlib import contextmanager
from errno import ESPIPE
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast

import boto3
Expand Down Expand Up @@ -178,9 +179,11 @@ def close(self) -> List[Dict[str, Union[str, int]]]:
if self.closed is True:
return []
if self._exec is not None:
for future in concurrent.futures.as_completed(self._futures):
self._results.append(future.result())
self._exec.shutdown(wait=True)
try:
for future in concurrent.futures.as_completed(self._futures):
self._results.append(future.result())
finally:
self._exec.shutdown(wait=True)
self.closed = True
return self._sort_by_part_number(parts=self._results)

Expand All @@ -198,7 +201,11 @@ def __init__(
boto3_session: Optional[boto3.Session],
newline: Optional[str],
encoding: Optional[str],
raw_buffer: bool,
) -> None:
if raw_buffer is True and "w" not in mode:
raise exceptions.InvalidArgumentValue("raw_buffer=True is only acceptable on write mode.")
self._raw_buffer: bool = raw_buffer
self.closed: bool = False
self._use_threads = use_threads
self._newline: str = "\n" if newline is None else newline
Expand Down Expand Up @@ -242,7 +249,7 @@ def __init__(
else:
raise RuntimeError(f"Invalid mode: {self._mode}")

def __enter__(self) -> Union["_S3ObjectBase", io.TextIOWrapper]:
def __enter__(self) -> Union["_S3ObjectBase"]:
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
Expand All @@ -256,6 +263,19 @@ def __del__(self) -> None:
"""Delete object tear down."""
self.close()

def __next__(self) -> bytes:
"""Next line."""
out: Union[bytes, None] = self.readline()
if not out:
raise StopIteration
return out

next = __next__

def __iter__(self) -> "_S3ObjectBase":
"""Iterate over lines."""
return self

@staticmethod
def _merge_range(ranges: List[Tuple[int, bytes]]) -> bytes:
return b"".join(data for start, data in sorted(ranges, key=lambda r: r[0]))
Expand Down Expand Up @@ -372,7 +392,7 @@ def tell(self) -> int:
def seek(self, loc: int, whence: int = 0) -> int:
"""Set current file location."""
if self.readable() is False:
raise ValueError("Seek only available in read mode")
raise OSError(ESPIPE, "Seek only available in read mode")
if whence == 0:
loc_tmp: int = loc
elif whence == 1:
Expand Down Expand Up @@ -425,6 +445,9 @@ def flush(self, force: bool = False) -> None:
function_name="upload_part", s3_additional_kwargs=self._s3_additional_kwargs
),
)
self._buffer.seek(0)
self._buffer.truncate(0)
self._buffer.close()
self._buffer = io.BytesIO()
return None

Expand All @@ -448,9 +471,9 @@ def close(self) -> None:
_logger.debug("Closing: %s parts", self._parts_count)
if self._parts_count > 0:
self.flush(force=True)
pasts: List[Dict[str, Union[str, int]]] = self._upload_proxy.close()
part_info: Dict[str, List[Dict[str, Any]]] = {"Parts": pasts}
_logger.debug("complete_multipart_upload")
parts: List[Dict[str, Union[str, int]]] = self._upload_proxy.close()
part_info: Dict[str, List[Dict[str, Any]]] = {"Parts": parts}
_logger.debug("Running complete_multipart_upload...")
_utils.try_it(
f=self._client.complete_multipart_upload,
ex=_S3_RETRYABLE_ERRORS,
Expand All @@ -464,7 +487,8 @@ def close(self) -> None:
function_name="complete_multipart_upload", s3_additional_kwargs=self._s3_additional_kwargs
),
)
elif self._buffer.tell() > 0:
_logger.debug("complete_multipart_upload done!")
elif self._buffer.tell() > 0 or self._raw_buffer is True:
_logger.debug("put_object")
_utils.try_it(
f=self._client.put_object,
Expand All @@ -482,43 +506,21 @@ def close(self) -> None:
self._buffer.seek(0)
self._buffer.truncate(0)
self._upload_proxy.close()
self._buffer.close()
elif self.readable():
self._cache = b""
else:
raise RuntimeError(f"Invalid mode: {self._mode}")
self.closed = True
return None

def get_raw_buffer(self) -> io.BytesIO:
"""Return the Raw Buffer if it is possible."""
if self._raw_buffer is False:
raise exceptions.InvalidArgumentValue("Trying to get raw buffer with raw_buffer=False.")
return self._buffer

class _S3ObjectWriter(_S3ObjectBase):
def write(self, data: bytes) -> int:
"""Write data to buffer and only upload on close() or if buffer is greater than or equal to _MIN_WRITE_BLOCK."""
if self.writable() is False:
raise RuntimeError("File not in write mode.")
if self.closed:
raise RuntimeError("I/O operation on closed file.")
n: int = self._buffer.write(data)
self._loc += n
if self._buffer.tell() >= _MIN_WRITE_BLOCK:
self.flush()
return n


class _S3ObjectReader(_S3ObjectBase):
def __next__(self) -> Union[bytes, str]:
"""Next line."""
out: Union[bytes, str, None] = self.readline()
if not out:
raise StopIteration
return out

next = __next__

def __iter__(self) -> "_S3ObjectReader":
"""Iterate over lines."""
return self

def read(self, length: int = -1) -> Union[bytes, str]:
def read(self, length: int = -1) -> bytes:
"""Return cached data and fetch on demand chunks."""
if self.readable() is False:
raise ValueError("File not in read mode.")
Expand All @@ -532,7 +534,7 @@ def read(self, length: int = -1) -> Union[bytes, str]:
self._loc += len(out)
return out

def readline(self, length: int = -1) -> Union[bytes, str]:
def readline(self, length: int = -1) -> bytes:
"""Read until the next line terminator."""
end: int = self._loc + self._s3_block_size
end = self._size if end > self._size else end
Expand All @@ -551,11 +553,25 @@ def readline(self, length: int = -1) -> Union[bytes, str]:
end = self._size if end > self._size else end
self._fetch(self._loc, end)

def readlines(self) -> List[Union[bytes, str]]:
def readlines(self) -> List[bytes]:
"""Return all lines as list."""
return list(self)


class _S3ObjectWriter(_S3ObjectBase):
def write(self, data: bytes) -> int:
"""Write data to buffer and only upload on close() or if buffer is greater than or equal to _MIN_WRITE_BLOCK."""
if self.writable() is False:
raise RuntimeError("File not in write mode.")
if self.closed:
raise RuntimeError("I/O operation on closed file.")
n: int = self._buffer.write(data)
self._loc += n
if self._buffer.tell() >= _MIN_WRITE_BLOCK:
self.flush()
return n


@contextmanager
@apply_configs
def open_s3_object(
Expand All @@ -567,11 +583,12 @@ def open_s3_object(
boto3_session: Optional[boto3.Session] = None,
newline: Optional[str] = "\n",
encoding: Optional[str] = "utf-8",
) -> Iterator[Union[_S3ObjectReader, _S3ObjectWriter, io.TextIOWrapper]]:
raw_buffer: bool = False,
) -> Iterator[Union[_S3ObjectBase, _S3ObjectWriter, io.TextIOWrapper, io.BytesIO]]:
"""Return a _S3Object or TextIOWrapper based in the received mode."""
s3obj: Optional[Union[_S3ObjectReader, _S3ObjectWriter]] = None
s3obj: Optional[Union[_S3ObjectBase, _S3ObjectWriter]] = None
text_s3obj: Optional[io.TextIOWrapper] = None
s3_class: Union[Type[_S3ObjectReader], Type[_S3ObjectWriter]] = _S3ObjectWriter if "w" in mode else _S3ObjectReader
s3_class: Union[Type[_S3ObjectBase], Type[_S3ObjectWriter]] = _S3ObjectWriter if "w" in mode else _S3ObjectBase
try:
s3obj = s3_class(
path=path,
Expand All @@ -582,8 +599,11 @@ def open_s3_object(
boto3_session=boto3_session,
encoding=encoding,
newline=newline,
raw_buffer=raw_buffer,
)
if "b" in mode: # binary
if raw_buffer is True: # Only useful for plain io.BytesIO write
yield s3obj.get_raw_buffer()
elif "b" in mode: # binary
yield s3obj
else: # text
text_s3obj = io.TextIOWrapper(
Expand Down
9 changes: 8 additions & 1 deletion awswrangler/s3/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@

_logger: logging.Logger = logging.getLogger(__name__)

_COMPRESSION_2_EXT: Dict[Optional[str], str] = {None: "", "gzip": ".gz", "snappy": ".snappy"}
_COMPRESSION_2_EXT: Dict[Optional[str], str] = {
None: "",
"gzip": ".gz",
"snappy": ".snappy",
"bz2": ".bz2",
"xz": ".xz",
"zip": ".zip",
}


def _extract_dtypes_from_table_input(table_input: Dict[str, Any]) -> Dict[str, str]:
Expand Down
Loading