# Copied from: https://github.com/encode/httpx/blob/master/httpx/_models.py, # which is licensed under the BSD License. # See https://github.com/encode/httpx/blob/master/LICENSE.md import typing from collections.abc import Mapping HeaderTypes = typing.Union[ "Headers", typing.Mapping[str, str], typing.Mapping[bytes, bytes], typing.Sequence[typing.Tuple[str, str]], typing.Sequence[typing.Tuple[bytes, bytes]], typing.Sequence[str], typing.Sequence[bytes], ] def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str: return value if isinstance(value, str) else value.decode(encoding) def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr: return value if isinstance(match_type_of, str) else value.encode() # type: ignore SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} def obfuscate_sensitive_headers( items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]] ) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]: for k, v in items: if to_str(k.lower()) in SENSITIVE_HEADERS: v = to_bytes_or_str("[secure]", match_type_of=v) yield k, v def normalize_header_key( value: typing.Union[str, bytes], lower: bool, encoding: typing.Optional[str] = None, ) -> bytes: """ Coerce str/bytes into a strictly byte-wise HTTP header key. """ if isinstance(value, bytes): bytes_value = value else: bytes_value = value.encode(encoding or "ascii") return bytes_value.lower() if lower else bytes_value def normalize_header_value( value: typing.Union[str, bytes], encoding: typing.Optional[str] = None ) -> bytes: """ Coerce str/bytes into a strictly byte-wise HTTP header value. """ if isinstance(value, bytes): return value return value.encode(encoding or "ascii") class Headers(typing.MutableMapping[str, str]): """ HTTP headers, as a case-insensitive multi-dict. """ def __init__( self, headers: typing.Optional[HeaderTypes] = None, encoding: typing.Optional[str] = None, ) -> None: if headers is None: self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]] elif isinstance(headers, Headers): self._list = list(headers._list) elif isinstance(headers, Mapping): self._list = [ ( normalize_header_key(k, lower=False, encoding=encoding), normalize_header_key(k, lower=True, encoding=encoding), normalize_header_value(v, encoding), ) for k, v in headers.items() ] else: if isinstance(headers[0], (str, bytes)): sep = ":" if isinstance(headers[0], str) else b":" h = [] for line in headers: k, v = line.split(sep, maxsplit=1) # type: ignore v = v.lstrip() h.append((k, v)) else: h = headers self._list = [ ( normalize_header_key(k, lower=False, encoding=encoding), # type: ignore normalize_header_key(k, lower=True, encoding=encoding), # type: ignore normalize_header_value(v, encoding), # type: ignore ) for k, v in h ] self._encoding = encoding @property def encoding(self) -> str: """ Header encoding is mandated as ascii, but we allow fallbacks to utf-8 or iso-8859-1. """ if self._encoding is None: for encoding in ["ascii", "utf-8"]: for key, value in self.raw: try: key.decode(encoding) value.decode(encoding) except UnicodeDecodeError: break else: # The else block runs if 'break' did not occur, meaning # all values fitted the encoding. self._encoding = encoding break else: # The ISO-8859-1 encoding covers all 256 code points in a byte, # so will never raise decode errors. self._encoding = "iso-8859-1" return self._encoding @encoding.setter def encoding(self, value: str) -> None: self._encoding = value @property def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: """ Returns a list of the raw header items, as byte pairs. """ return [(raw_key, value) for raw_key, _, value in self._list] def keys(self) -> typing.KeysView[str]: return {key.decode(self.encoding): None for _, key, _ in self._list}.keys() def values(self) -> typing.ValuesView[str]: values_dict: typing.Dict[str, str] = {} for _, key, value in self._list: str_key = key.decode(self.encoding) str_value = value.decode(self.encoding) if str_key in values_dict: values_dict[str_key] += f", {str_value}" else: values_dict[str_key] = str_value return values_dict.values() def items(self) -> typing.ItemsView[str, str]: """ Return `(key, value)` items of headers. Concatenate headers into a single comma separated value when a key occurs multiple times. """ values_dict: typing.Dict[str, str] = {} for _, key, value in self._list: str_key = key.decode(self.encoding) str_value = value.decode(self.encoding) if str_key in values_dict: values_dict[str_key] += f", {str_value}" else: values_dict[str_key] = str_value return values_dict.items() def multi_items(self) -> typing.List[typing.Tuple[str, str]]: """ Return a list of `(key, value)` pairs of headers. Allow multiple occurrences of the same key without concatenating into a single comma separated value. """ return [ (key.decode(self.encoding), value.decode(self.encoding)) for _, key, value in self._list ] def get(self, key: str, default: typing.Any = None) -> typing.Any: """ Return a header value. If multiple occurrences of the header occur then concatenate them together with commas. """ try: return self[key] except KeyError: return default def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]: """ Return a list of all header values for a given key. If `split_commas=True` is passed, then any comma separated header values are split into multiple return strings. """ get_header_key = key.lower().encode(self.encoding) values = [ item_value.decode(self.encoding) for _, item_key, item_value in self._list if item_key.lower() == get_header_key ] if not split_commas: return values split_values = [] for value in values: split_values.extend([item.strip() for item in value.split(",")]) return split_values def update(self, headers: typing.Optional[HeaderTypes] = None) -> None: # type: ignore headers = Headers(headers) for key in headers.keys(): if key in self: self.pop(key) self._list.extend(headers._list) def copy(self) -> "Headers": return Headers(self, encoding=self.encoding) def __getitem__(self, key: str) -> str: """ Return a single header value. If there are multiple headers with the same key, then we concatenate them with commas. See: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ normalized_key = key.lower().encode(self.encoding) items = [ header_value.decode(self.encoding) for _, header_key, header_value in self._list if header_key == normalized_key ] if items: return ", ".join(items) raise KeyError(key) def __setitem__(self, key: str, value: str) -> None: """ Set the header `key` to `value`, removing any duplicate entries. Retains insertion order. """ set_key = key.encode(self._encoding or "utf-8") set_value = value.encode(self._encoding or "utf-8") lookup_key = set_key.lower() found_indexes = [ idx for idx, (_, item_key, _) in enumerate(self._list) if item_key == lookup_key ] for idx in reversed(found_indexes[1:]): del self._list[idx] if found_indexes: idx = found_indexes[0] self._list[idx] = (set_key, lookup_key, set_value) else: self._list.append((set_key, lookup_key, set_value)) def __delitem__(self, key: str) -> None: """ Remove the header `key`. """ del_key = key.lower().encode(self.encoding) pop_indexes = [ idx for idx, (_, item_key, _) in enumerate(self._list) if item_key.lower() == del_key ] if not pop_indexes: raise KeyError(key) for idx in reversed(pop_indexes): del self._list[idx] def __contains__(self, key: typing.Any) -> bool: header_key = key.lower().encode(self.encoding) return header_key in [key for _, key, _ in self._list] def __iter__(self) -> typing.Iterator[typing.Any]: return iter(self.keys()) def __len__(self) -> int: return len(self._list) def __eq__(self, other: typing.Any) -> bool: try: other_headers = Headers(other) except ValueError: return False self_list = [(key, value) for _, key, value in self._list] other_list = [(key, value) for _, key, value in other_headers._list] return sorted(self_list) == sorted(other_list) def __repr__(self) -> str: class_name = self.__class__.__name__ encoding_str = "" if self.encoding != "ascii": encoding_str = f", encoding={self.encoding!r}" as_list = list(obfuscate_sensitive_headers(self.multi_items())) as_dict = dict(as_list) no_duplicate_keys = len(as_dict) == len(as_list) if no_duplicate_keys: return f"{class_name}({as_dict!r}{encoding_str})" return f"{class_name}({as_list!r}{encoding_str})"