267 lines
8.8 KiB
Python
267 lines
8.8 KiB
Python
import io
|
|
import os
|
|
import typing
|
|
from pathlib import Path
|
|
|
|
from ._types import (
|
|
AsyncByteStream,
|
|
FileContent,
|
|
FileTypes,
|
|
RequestData,
|
|
RequestFiles,
|
|
SyncByteStream,
|
|
)
|
|
from ._utils import (
|
|
format_form_param,
|
|
guess_content_type,
|
|
peek_filelike_length,
|
|
primitive_value_to_str,
|
|
to_bytes,
|
|
)
|
|
|
|
|
|
def get_multipart_boundary_from_content_type(
|
|
content_type: typing.Optional[bytes],
|
|
) -> typing.Optional[bytes]:
|
|
if not content_type or not content_type.startswith(b"multipart/form-data"):
|
|
return None
|
|
# parse boundary according to
|
|
# https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
|
|
if b";" in content_type:
|
|
for section in content_type.split(b";"):
|
|
if section.strip().lower().startswith(b"boundary="):
|
|
return section.strip()[len(b"boundary=") :].strip(b'"')
|
|
return None
|
|
|
|
|
|
class DataField:
|
|
"""
|
|
A single form field item, within a multipart form field.
|
|
"""
|
|
|
|
def __init__(
|
|
self, name: str, value: typing.Union[str, bytes, int, float, None]
|
|
) -> None:
|
|
if not isinstance(name, str):
|
|
raise TypeError(
|
|
f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
|
|
)
|
|
if value is not None and not isinstance(value, (str, bytes, int, float)):
|
|
raise TypeError(
|
|
f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}"
|
|
)
|
|
self.name = name
|
|
self.value: typing.Union[str, bytes] = (
|
|
value if isinstance(value, bytes) else primitive_value_to_str(value)
|
|
)
|
|
|
|
def render_headers(self) -> bytes:
|
|
if not hasattr(self, "_headers"):
|
|
name = format_form_param("name", self.name)
|
|
self._headers = b"".join(
|
|
[b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
|
|
)
|
|
|
|
return self._headers
|
|
|
|
def render_data(self) -> bytes:
|
|
if not hasattr(self, "_data"):
|
|
self._data = to_bytes(self.value)
|
|
|
|
return self._data
|
|
|
|
def get_length(self) -> int:
|
|
headers = self.render_headers()
|
|
data = self.render_data()
|
|
return len(headers) + len(data)
|
|
|
|
def render(self) -> typing.Iterator[bytes]:
|
|
yield self.render_headers()
|
|
yield self.render_data()
|
|
|
|
|
|
class FileField:
|
|
"""
|
|
A single file field item, within a multipart form field.
|
|
"""
|
|
|
|
CHUNK_SIZE = 64 * 1024
|
|
|
|
def __init__(self, name: str, value: FileTypes) -> None:
|
|
self.name = name
|
|
|
|
fileobj: FileContent
|
|
|
|
headers: typing.Dict[str, str] = {}
|
|
content_type: typing.Optional[str] = None
|
|
|
|
# This large tuple based API largely mirror's requests' API
|
|
# It would be good to think of better APIs for this that we could include in httpx 2.0
|
|
# since variable length tuples (especially of 4 elements) are quite unwieldly
|
|
if isinstance(value, tuple):
|
|
if len(value) == 2:
|
|
# neither the 3rd parameter (content_type) nor the 4th (headers) was included
|
|
filename, fileobj = value # type: ignore
|
|
elif len(value) == 3:
|
|
filename, fileobj, content_type = value # type: ignore
|
|
else:
|
|
# all 4 parameters included
|
|
filename, fileobj, content_type, headers = value # type: ignore
|
|
else:
|
|
filename = Path(str(getattr(value, "name", "upload"))).name
|
|
fileobj = value
|
|
|
|
if content_type is None:
|
|
content_type = guess_content_type(filename)
|
|
|
|
has_content_type_header = any("content-type" in key.lower() for key in headers)
|
|
if content_type is not None and not has_content_type_header:
|
|
# note that unlike requests, we ignore the content_type
|
|
# provided in the 3rd tuple element if it is also included in the headers
|
|
# requests does the opposite (it overwrites the header with the 3rd tuple element)
|
|
headers["Content-Type"] = content_type
|
|
|
|
if isinstance(fileobj, io.StringIO):
|
|
raise TypeError(
|
|
"Multipart file uploads require 'io.BytesIO', not 'io.StringIO'."
|
|
)
|
|
if isinstance(fileobj, io.TextIOBase):
|
|
raise TypeError(
|
|
"Multipart file uploads must be opened in binary mode, not text mode."
|
|
)
|
|
|
|
self.filename = filename
|
|
self.file = fileobj
|
|
self.headers = headers
|
|
|
|
def get_length(self) -> typing.Optional[int]:
|
|
headers = self.render_headers()
|
|
|
|
if isinstance(self.file, (str, bytes)):
|
|
return len(headers) + len(to_bytes(self.file))
|
|
|
|
file_length = peek_filelike_length(self.file)
|
|
|
|
# If we can't determine the filesize without reading it into memory,
|
|
# then return `None` here, to indicate an unknown file length.
|
|
if file_length is None:
|
|
return None
|
|
|
|
return len(headers) + file_length
|
|
|
|
def render_headers(self) -> bytes:
|
|
if not hasattr(self, "_headers"):
|
|
parts = [
|
|
b"Content-Disposition: form-data; ",
|
|
format_form_param("name", self.name),
|
|
]
|
|
if self.filename:
|
|
filename = format_form_param("filename", self.filename)
|
|
parts.extend([b"; ", filename])
|
|
for header_name, header_value in self.headers.items():
|
|
key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
|
|
parts.extend([key, val])
|
|
parts.append(b"\r\n\r\n")
|
|
self._headers = b"".join(parts)
|
|
|
|
return self._headers
|
|
|
|
def render_data(self) -> typing.Iterator[bytes]:
|
|
if isinstance(self.file, (str, bytes)):
|
|
yield to_bytes(self.file)
|
|
return
|
|
|
|
if hasattr(self.file, "seek"):
|
|
try:
|
|
self.file.seek(0)
|
|
except io.UnsupportedOperation:
|
|
pass
|
|
|
|
chunk = self.file.read(self.CHUNK_SIZE)
|
|
while chunk:
|
|
yield to_bytes(chunk)
|
|
chunk = self.file.read(self.CHUNK_SIZE)
|
|
|
|
def render(self) -> typing.Iterator[bytes]:
|
|
yield self.render_headers()
|
|
yield from self.render_data()
|
|
|
|
|
|
class MultipartStream(SyncByteStream, AsyncByteStream):
|
|
"""
|
|
Request content as streaming multipart encoded form data.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data: RequestData,
|
|
files: RequestFiles,
|
|
boundary: typing.Optional[bytes] = None,
|
|
) -> None:
|
|
if boundary is None:
|
|
boundary = os.urandom(16).hex().encode("ascii")
|
|
|
|
self.boundary = boundary
|
|
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
|
|
"ascii"
|
|
)
|
|
self.fields = list(self._iter_fields(data, files))
|
|
|
|
def _iter_fields(
|
|
self, data: RequestData, files: RequestFiles
|
|
) -> typing.Iterator[typing.Union[FileField, DataField]]:
|
|
for name, value in data.items():
|
|
if isinstance(value, (tuple, list)):
|
|
for item in value:
|
|
yield DataField(name=name, value=item)
|
|
else:
|
|
yield DataField(name=name, value=value)
|
|
|
|
file_items = files.items() if isinstance(files, typing.Mapping) else files
|
|
for name, value in file_items:
|
|
yield FileField(name=name, value=value)
|
|
|
|
def iter_chunks(self) -> typing.Iterator[bytes]:
|
|
for field in self.fields:
|
|
yield b"--%s\r\n" % self.boundary
|
|
yield from field.render()
|
|
yield b"\r\n"
|
|
yield b"--%s--\r\n" % self.boundary
|
|
|
|
def get_content_length(self) -> typing.Optional[int]:
|
|
"""
|
|
Return the length of the multipart encoded content, or `None` if
|
|
any of the files have a length that cannot be determined upfront.
|
|
"""
|
|
boundary_length = len(self.boundary)
|
|
length = 0
|
|
|
|
for field in self.fields:
|
|
field_length = field.get_length()
|
|
if field_length is None:
|
|
return None
|
|
|
|
length += 2 + boundary_length + 2 # b"--{boundary}\r\n"
|
|
length += field_length
|
|
length += 2 # b"\r\n"
|
|
|
|
length += 2 + boundary_length + 4 # b"--{boundary}--\r\n"
|
|
return length
|
|
|
|
# Content stream interface.
|
|
|
|
def get_headers(self) -> typing.Dict[str, str]:
|
|
content_length = self.get_content_length()
|
|
content_type = self.content_type
|
|
if content_length is None:
|
|
return {"Transfer-Encoding": "chunked", "Content-Type": content_type}
|
|
return {"Content-Length": str(content_length), "Content-Type": content_type}
|
|
|
|
def __iter__(self) -> typing.Iterator[bytes]:
|
|
for chunk in self.iter_chunks():
|
|
yield chunk
|
|
|
|
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
|
for chunk in self.iter_chunks():
|
|
yield chunk
|