"""Read database configuration from ENV or .env -file"""
from typing import Optional, cast, Callable, ClassVar, Any
import logging
import functools
from dataclasses import dataclass, field
from sqlalchemy.engine.url import URL, make_url
from starlette.config import Config
from starlette.datastructures import Secret
[docs]
LOGGER = logging.getLogger(__name__)
[docs]
config = Config() # not supporting .env files anymore because https://github.com/encode/starlette/discussions/2446
# FIXME: this should probably be in some common library of ours
@dataclass
[docs]
class DBConfig: # pylint: disable=R0902
"""DB config dataclass, functools etc used to avoid import-time side-effects"""
[docs]
driver: str = field(
default_factory=cast(Callable[..., str], functools.partial(config, "RM_DATABASE_DRIVER", default="postgresql"))
)
[docs]
host: Optional[str] = field(default_factory=functools.partial(config, "RM_DATABASE_HOST", default=None))
[docs]
port: int = field(
default_factory=cast(Callable[..., int], functools.partial(config, "RM_DATABASE_PORT", cast=int, default=None))
)
[docs]
user: Optional[str] = field(
default_factory=cast(Callable[..., str], functools.partial(config, "RM_DATABASE_USER", default="raesenmaeher"))
)
[docs]
password: Secret = field(
default_factory=cast(
Callable[..., Secret], functools.partial(config, "RM_DATABASE_PASSWORD", cast=Secret, default=None)
)
)
[docs]
database: str = field(
default_factory=cast(Callable[..., str], functools.partial(config, "RM_DATABASE_NAME", default="raesenmaeher"))
)
[docs]
dsn: Optional[URL] = field(
default_factory=cast(
Callable[..., Optional[URL]],
functools.partial(
config,
"RM_DB_DSN",
cast=make_url,
default=None,
),
)
)
[docs]
pool_min_size: int = field(
default_factory=cast(Callable[..., int], functools.partial(config, "DB_POOL_MIN_SIZE", cast=int, default=1))
)
[docs]
pool_max_size: int = field(
default_factory=cast(Callable[..., int], functools.partial(config, "DB_POOL_MAX_SIZE", cast=int, default=16))
)
[docs]
echo: bool = field(
default_factory=cast(Callable[..., bool], functools.partial(config, "DB_ECHO", cast=bool, default=False))
)
[docs]
ssl: str = field(
default_factory=cast(Callable[..., str], functools.partial(config, "DB_SSL", cast=str, default="prefer"))
) # see asyncpg.connect()
[docs]
use_connection_for_request: bool = field(
default_factory=cast(
Callable[..., bool], functools.partial(config, "DB_USE_CONNECTION_FOR_REQUEST", cast=bool, default=True)
)
)
[docs]
retry_limit: int = field(
default_factory=cast(Callable[..., int], functools.partial(config, "DB_RETRY_LIMIT", cast=int, default=1))
)
[docs]
retry_interval: int = field(
default_factory=cast(Callable[..., int], functools.partial(config, "DB_RETRY_INTERVAL", cast=int, default=1))
)
# private
[docs]
_singleton: ClassVar[Optional["DBConfig"]] = None
@classmethod
[docs]
def singleton(cls, **kwargs: Any) -> "DBConfig":
"""Get a singleton"""
if DBConfig._singleton is None:
DBConfig._singleton = DBConfig(**kwargs)
assert DBConfig._singleton is not None
return DBConfig._singleton
[docs]
def __post_init__(self) -> None:
"""Post init stuff"""
if self.dsn is None:
self.dsn = URL(
drivername=self.driver,
username=self.user,
password=self.password,
host=self.host,
port=self.port,
database=self.database,
)
LOGGER.debug("DSN={}".format(self.dsn))
LOGGER.debug("HOST={}".format(self.host))
LOGGER.debug("DATABASE={}".format(self.database))