config: improve management of CORS origins
- Replaces the onliner converting the value of the env variable into a list by a function. This allows to test this conversion. Co-Authored-by: iGor milhit <igor@milhit.ch>main
parent
67213da237
commit
d171af8d0b
21
config.py
21
config.py
|
|
@ -1,5 +1,6 @@
|
|||
from functools import lru_cache
|
||||
|
||||
from pydantic import Field, computed_field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
|
|
@ -9,13 +10,29 @@ class Settings(BaseSettings):
|
|||
env_file_encoding = "utf-8"
|
||||
)
|
||||
|
||||
# CORS
|
||||
cors_origins: str = "http://localhost:5173"
|
||||
# CORS as a private field with a string type:
|
||||
cors_origins_raw: str = Field(
|
||||
default="http://localhost:5173",
|
||||
alias="cors_origins"
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
"""Converts a string into list."""
|
||||
if isinstance(self.cors_origins_raw, list):
|
||||
return self.cors_origins_raw
|
||||
return [
|
||||
origin.strip()
|
||||
for origin in self.cors_origins_raw.split(",")
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
# Autres configurations futures
|
||||
app_name: str = "Dough Calculator"
|
||||
debug: bool = False
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
|
|
|||
4
main.py
4
main.py
|
|
@ -8,11 +8,9 @@ from config import get_settings
|
|||
app = FastAPI()
|
||||
settings = get_settings()
|
||||
|
||||
origins = [origin.strip() for origin in settings.cors_origins.split(",")]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
|
|
|||
|
|
@ -1,56 +1,56 @@
|
|||
from config import Settings, get_settings
|
||||
from config import Settings
|
||||
|
||||
|
||||
def test_settings_default_values():
|
||||
"""Vérifie les valeurs par défaut"""
|
||||
settings = Settings()
|
||||
assert settings.cors_origins == "http://localhost:5173"
|
||||
assert settings.app_name == "Dough Calculator"
|
||||
assert settings.debug is False
|
||||
class TestSettings:
|
||||
"""Tests the Settings model."""
|
||||
|
||||
def test_settings_from_env(monkeypatch):
|
||||
"""Vérifie le chargement depuis variables d'environnement"""
|
||||
monkeypatch.setenv("CORS_ORIGINS", "https://example.com,https://test.com")
|
||||
monkeypatch.setenv("APP_NAME", "Test App")
|
||||
monkeypatch.setenv("DEBUG", "true")
|
||||
def test_default_cors_origins(self):
|
||||
"""Test the default value."""
|
||||
settings = Settings()
|
||||
assert settings.cors_origins == ["http://localhost:5173"]
|
||||
assert isinstance(settings.cors_origins, list)
|
||||
|
||||
# Vider le cache pour forcer le rechargement
|
||||
get_settings.cache_clear()
|
||||
def test_cors_origins_single_value(self, monkeypatch):
|
||||
"""Test with only one origin in the env variable."""
|
||||
monkeypatch.setenv("CORS_ORIGINS", "http://example.com")
|
||||
settings = Settings()
|
||||
assert settings.cors_origins == ["http://example.com"]
|
||||
|
||||
settings = get_settings()
|
||||
assert settings.cors_origins == "https://example.com,https://test.com"
|
||||
assert settings.app_name == "Test App"
|
||||
assert settings.debug is True
|
||||
def test_cors_origins_multiple_values(self, monkeypatch):
|
||||
"""Test with several origins comma separated."""
|
||||
monkeypatch.setenv(
|
||||
"CORS_ORIGINS", "http://a.com,http://b.com,http://c.com"
|
||||
)
|
||||
settings = Settings()
|
||||
assert settings.cors_origins == [
|
||||
"http://a.com", "http://b.com", "http://c.com"
|
||||
]
|
||||
|
||||
def test_cors_origins_parsing():
|
||||
"""Vérifie que les origines CORS sont correctement parsées"""
|
||||
settings = Settings(cors_origins="http://localhost:3000,https://prod.com")
|
||||
origins = [origin.strip() for origin in settings.cors_origins.split(",")]
|
||||
def test_cors_origins_with_spaces(self, monkeypatch):
|
||||
"""Test with origins with spaces around commas."""
|
||||
monkeypatch.setenv(
|
||||
"CORS_ORIGINS", "http://a.com , http://b.com , http://c.com"
|
||||
)
|
||||
settings = Settings()
|
||||
assert settings.cors_origins == [
|
||||
"http://a.com", "http://b.com", "http://c.com"
|
||||
]
|
||||
|
||||
assert len(origins) == 2
|
||||
assert "http://localhost:3000" in origins
|
||||
assert "https://prod.com" in origins
|
||||
def test_cors_origins_empty_string(self, monkeypatch):
|
||||
"""Test with an empty list."""
|
||||
monkeypatch.setenv("CORS_ORIGINS", "")
|
||||
settings = Settings()
|
||||
assert settings.cors_origins == []
|
||||
|
||||
def test_settings_cache():
|
||||
"""Vérifie que le cache fonctionne"""
|
||||
get_settings.cache_clear()
|
||||
def test_cors_origins_with_trailing_comma(self, monkeypatch):
|
||||
"""Test with a trailing comma."""
|
||||
monkeypatch.setenv("CORS_ORIGINS", "http://a.com,http://b.com,")
|
||||
settings = Settings()
|
||||
assert settings.cors_origins == ["http://a.com", "http://b.com"]
|
||||
|
||||
settings1 = get_settings()
|
||||
settings2 = get_settings()
|
||||
def test_cors_origins_type_is_list(self):
|
||||
"""Test that the returned type is a list"""
|
||||
settings = Settings()
|
||||
assert isinstance(settings.cors_origins, list)
|
||||
assert all(isinstance(origin, str) for origin in settings.cors_origins)
|
||||
|
||||
# Même instance en mémoire grâce au cache
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_env_file_loading(tmp_path, monkeypatch):
|
||||
"""Vérifie le chargement depuis fichier .env"""
|
||||
# Créer un fichier .env temporaire
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("CORS_ORIGINS=https://from-file.com\nDEBUG=true")
|
||||
|
||||
# Changer le répertoire de travail
|
||||
monkeypatch.chdir(tmp_path)
|
||||
get_settings.cache_clear()
|
||||
|
||||
settings = Settings()
|
||||
assert settings.cors_origins == "https://from-file.com"
|
||||
assert settings.debug is True
|
||||
|
|
|
|||
Loading…
Reference in New Issue