Skip to content
Merged
15 changes: 11 additions & 4 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ssl import SSLContext
from typing import Optional, Tuple, cast
from typing import List, Optional, Tuple, cast

from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend
from .._exceptions import ConnectError, ConnectTimeout
Expand All @@ -23,6 +23,7 @@ class AsyncHTTPConnection(AsyncHTTPTransport):
def __init__(
self,
origin: Origin,
http1: bool = True,
http2: bool = False,
uds: str = None,
ssl_context: SSLContext = None,
Expand All @@ -32,15 +33,21 @@ def __init__(
backend: AsyncBackend = None,
):
self.origin = origin
self.http1 = http1
self.http2 = http2
self.uds = uds
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_address = local_address
self.retries = retries

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
alpn_protocols: List[str] = []
if http1:
alpn_protocols.append("http/1.1")
if http2:
alpn_protocols.append("h2")

self.ssl_context.set_alpn_protocols(alpn_protocols)

self.connection: Optional[AsyncBaseHTTPConnection] = None
self.is_http11 = False
Expand Down Expand Up @@ -147,7 +154,7 @@ def _create_connection(self, socket: AsyncSocketStream) -> None:
logger.trace(
"create_connection socket=%r http_version=%r", socket, http_version
)
if http_version == "HTTP/2":
if http_version == "HTTP/2" or (self.http2 and not self.http1):
from .http2 import AsyncHTTP2Connection

self.is_http2 = True
Expand Down
10 changes: 9 additions & 1 deletion httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ class AsyncConnectionPool(AsyncHTTPTransport):
connections.
keepalive_expiry:
The maximum time to allow before closing a keep-alive connection.
http1:
Enable/Disable HTTP/1.1 support. Defaults to True.
http2:
Enable HTTP/2 support.
Enable/Disable HTTP/2 support. Defaults to False.
uds:
Path to a Unix Domain Socket to use instead of TCP sockets.
local_address:
Expand All @@ -110,6 +112,7 @@ def __init__(
max_connections: int = None,
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
uds: str = None,
local_address: str = None,
Expand All @@ -131,6 +134,7 @@ def __init__(
self._max_connections = max_connections
self._max_keepalive_connections = max_keepalive_connections
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._uds = uds
self._local_address = local_address
Expand All @@ -140,6 +144,9 @@ def __init__(
self._backend = backend
self._next_keepalive_check = 0.0

if not (http1 or http2):
raise ValueError("Either http1 or http2 must be True.")

if http2:
try:
import h2 # noqa: F401
Expand Down Expand Up @@ -175,6 +182,7 @@ def _create_connection(
) -> AsyncHTTPConnection:
return AsyncHTTPConnection(
origin=origin,
http1=self._http1,
http2=self._http2,
uds=self._uds,
ssl_context=self._ssl_context,
Expand Down
15 changes: 11 additions & 4 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ssl import SSLContext
from typing import Optional, Tuple, cast
from typing import List, Optional, Tuple, cast

from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend
from .._exceptions import ConnectError, ConnectTimeout
Expand All @@ -23,6 +23,7 @@ class SyncHTTPConnection(SyncHTTPTransport):
def __init__(
self,
origin: Origin,
http1: bool = True,
http2: bool = False,
uds: str = None,
ssl_context: SSLContext = None,
Expand All @@ -32,15 +33,21 @@ def __init__(
backend: SyncBackend = None,
):
self.origin = origin
self.http1 = http1
self.http2 = http2
self.uds = uds
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_address = local_address
self.retries = retries

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
alpn_protocols: List[str] = []
if http1:
alpn_protocols.append("http/1.1")
if http2:
alpn_protocols.append("h2")

self.ssl_context.set_alpn_protocols(alpn_protocols)

self.connection: Optional[SyncBaseHTTPConnection] = None
self.is_http11 = False
Expand Down Expand Up @@ -147,7 +154,7 @@ def _create_connection(self, socket: SyncSocketStream) -> None:
logger.trace(
"create_connection socket=%r http_version=%r", socket, http_version
)
if http_version == "HTTP/2":
if http_version == "HTTP/2" or (self.http2 and not self.http1):
from .http2 import SyncHTTP2Connection

self.is_http2 = True
Expand Down
10 changes: 9 additions & 1 deletion httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ class SyncConnectionPool(SyncHTTPTransport):
connections.
keepalive_expiry:
The maximum time to allow before closing a keep-alive connection.
http1:
Enable/Disable HTTP/1.1 support. Defaults to True.
http2:
Enable HTTP/2 support.
Enable/Disable HTTP/2 support. Defaults to False.
uds:
Path to a Unix Domain Socket to use instead of TCP sockets.
local_address:
Expand All @@ -110,6 +112,7 @@ def __init__(
max_connections: int = None,
max_keepalive_connections: int = None,
keepalive_expiry: float = None,
http1: bool = True,
http2: bool = False,
uds: str = None,
local_address: str = None,
Expand All @@ -131,6 +134,7 @@ def __init__(
self._max_connections = max_connections
self._max_keepalive_connections = max_keepalive_connections
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._uds = uds
self._local_address = local_address
Expand All @@ -140,6 +144,9 @@ def __init__(
self._backend = backend
self._next_keepalive_check = 0.0

if not (http1 or http2):
raise ValueError("Either http1 or http2 must be True.")

if http2:
try:
import h2 # noqa: F401
Expand Down Expand Up @@ -175,6 +182,7 @@ def _create_connection(
) -> SyncHTTPConnection:
return SyncHTTPConnection(
origin=origin,
http1=self._http1,
http2=self._http2,
uds=self._uds,
ssl_context=self._ssl_context,
Expand Down
5 changes: 5 additions & 0 deletions tests/async_tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ async def read_body(stream: httpcore.AsyncByteStream) -> bytes:
await stream.aclose()


def test_must_configure_either_http1_or_http2() -> None:
with pytest.raises(ValueError):
httpcore.AsyncConnectionPool(http1=False, http2=False)


@pytest.mark.anyio
async def test_http_request(backend: str, server: Server) -> None:
async with httpcore.AsyncConnectionPool(backend=backend) as http:
Expand Down
5 changes: 5 additions & 0 deletions tests/sync_tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def read_body(stream: httpcore.SyncByteStream) -> bytes:
stream.close()


def test_must_configure_either_http1_or_http2() -> None:
with pytest.raises(ValueError):
httpcore.SyncConnectionPool(http1=False, http2=False)



def test_http_request(backend: str, server: Server) -> None:
with httpcore.SyncConnectionPool(backend=backend) as http:
Expand Down