diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 2f713d49d..896479665 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -1,5 +1,5 @@ from ssl import SSLContext -from typing import Optional, Tuple, cast +from typing import List, Optional, Tuple, cast from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend from .._exceptions import ConnectError, ConnectTimeout @@ -23,6 +23,7 @@ class AsyncHTTPConnection(AsyncHTTPTransport): def __init__( self, origin: Origin, + http1: bool = True, http2: bool = False, uds: str = None, ssl_context: SSLContext = None, @@ -32,6 +33,7 @@ def __init__( backend: AsyncBackend = None, ): self.origin = origin + self.http1 = http1 self.http2 = http2 self.uds = uds self.ssl_context = SSLContext() if ssl_context is None else ssl_context @@ -39,8 +41,13 @@ def __init__( self.local_address = local_address self.retries = retries - if self.http2: - self.ssl_context.set_alpn_protocols(["http/1.1", "h2"]) + alpn_protocols: List[str] = [] + if http1: + alpn_protocols.append("http/1.1") + if http2: + alpn_protocols.append("h2") + + self.ssl_context.set_alpn_protocols(alpn_protocols) self.connection: Optional[AsyncBaseHTTPConnection] = None self.is_http11 = False @@ -147,7 +154,7 @@ def _create_connection(self, socket: AsyncSocketStream) -> None: logger.trace( "create_connection socket=%r http_version=%r", socket, http_version ) - if http_version == "HTTP/2": + if http_version == "HTTP/2" or (self.http2 and not self.http1): from .http2 import AsyncHTTP2Connection self.is_http2 = True diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 650abfc47..089271344 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -89,8 +89,10 @@ class AsyncConnectionPool(AsyncHTTPTransport): connections. keepalive_expiry: The maximum time to allow before closing a keep-alive connection. + http1: + Enable/Disable HTTP/1.1 support. Defaults to True. http2: - Enable HTTP/2 support. + Enable/Disable HTTP/2 support. Defaults to False. uds: Path to a Unix Domain Socket to use instead of TCP sockets. local_address: @@ -110,6 +112,7 @@ def __init__( max_connections: int = None, max_keepalive_connections: int = None, keepalive_expiry: float = None, + http1: bool = True, http2: bool = False, uds: str = None, local_address: str = None, @@ -131,6 +134,7 @@ def __init__( self._max_connections = max_connections self._max_keepalive_connections = max_keepalive_connections self._keepalive_expiry = keepalive_expiry + self._http1 = http1 self._http2 = http2 self._uds = uds self._local_address = local_address @@ -140,6 +144,9 @@ def __init__( self._backend = backend self._next_keepalive_check = 0.0 + if not (http1 or http2): + raise ValueError("Either http1 or http2 must be True.") + if http2: try: import h2 # noqa: F401 @@ -175,6 +182,7 @@ def _create_connection( ) -> AsyncHTTPConnection: return AsyncHTTPConnection( origin=origin, + http1=self._http1, http2=self._http2, uds=self._uds, ssl_context=self._ssl_context, diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 26e69e0ac..a49a13d4c 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -1,5 +1,5 @@ from ssl import SSLContext -from typing import Optional, Tuple, cast +from typing import List, Optional, Tuple, cast from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend from .._exceptions import ConnectError, ConnectTimeout @@ -23,6 +23,7 @@ class SyncHTTPConnection(SyncHTTPTransport): def __init__( self, origin: Origin, + http1: bool = True, http2: bool = False, uds: str = None, ssl_context: SSLContext = None, @@ -32,6 +33,7 @@ def __init__( backend: SyncBackend = None, ): self.origin = origin + self.http1 = http1 self.http2 = http2 self.uds = uds self.ssl_context = SSLContext() if ssl_context is None else ssl_context @@ -39,8 +41,13 @@ def __init__( self.local_address = local_address self.retries = retries - if self.http2: - self.ssl_context.set_alpn_protocols(["http/1.1", "h2"]) + alpn_protocols: List[str] = [] + if http1: + alpn_protocols.append("http/1.1") + if http2: + alpn_protocols.append("h2") + + self.ssl_context.set_alpn_protocols(alpn_protocols) self.connection: Optional[SyncBaseHTTPConnection] = None self.is_http11 = False @@ -147,7 +154,7 @@ def _create_connection(self, socket: SyncSocketStream) -> None: logger.trace( "create_connection socket=%r http_version=%r", socket, http_version ) - if http_version == "HTTP/2": + if http_version == "HTTP/2" or (self.http2 and not self.http1): from .http2 import SyncHTTP2Connection self.is_http2 = True diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 7d2ba9eb4..a60437b64 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -89,8 +89,10 @@ class SyncConnectionPool(SyncHTTPTransport): connections. keepalive_expiry: The maximum time to allow before closing a keep-alive connection. + http1: + Enable/Disable HTTP/1.1 support. Defaults to True. http2: - Enable HTTP/2 support. + Enable/Disable HTTP/2 support. Defaults to False. uds: Path to a Unix Domain Socket to use instead of TCP sockets. local_address: @@ -110,6 +112,7 @@ def __init__( max_connections: int = None, max_keepalive_connections: int = None, keepalive_expiry: float = None, + http1: bool = True, http2: bool = False, uds: str = None, local_address: str = None, @@ -131,6 +134,7 @@ def __init__( self._max_connections = max_connections self._max_keepalive_connections = max_keepalive_connections self._keepalive_expiry = keepalive_expiry + self._http1 = http1 self._http2 = http2 self._uds = uds self._local_address = local_address @@ -140,6 +144,9 @@ def __init__( self._backend = backend self._next_keepalive_check = 0.0 + if not (http1 or http2): + raise ValueError("Either http1 or http2 must be True.") + if http2: try: import h2 # noqa: F401 @@ -175,6 +182,7 @@ def _create_connection( ) -> SyncHTTPConnection: return SyncHTTPConnection( origin=origin, + http1=self._http1, http2=self._http2, uds=self._uds, ssl_context=self._ssl_context, diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index 36b663946..b17f266b4 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -23,6 +23,11 @@ async def read_body(stream: httpcore.AsyncByteStream) -> bytes: await stream.aclose() +def test_must_configure_either_http1_or_http2() -> None: + with pytest.raises(ValueError): + httpcore.AsyncConnectionPool(http1=False, http2=False) + + @pytest.mark.anyio async def test_http_request(backend: str, server: Server) -> None: async with httpcore.AsyncConnectionPool(backend=backend) as http: diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index 270af4234..767357d08 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -23,6 +23,11 @@ def read_body(stream: httpcore.SyncByteStream) -> bytes: stream.close() +def test_must_configure_either_http1_or_http2() -> None: + with pytest.raises(ValueError): + httpcore.SyncConnectionPool(http1=False, http2=False) + + def test_http_request(backend: str, server: Server) -> None: with httpcore.SyncConnectionPool(backend=backend) as http: