Source code for flowodm.model

"""
FlowBaseModel - Pydantic v2 base class for Kafka message models.

Provides both synchronous and asynchronous methods for produce/consume operations.
"""

from __future__ import annotations

import io
import struct
from collections.abc import AsyncIterator, Iterator
from datetime import datetime
from enum import Enum
from typing import Any, TypeVar

import fastavro
from confluent_kafka import Consumer, Message, Producer
from pydantic import BaseModel, ConfigDict

from flowodm.connection import (
    get_async_consumer,
    get_async_producer,
    get_consumer,
    get_producer,
    get_schema_registry,
)
from flowodm.exceptions import (
    ConfigurationError,
    DeserializationError,
    ProducerError,
    SerializationError,
    SettingsError,
)
from flowodm.settings import BaseSettings

T = TypeVar("T", bound="FlowBaseModel")

# Cache for schema IDs keyed by class name, stored at module level
# to avoid Pydantic treating it as a model attribute
_schema_id_cache: dict[str, int] = {}

# Confluent wire format constants
# Confluent Kafka messages use a wire format with:
# - Byte 0: Magic byte (0x00)
# - Bytes 1-4: Schema ID (big-endian 4-byte integer)
# - Bytes 5+: Actual Avro serialized data
CONFLUENT_MAGIC_BYTE = 0x00
CONFLUENT_HEADER_SIZE = 5  # 1 byte magic + 4 bytes schema ID


