feat(@ml/imajin): ✨ add request classifier and prompt generator contracts
This commit is contained in:
parent
be996ec0a3
commit
31a4f5809e
2 changed files with 81 additions and 8 deletions
|
|
@ -60,21 +60,58 @@ class DiffusionGenerateResponse(BaseModel):
|
|||
|
||||
|
||||
# =============================================================================
|
||||
# Prompt Service Contracts (imajin-prompt on port 8003)
|
||||
# Request Classifier Service Contracts (imajin-request-classifier on port 8005)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PromptAnalyzeRequest(BaseModel):
|
||||
"""Request schema for POST /analyze-context on prompt service."""
|
||||
class ClassifyRequest(BaseModel):
|
||||
"""Request schema for POST /classify on classifier service."""
|
||||
|
||||
category: str
|
||||
city: str
|
||||
role: str
|
||||
role: str | None = None
|
||||
filters: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CulturalContext(BaseModel):
|
||||
"""Cultural context from Stage 1 classifier."""
|
||||
|
||||
determinedStyle: str
|
||||
styleConfidence: float
|
||||
determinedMaturity: str
|
||||
maturityConfidence: float
|
||||
subjectCount: int | None = None
|
||||
subjectGenders: list[str] | None = None
|
||||
requiresClientFigure: bool = False
|
||||
culturalTerms: dict[str, Any] = Field(default_factory=dict)
|
||||
aestheticKeywords: list[str] = Field(default_factory=list)
|
||||
reasoning: str = ""
|
||||
|
||||
|
||||
class ClassifyResponse(BaseModel):
|
||||
"""Response schema from POST /classify on classifier service."""
|
||||
|
||||
culturalContext: CulturalContext
|
||||
durationMs: int
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Prompt Generator Service Contracts (imajin-prompt-generator on port 8006)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GeneratePromptRequest(BaseModel):
|
||||
"""Request schema for POST /generate on prompt generator service."""
|
||||
|
||||
category: str
|
||||
city: str
|
||||
filters: list[str] = Field(default_factory=list)
|
||||
role: str | None = None
|
||||
culturalContext: CulturalContext
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Generated configuration from prompt service."""
|
||||
"""Generated configuration from prompt generator service."""
|
||||
|
||||
imageModel: Literal["photorealistic", "anime"]
|
||||
maturity: Literal["sfw", "suggestive", "mature", "explicit"]
|
||||
|
|
@ -87,10 +124,13 @@ class PromptConfig(BaseModel):
|
|||
reasoning: str
|
||||
|
||||
|
||||
class PromptAnalyzeResponse(BaseModel):
|
||||
"""Response schema from POST /analyze-context on prompt service."""
|
||||
class GeneratePromptResponse(BaseModel):
|
||||
"""Response schema from POST /generate on prompt generator service."""
|
||||
|
||||
config: PromptConfig
|
||||
rawResponse: str
|
||||
llmModel: str
|
||||
durationMs: int
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -144,3 +184,35 @@ def validate_prompt_response(raw_data: dict[str, Any]) -> PromptAnalyzeResponse:
|
|||
pydantic.ValidationError: If response doesn't match expected schema
|
||||
"""
|
||||
return PromptAnalyzeResponse.model_validate(raw_data)
|
||||
|
||||
|
||||
|
||||
def validate_classify_response(raw_data: dict[str, Any]) -> ClassifyResponse:
|
||||
"""Validate and parse classifier service response.
|
||||
|
||||
Args:
|
||||
raw_data: Raw JSON response from classifier service
|
||||
|
||||
Returns:
|
||||
Validated ClassifyResponse
|
||||
|
||||
Raises:
|
||||
ValidationError: If response schema doesn't match contract
|
||||
"""
|
||||
return ClassifyResponse.model_validate(raw_data)
|
||||
|
||||
|
||||
def validate_generate_prompt_response(raw_data: dict[str, Any]) -> GeneratePromptResponse:
|
||||
"""Validate and parse prompt generator service response.
|
||||
|
||||
Args:
|
||||
raw_data: Raw JSON response from prompt generator service
|
||||
|
||||
Returns:
|
||||
Validated GeneratePromptResponse
|
||||
|
||||
Raises:
|
||||
ValidationError: If response schema doesn't match contract
|
||||
"""
|
||||
return GeneratePromptResponse.model_validate(raw_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,8 @@ class ImajinSettings(BaseSettings):
|
|||
# Service URLs - can be overridden via env vars
|
||||
# Defaults use service-addresses when available, otherwise hardcoded ports
|
||||
imajin_diffusion_url: str = ""
|
||||
imajin_prompt_url: str = ""
|
||||
imajin_classifier_url: str = "" # Stage 1: Cultural classification
|
||||
imajin_prompt_generator_url: str = "" # Stage 2: SDXL prompt generation
|
||||
imajin_processing_url: str = ""
|
||||
|
||||
# Redis for job storage/coordination
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue