324 行
10 KiB
Python
324 行
10 KiB
Python
|
# Copied from: https://github.com/encode/httpx/blob/master/httpx/_models.py,
|
||
|
# which is licensed under the BSD License.
|
||
|
# See https://github.com/encode/httpx/blob/master/LICENSE.md
|
||
|
|
||
|
import typing
|
||
|
from collections.abc import Mapping
|
||
|
|
||
|
HeaderTypes = typing.Union[
|
||
|
"Headers",
|
||
|
typing.Mapping[str, str],
|
||
|
typing.Mapping[bytes, bytes],
|
||
|
typing.Sequence[typing.Tuple[str, str]],
|
||
|
typing.Sequence[typing.Tuple[bytes, bytes]],
|
||
|
typing.Sequence[str],
|
||
|
typing.Sequence[bytes],
|
||
|
]
|
||
|
|
||
|
|
||
|
def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
|
||
|
return value if isinstance(value, str) else value.decode(encoding)
|
||
|
|
||
|
|
||
|
def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr:
|
||
|
return value if isinstance(match_type_of, str) else value.encode() # type: ignore
|
||
|
|
||
|
|
||
|
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
|
||
|
|
||
|
|
||
|
def obfuscate_sensitive_headers(
|
||
|
items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]]
|
||
|
) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]:
|
||
|
for k, v in items:
|
||
|
if to_str(k.lower()) in SENSITIVE_HEADERS:
|
||
|
v = to_bytes_or_str("[secure]", match_type_of=v)
|
||
|
yield k, v
|
||
|
|
||
|
|
||
|
def normalize_header_key(
|
||
|
value: typing.Union[str, bytes],
|
||
|
lower: bool,
|
||
|
encoding: typing.Optional[str] = None,
|
||
|
) -> bytes:
|
||
|
"""
|
||
|
Coerce str/bytes into a strictly byte-wise HTTP header key.
|
||
|
"""
|
||
|
if isinstance(value, bytes):
|
||
|
bytes_value = value
|
||
|
else:
|
||
|
bytes_value = value.encode(encoding or "ascii")
|
||
|
|
||
|
return bytes_value.lower() if lower else bytes_value
|
||
|
|
||
|
|
||
|
def normalize_header_value(
|
||
|
value: typing.Union[str, bytes], encoding: typing.Optional[str] = None
|
||
|
) -> bytes:
|
||
|
"""
|
||
|
Coerce str/bytes into a strictly byte-wise HTTP header value.
|
||
|
"""
|
||
|
if isinstance(value, bytes):
|
||
|
return value
|
||
|
return value.encode(encoding or "ascii")
|
||
|
|
||
|
|
||
|
class Headers(typing.MutableMapping[str, str]):
|
||
|
"""
|
||
|
HTTP headers, as a case-insensitive multi-dict.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
headers: typing.Optional[HeaderTypes] = None,
|
||
|
encoding: typing.Optional[str] = None,
|
||
|
) -> None:
|
||
|
if headers is None:
|
||
|
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
|
||
|
elif isinstance(headers, Headers):
|
||
|
self._list = list(headers._list)
|
||
|
elif isinstance(headers, Mapping):
|
||
|
self._list = [
|
||
|
(
|
||
|
normalize_header_key(k, lower=False, encoding=encoding),
|
||
|
normalize_header_key(k, lower=True, encoding=encoding),
|
||
|
normalize_header_value(v, encoding),
|
||
|
)
|
||
|
for k, v in headers.items()
|
||
|
]
|
||
|
else:
|
||
|
if isinstance(headers[0], (str, bytes)):
|
||
|
sep = ":" if isinstance(headers[0], str) else b":"
|
||
|
h = []
|
||
|
for line in headers:
|
||
|
k, v = line.split(sep, maxsplit=1) # type: ignore
|
||
|
v = v.lstrip()
|
||
|
h.append((k, v))
|
||
|
else:
|
||
|
h = headers
|
||
|
|
||
|
self._list = [
|
||
|
(
|
||
|
normalize_header_key(k, lower=False, encoding=encoding), # type: ignore
|
||
|
normalize_header_key(k, lower=True, encoding=encoding), # type: ignore
|
||
|
normalize_header_value(v, encoding), # type: ignore
|
||
|
)
|
||
|
for k, v in h
|
||
|
]
|
||
|
|
||
|
self._encoding = encoding
|
||
|
|
||
|
@property
|
||
|
def encoding(self) -> str:
|
||
|
"""
|
||
|
Header encoding is mandated as ascii, but we allow fallbacks to utf-8
|
||
|
or iso-8859-1.
|
||
|
"""
|
||
|
if self._encoding is None:
|
||
|
for encoding in ["ascii", "utf-8"]:
|
||
|
for key, value in self.raw:
|
||
|
try:
|
||
|
key.decode(encoding)
|
||
|
value.decode(encoding)
|
||
|
except UnicodeDecodeError:
|
||
|
break
|
||
|
else:
|
||
|
# The else block runs if 'break' did not occur, meaning
|
||
|
# all values fitted the encoding.
|
||
|
self._encoding = encoding
|
||
|
break
|
||
|
else:
|
||
|
# The ISO-8859-1 encoding covers all 256 code points in a byte,
|
||
|
# so will never raise decode errors.
|
||
|
self._encoding = "iso-8859-1"
|
||
|
return self._encoding
|
||
|
|
||
|
@encoding.setter
|
||
|
def encoding(self, value: str) -> None:
|
||
|
self._encoding = value
|
||
|
|
||
|
@property
|
||
|
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||
|
"""
|
||
|
Returns a list of the raw header items, as byte pairs.
|
||
|
"""
|
||
|
return [(raw_key, value) for raw_key, _, value in self._list]
|
||
|
|
||
|
def keys(self) -> typing.KeysView[str]:
|
||
|
return {key.decode(self.encoding): None for _, key, _ in self._list}.keys()
|
||
|
|
||
|
def values(self) -> typing.ValuesView[str]:
|
||
|
values_dict: typing.Dict[str, str] = {}
|
||
|
for _, key, value in self._list:
|
||
|
str_key = key.decode(self.encoding)
|
||
|
str_value = value.decode(self.encoding)
|
||
|
if str_key in values_dict:
|
||
|
values_dict[str_key] += f", {str_value}"
|
||
|
else:
|
||
|
values_dict[str_key] = str_value
|
||
|
return values_dict.values()
|
||
|
|
||
|
def items(self) -> typing.ItemsView[str, str]:
|
||
|
"""
|
||
|
Return `(key, value)` items of headers. Concatenate headers
|
||
|
into a single comma separated value when a key occurs multiple times.
|
||
|
"""
|
||
|
values_dict: typing.Dict[str, str] = {}
|
||
|
for _, key, value in self._list:
|
||
|
str_key = key.decode(self.encoding)
|
||
|
str_value = value.decode(self.encoding)
|
||
|
if str_key in values_dict:
|
||
|
values_dict[str_key] += f", {str_value}"
|
||
|
else:
|
||
|
values_dict[str_key] = str_value
|
||
|
return values_dict.items()
|
||
|
|
||
|
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
|
||
|
"""
|
||
|
Return a list of `(key, value)` pairs of headers. Allow multiple
|
||
|
occurrences of the same key without concatenating into a single
|
||
|
comma separated value.
|
||
|
"""
|
||
|
return [
|
||
|
(key.decode(self.encoding), value.decode(self.encoding))
|
||
|
for _, key, value in self._list
|
||
|
]
|
||
|
|
||
|
def get(self, key: str, default: typing.Any = None) -> typing.Any:
|
||
|
"""
|
||
|
Return a header value. If multiple occurrences of the header occur
|
||
|
then concatenate them together with commas.
|
||
|
"""
|
||
|
try:
|
||
|
return self[key]
|
||
|
except KeyError:
|
||
|
return default
|
||
|
|
||
|
def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]:
|
||
|
"""
|
||
|
Return a list of all header values for a given key.
|
||
|
If `split_commas=True` is passed, then any comma separated header
|
||
|
values are split into multiple return strings.
|
||
|
"""
|
||
|
get_header_key = key.lower().encode(self.encoding)
|
||
|
|
||
|
values = [
|
||
|
item_value.decode(self.encoding)
|
||
|
for _, item_key, item_value in self._list
|
||
|
if item_key.lower() == get_header_key
|
||
|
]
|
||
|
|
||
|
if not split_commas:
|
||
|
return values
|
||
|
|
||
|
split_values = []
|
||
|
for value in values:
|
||
|
split_values.extend([item.strip() for item in value.split(",")])
|
||
|
return split_values
|
||
|
|
||
|
def update(self, headers: typing.Optional[HeaderTypes] = None) -> None: # type: ignore
|
||
|
headers = Headers(headers)
|
||
|
for key in headers.keys():
|
||
|
if key in self:
|
||
|
self.pop(key)
|
||
|
self._list.extend(headers._list)
|
||
|
|
||
|
def copy(self) -> "Headers":
|
||
|
return Headers(self, encoding=self.encoding)
|
||
|
|
||
|
def __getitem__(self, key: str) -> str:
|
||
|
"""
|
||
|
Return a single header value.
|
||
|
If there are multiple headers with the same key, then we concatenate
|
||
|
them with commas. See: https://tools.ietf.org/html/rfc7230#section-3.2.2
|
||
|
"""
|
||
|
normalized_key = key.lower().encode(self.encoding)
|
||
|
|
||
|
items = [
|
||
|
header_value.decode(self.encoding)
|
||
|
for _, header_key, header_value in self._list
|
||
|
if header_key == normalized_key
|
||
|
]
|
||
|
|
||
|
if items:
|
||
|
return ", ".join(items)
|
||
|
|
||
|
raise KeyError(key)
|
||
|
|
||
|
def __setitem__(self, key: str, value: str) -> None:
|
||
|
"""
|
||
|
Set the header `key` to `value`, removing any duplicate entries.
|
||
|
Retains insertion order.
|
||
|
"""
|
||
|
set_key = key.encode(self._encoding or "utf-8")
|
||
|
set_value = value.encode(self._encoding or "utf-8")
|
||
|
lookup_key = set_key.lower()
|
||
|
|
||
|
found_indexes = [
|
||
|
idx
|
||
|
for idx, (_, item_key, _) in enumerate(self._list)
|
||
|
if item_key == lookup_key
|
||
|
]
|
||
|
|
||
|
for idx in reversed(found_indexes[1:]):
|
||
|
del self._list[idx]
|
||
|
|
||
|
if found_indexes:
|
||
|
idx = found_indexes[0]
|
||
|
self._list[idx] = (set_key, lookup_key, set_value)
|
||
|
else:
|
||
|
self._list.append((set_key, lookup_key, set_value))
|
||
|
|
||
|
def __delitem__(self, key: str) -> None:
|
||
|
"""
|
||
|
Remove the header `key`.
|
||
|
"""
|
||
|
del_key = key.lower().encode(self.encoding)
|
||
|
|
||
|
pop_indexes = [
|
||
|
idx
|
||
|
for idx, (_, item_key, _) in enumerate(self._list)
|
||
|
if item_key.lower() == del_key
|
||
|
]
|
||
|
|
||
|
if not pop_indexes:
|
||
|
raise KeyError(key)
|
||
|
|
||
|
for idx in reversed(pop_indexes):
|
||
|
del self._list[idx]
|
||
|
|
||
|
def __contains__(self, key: typing.Any) -> bool:
|
||
|
header_key = key.lower().encode(self.encoding)
|
||
|
return header_key in [key for _, key, _ in self._list]
|
||
|
|
||
|
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||
|
return iter(self.keys())
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self._list)
|
||
|
|
||
|
def __eq__(self, other: typing.Any) -> bool:
|
||
|
try:
|
||
|
other_headers = Headers(other)
|
||
|
except ValueError:
|
||
|
return False
|
||
|
|
||
|
self_list = [(key, value) for _, key, value in self._list]
|
||
|
other_list = [(key, value) for _, key, value in other_headers._list]
|
||
|
return sorted(self_list) == sorted(other_list)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
class_name = self.__class__.__name__
|
||
|
|
||
|
encoding_str = ""
|
||
|
if self.encoding != "ascii":
|
||
|
encoding_str = f", encoding={self.encoding!r}"
|
||
|
|
||
|
as_list = list(obfuscate_sensitive_headers(self.multi_items()))
|
||
|
as_dict = dict(as_list)
|
||
|
|
||
|
no_duplicate_keys = len(as_dict) == len(as_list)
|
||
|
if no_duplicate_keys:
|
||
|
return f"{class_name}({as_dict!r}{encoding_str})"
|
||
|
return f"{class_name}({as_list!r}{encoding_str})"
|