Source code for ormlite.orm

import logging
import dataclasses as dc
import sqlite3
from collections.abc import Sequence
from typing import (
    dataclass_transform,
    Any,
    Optional,
    TypeVar,
    ClassVar,
    Generic,
    Protocol,
)

from ormlite.errors import (
    MissingAdapterError,
    InvalidForeignKeyError,
    MultiplePrimaryKeysError,
)
from ormlite.utils import get_optional_type_arg

logger = logging.getLogger(__name__)

T = TypeVar("T")


class Adapter(Generic[T]):
    sql_type: ClassVar[str]
    python_type: ClassVar[type]

    def convert(self, b: bytes) -> T:
        ...  # pragma: no cover

    def adapt(self, val: T) -> str:
        ...  # pragma: no cover


class DatabaseConnection(Protocol):
    def close(self) -> None:
        ...  # pragma: no cover

    def execute(self, statement: str, **kwargs: Any) -> sqlite3.Cursor:
        ...  # pragma: no cover


class Context:
    ADAPTERS: dict[type, Adapter] = dict()
    PYTHON_TO_SQL_MAPPING: dict[type, str] = dict()

    MODEL_TO_TABLE: dict[type, str] = dict()
    TABLE_TO_MODEL: dict[str, type] = dict()

    @classmethod
    def setup(cls):
        cls.PYTHON_TO_SQL_MAPPING = cls.python_to_sql_mapping()

    @classmethod
    def python_to_sql_mapping(cls) -> dict[type, str]:
        mapping = {
            bytes: "BLOB",
            str: "TEXT",
            int: "INTEGER",
            float: "REAL",
        }
        for adapter in cls.ADAPTERS.values():
            mapping[adapter.python_type] = adapter.sql_type
        return mapping


[docs]@dataclass_transform() def model(sql_table_name: str): if isinstance(sql_table_name, type): raise TypeError("@model(sql_table_name) must be called with the sql table name") def wrap(model: type) -> type: if sql_table_name in Context.TABLE_TO_MODEL: logger.warning( f"Reregistering the sql table '{sql_table_name}' with {model}" ) else: logger.debug(f"applying @model({sql_table_name}) to {model})") model = dc.dataclass(model, slots=True) # pyright: ignore validate_model(model) Context.TABLE_TO_MODEL[sql_table_name] = model Context.MODEL_TO_TABLE[model] = sql_table_name return model return wrap
def validate_model(model: type): has_primary = False for field in dc.fields(model): to_sql_type(field.type) if field.metadata.get("pk"): if has_primary: raise MultiplePrimaryKeysError else: has_primary = True @dc.dataclass class ForeignKey: table: str key: Optional[str] = None def to_constraint(self, field: dc.Field[Any]) -> str: return ( f"FOREIGN KEY ({field.name}) " f"REFERENCES {self.table}({self.key or field.name})" )
[docs]def field(*, pk: bool = False, fk: Optional[str] = None, **kwargs: Any): foreign_key: Optional[ForeignKey] = None if fk: parts = fk.split(".") if len(parts) > 2: raise InvalidForeignKeyError table = parts[0] key = get(parts, 1) foreign_key = ForeignKey(table=(table), key=key) return dc.field( **kwargs, metadata={ "pk": pk, "fk": foreign_key, }, )
def get(seq: Sequence[T], index: int) -> Optional[T]: if index >= len(seq): return None else: return seq[index] def to_sql_literal(value: Any) -> str: if value is None: return "NULL" if isinstance(value, str): # TODO: sqlite escape string contents return f"'{value}'" # bool is a subtype of int if isinstance(value, int) and not isinstance(value, bool): return f"{value}" if isinstance(value, float): return f"{value}" if isinstance(value, (list, tuple)): return f"({','.join(to_sql_literal(item) for item in value)})" adapter = Context.ADAPTERS.get(type(value)) if adapter: # TODO: sqlite escape text contents # ASSUMPTION: custom adapters always encode to text return f"'{adapter.adapt(value)}'" raise MissingAdapterError def to_sql_type(field: type) -> str: python_type = get_optional_type_arg(field) or field sql_type = Context.PYTHON_TO_SQL_MAPPING.get(python_type) if sql_type is not None: return sql_type raise MissingAdapterError def column_def(field: dc.Field[Any]) -> str: optional_inner_type = get_optional_type_arg(field.type) # not null is applied to all fields automatically # use default = None to get a nullable field constraint = "NOT NULL" if field.metadata.get("pk"): constraint = "PRIMARY KEY" elif field.default is None or optional_inner_type is not None: constraint = "" # nullable, since we can't convert a python factory into a sql factory elif field.default_factory != dc.MISSING: constraint = "" elif field.default != dc.MISSING: constraint = f"DEFAULT {to_sql_literal(field.default)} NOT NULL" return f"{field.name} {to_sql_type(field.type)} {constraint}".strip() def models() -> dict[str, type]: return Context.TABLE_TO_MODEL def sql_table_name(model: type) -> str: return Context.MODEL_TO_TABLE[model] def register_adapter(adapter: Adapter[Any]): Context.ADAPTERS[adapter.python_type] = adapter sqlite3.register_adapter(adapter.python_type, adapter.adapt) sqlite3.register_converter(adapter.sql_type, adapter.convert)