[docs] class FlowBaseModel(BaseModel): """ Base class for Kafka message models with ODM functionality. Provides both synchronous and asynchronous methods for produce/consume. Maps Pydantic models to Avro schemas automatically. Subclasses must define an inner Settings class: Example: class UserEvent(FlowBaseModel): class Settings: topic = "user-events" schema_subject = "user-events-value" # Optional consumer_group = "my-service" # Optional user_id: str action: str timestamp: datetime """ model_config = ConfigDict(populate_by_name=True, extra="forbid")
[docs] class Settings: """ Settings for the Kafka model. Configuration class for defining Kafka topic, schema, and consumer settings. The type annotations provide the documentation for each setting. """ topic: str | None = None """Kafka topic name (required)""" schema_subject: str | None = None """Schema Registry subject (defaults to {topic}-value)""" schema_path: str | None = None """Path to local .avsc file (optional)""" consumer_group: str | None = None """Consumer group ID (optional)""" key_field: str | None = None """Field name to use as message key (optional)""" key_serializer: str = "string" """Key serialization format: "string", "avro", "json" """ value_serializer: str = "avro" """Value serialization format: "avro", "json" """ confluent_wire_format: bool = True """Prepend Confluent wire format header (magic byte + schema ID) when serializing"""
# ==================== Class Methods for Configuration ==================== @classmethod def _get_topic(cls) -> str: """Get topic name from Settings.""" settings = getattr(cls, "Settings", None) if settings is None: raise SettingsError(f"{cls.__name__} must define an inner Settings class") topic: str | None = getattr(settings, "topic", None) if not topic: raise SettingsError(f"{cls.__name__}.Settings must define 'topic'") return topic @classmethod def _get_schema_subject(cls) -> str: """Get schema subject (defaults to {topic}-value).""" settings = getattr(cls, "Settings", None) subject: str | None = getattr(settings, "schema_subject", None) if settings else None if subject: return subject return f"{cls._get_topic()}-value" @classmethod def _get_consumer_group(cls) -> str | None: """Get consumer group from Settings.""" settings = getattr(cls, "Settings", None) return getattr(settings, "consumer_group", None) if settings else None @classmethod def _get_key_field(cls) -> str | None: """Get key field from Settings.""" settings = getattr(cls, "Settings", None) return getattr(settings, "key_field", None) if settings else None @classmethod def _get_schema_path(cls) -> str | None: """Get schema path from Settings.""" settings = getattr(cls, "Settings", None) return getattr(settings, "schema_path", None) if settings else None @classmethod def _get_confluent_wire_format(cls) -> bool: """Get confluent_wire_format setting (defaults to True).""" settings = getattr(cls, "Settings", None) return getattr(settings, "confluent_wire_format", True) if settings else True @classmethod def _get_or_register_schema_id(cls) -> int: """ Get or register the schema ID for this model. Uses a module-level cache to avoid repeated registry calls. Returns: Schema ID from registry Raises: ConfigurationError: If no Schema Registry is configured """ cache_key = f"{cls.__module__}.{cls.__qualname__}" if cache_key not in _schema_id_cache: _schema_id_cache[cache_key] = cls.register_schema() return _schema_id_cache[cache_key] @classmethod def _get_avro_schema(cls) -> dict[str, Any]: """ Get Avro schema for this model. Priority: 1. Load from schema_path if specified 2. Load from Schema Registry if schema_subject specified 3. Auto-generate from Pydantic model """ # Try loading from file schema_path = cls._get_schema_path() if schema_path: with open(schema_path) as f: import json schema_from_file: dict[str, Any] = json.load(f) return schema_from_file # Try loading from Schema Registry try: registry = get_schema_registry() subject = cls._get_schema_subject() schema = registry.get_latest_version(subject) import json schema_str = schema.schema.schema_str if not schema_str: raise ValueError("Schema string is empty") schema_from_registry: dict[str, Any] = json.loads(schema_str) return schema_from_registry except Exception: pass # Auto-generate from Pydantic model return cls._generate_avro_schema() @classmethod def _generate_avro_schema(cls) -> dict[str, Any]: """Generate Avro schema from Pydantic model fields.""" fields = [] for field_name, field_info in cls.model_fields.items(): avro_type = cls._python_type_to_avro(field_info.annotation) # Handle optional fields field_type: str | dict[str, Any] | list[str | dict[str, Any]] if field_info.is_required(): field_type = avro_type else: field_type = ["null", avro_type] fields.append({"name": field_name, "type": field_type}) return { "type": "record", "name": cls.__name__, "namespace": cls.__module__, "fields": fields, } @classmethod def _python_type_to_avro(cls, python_type: Any) -> str | dict[str, Any]: """Convert Python type annotation to Avro type.""" # Handle None type if python_type is type(None): return "null" # Handle basic types type_mapping = { str: "string", int: "long", float: "double", bool: "boolean", bytes: "bytes", } # Get origin type for generic types origin = getattr(python_type, "__origin__", None) if origin is None: # Simple type if python_type in type_mapping: return type_mapping[python_type] if python_type == datetime: return {"type": "long", "logicalType": "timestamp-millis"} return "string" # Default fallback # Handle Optional (Union with None) if origin is type(None): return "null" return "string" # Default fallback # ==================== Serialization ==================== def _to_avro_dict(self) -> dict[str, Any]: """Convert model to Avro-compatible dictionary.""" data = self.model_dump(mode="python") for key, value in data.items(): if isinstance(value, datetime): data[key] = int(value.timestamp() * 1000) elif isinstance(value, Enum): data[key] = value.value return data @classmethod def _from_avro_dict(cls: type[T], data: dict[str, Any]) -> T: """Create model instance from Avro dictionary.""" # Convert timestamp milliseconds back to datetime for field_name, field_info in cls.model_fields.items(): if field_name in data and field_info.annotation == datetime: if isinstance(data[field_name], int): data[field_name] = datetime.fromtimestamp(data[field_name] / 1000) return cls.model_validate(data) def _serialize_avro(self) -> bytes: """Serialize model to Avro bytes. When confluent_wire_format is enabled (default) and a Schema Registry is configured, prepends the 5-byte Confluent wire format header (magic byte 0x00 + 4-byte big-endian schema ID) before the Avro data. """ schema = self._get_avro_schema() data = self._to_avro_dict() output = io.BytesIO() # Prepend Confluent wire format header if enabled and registry available if self._get_confluent_wire_format(): try: schema_id = self._get_or_register_schema_id() output.write(struct.pack(">bI", CONFLUENT_MAGIC_BYTE, schema_id)) except ConfigurationError: pass # No Schema Registry configured → raw Avro try: fastavro.schemaless_writer(output, schema, data) return output.getvalue() except Exception as e: raise SerializationError(f"Failed to serialize to Avro: {e}") from e @classmethod def _strip_confluent_header(cls, data: bytes) -> bytes: """ Strip Confluent wire format header if present. Confluent Kafka messages use a wire format with: - Byte 0: Magic byte (0x00) - Bytes 1-4: Schema ID (big-endian 4-byte integer) - Bytes 5+: Actual Avro serialized data This method detects and removes this header, returning pure Avro bytes. Args: data: Raw message bytes (may include Confluent header) Returns: Pure Avro bytes without the Confluent header """ if len(data) >= CONFLUENT_HEADER_SIZE and data[0] == CONFLUENT_MAGIC_BYTE: return data[CONFLUENT_HEADER_SIZE:] return data @classmethod def _deserialize_avro(cls: type[T], data: bytes) -> T: """ Deserialize Avro bytes to model instance. Automatically handles both pure Avro format and Confluent wire format (which includes a 5-byte header with magic byte and schema ID). Args: data: Avro bytes (with or without Confluent header) Returns: Model instance """ schema = cls._get_avro_schema() # Strip Confluent wire format header if present avro_data = cls._strip_confluent_header(data) input_stream = io.BytesIO(avro_data) try: record: dict[str, Any] = fastavro.schemaless_reader(input_stream, schema) # type: ignore[assignment,call-arg] # Validate all bytes were consumed (catches wire format mismatches) bytes_read = input_stream.tell() if bytes_read != len(avro_data): raise DeserializationError( f"Incomplete deserialization: read {bytes_read} of {len(avro_data)} bytes. " "This may indicate a wire format mismatch or schema incompatibility." ) return cls._from_avro_dict(record) except DeserializationError: raise except Exception as e: raise DeserializationError(f"Failed to deserialize from Avro: {e}") from e def _get_message_key(self) -> bytes | None: """Get message key based on key_field setting.""" key_field = self._get_key_field() if not key_field: return None key_value = getattr(self, key_field, None) if key_value is None: return None return str(key_value).encode("utf-8") # ==================== Producer/Consumer Access ====================
[docs] @classmethod def get_producer(cls) -> Producer: """Get sync Kafka producer. Override for custom connection logic.""" return get_producer()
[docs] @classmethod async def get_async_producer(cls) -> Any: """Get async Kafka producer. Override for custom connection logic.""" return await get_async_producer()
[docs] @classmethod def get_consumer( cls, group_id: str | None = None, settings: BaseSettings | None = None ) -> Consumer: """Get sync Kafka consumer. Override for custom connection logic.""" group = group_id or cls._get_consumer_group() if not group: raise SettingsError( f"{cls.__name__} requires consumer_group in Settings or group_id parameter" ) return get_consumer(group, [cls._get_topic()], settings)
[docs] @classmethod async def get_async_consumer( cls, group_id: str | None = None, settings: BaseSettings | None = None ) -> Any: """Get async Kafka consumer. Override for custom connection logic.""" group = group_id or cls._get_consumer_group() if not group: raise SettingsError( f"{cls.__name__} requires consumer_group in Settings or group_id parameter" ) return await get_async_consumer(group, [cls._get_topic()], settings)
# ==================== Produce Operations (Sync) ====================
[docs] def produce_nowait(self, callback: Any | None = None) -> None: """ Produce message to Kafka (non-blocking, fire-and-forget). Args: callback: Optional delivery callback function(err, msg) """ producer = self.get_producer() topic = self._get_topic() value = self._serialize_avro() key = self._get_message_key() try: producer.produce( topic=topic, value=value, key=key, callback=callback, ) producer.poll(0) # Trigger delivery reports except Exception as e: raise ProducerError(f"Failed to produce message: {e}") from e
[docs] def produce(self, timeout: float = 10.0) -> None: """ Produce message and wait for delivery confirmation (blocking). Args: timeout: Maximum time to wait for delivery (seconds) """ delivery_error: Exception | None = None def on_delivery(err: Any, msg: Message) -> None: nonlocal delivery_error if err: delivery_error = ProducerError(f"Delivery failed: {err}") self.produce_nowait(callback=on_delivery) producer = self.get_producer() remaining = producer.flush(timeout=timeout) if remaining > 0: raise ProducerError(f"Timed out waiting for message delivery ({remaining} pending)") if delivery_error: raise delivery_error
[docs] @classmethod def produce_many(cls, messages: list[FlowBaseModel], flush: bool = True) -> int: """ Produce multiple messages (batch). Args: messages: List of model instances to produce flush: Whether to wait for all deliveries Returns: Number of messages produced """ producer = cls.get_producer() if messages else None count = 0 for msg in messages: msg.produce_nowait() count += 1 if flush and producer: producer.flush() return count
# ==================== Produce Operations (Async) ====================
[docs] async def aproduce(self) -> None: """Produce message to Kafka (asynchronous).""" producer = await self.get_async_producer() topic = self._get_topic() value = self._serialize_avro() key = self._get_message_key() try: # Check if it's a real AIOProducer or fallback sync if hasattr(producer, "produce_async"): await producer.produce_async(topic=topic, value=value, key=key) else: # Fallback to sync producer producer.produce(topic=topic, value=value, key=key) producer.poll(0) except Exception as e: raise ProducerError(f"Failed to produce message: {e}") from e
[docs] @classmethod async def aproduce_many(cls, messages: list[FlowBaseModel]) -> int: """Produce multiple messages asynchronously.""" count = 0 for msg in messages: await msg.aproduce() count += 1 return count
# ==================== Consume Operations (Sync) ====================
[docs] @classmethod def consume_one( cls: type[T], timeout: float = 1.0, group_id: str | None = None, settings: BaseSettings | None = None, ) -> T | None: """ Consume single message (synchronous). Args: timeout: Poll timeout in seconds group_id: Consumer group ID (uses Settings.consumer_group if not specified) settings: Optional settings profile Returns: Model instance or None if no message available """ consumer = cls.get_consumer(group_id, settings) msg = consumer.poll(timeout) if msg is None: return None if msg.error(): return None try: value = msg.value() if value is None: return None instance = cls._deserialize_avro(value) consumer.commit(msg) return instance except Exception: return None
[docs] @classmethod def consume_iter( cls: type[T], timeout: float = 1.0, group_id: str | None = None, settings: BaseSettings | None = None, ) -> Iterator[T]: """ Iterate over messages (synchronous generator). Args: timeout: Poll timeout in seconds group_id: Consumer group ID settings: Optional settings profile Yields: Model instances """ consumer = cls.get_consumer(group_id, settings) while True: msg = consumer.poll(timeout) if msg is None: continue if msg.error(): continue try: value = msg.value() if value is None: continue instance = cls._deserialize_avro(value) yield instance consumer.commit(msg) except Exception: continue
[docs] @classmethod def consume_batch( cls: type[T], max_messages: int, timeout: float = 1.0, group_id: str | None = None, settings: BaseSettings | None = None, ) -> list[T]: """ Consume batch of messages. Args: max_messages: Maximum number of messages to consume timeout: Poll timeout in seconds group_id: Consumer group ID settings: Optional settings profile Returns: List of model instances """ results: list[T] = [] consumer = cls.get_consumer(group_id, settings) while len(results) < max_messages: msg = consumer.poll(timeout) if msg is None: break if msg.error(): continue try: value = msg.value() if value is None: continue instance = cls._deserialize_avro(value) results.append(instance) consumer.commit(msg) except Exception: continue return results
# ==================== Consume Operations (Async) ====================
[docs] @classmethod async def aconsume_one( cls: type[T], timeout: float = 1.0, group_id: str | None = None, settings: BaseSettings | None = None, ) -> T | None: """Consume single message (asynchronous).""" consumer = await cls.get_async_consumer(group_id, settings) # Check if it's a real AIOConsumer or fallback sync if hasattr(consumer, "poll_async"): msg = await consumer.poll_async(timeout) else: msg = consumer.poll(timeout) if msg is None: return None if msg.error(): return None try: value = msg.value() if value is None: return None instance = cls._deserialize_avro(value) consumer.commit(msg) return instance except Exception: return None
[docs] @classmethod async def aconsume_iter( cls: type[T], timeout: float = 1.0, group_id: str | None = None, settings: BaseSettings | None = None, ) -> AsyncIterator[T]: """Iterate over messages (async generator).""" consumer = await cls.get_async_consumer(group_id, settings) while True: # Check if it's a real AIOConsumer or fallback sync if hasattr(consumer, "poll_async"): msg = await consumer.poll_async(timeout) else: msg = consumer.poll(timeout) if msg is None: continue if msg.error(): continue try: value = msg.value() if value is None: continue instance = cls._deserialize_avro(value) yield instance consumer.commit(msg) except Exception: continue
[docs] @classmethod async def aconsume_batch( cls: type[T], max_messages: int, timeout: float = 1.0, group_id: str | None = None, settings: BaseSettings | None = None, ) -> list[T]: """Consume batch of messages asynchronously.""" results: list[T] = [] consumer = await cls.get_async_consumer(group_id, settings) while len(results) < max_messages: if hasattr(consumer, "poll_async"): msg = await consumer.poll_async(timeout) else: msg = consumer.poll(timeout) if msg is None: break if msg.error(): continue try: value = msg.value() if value is None: continue instance = cls._deserialize_avro(value) results.append(instance) consumer.commit(msg) except Exception: continue return results
# ==================== Schema Operations ====================
[docs] @classmethod def register_schema(cls) -> int: """ Register Avro schema with Schema Registry. Uses the same schema resolution as serialization (file → registry → auto-generate) to ensure the registered schema matches what is used for producing messages. Returns: Schema ID from registry """ from confluent_kafka.schema_registry import Schema registry = get_schema_registry() subject = cls._get_schema_subject() schema = cls._get_avro_schema() import json avro_schema = Schema(json.dumps(schema), "AVRO") return registry.register_schema(subject, avro_schema)