config: improve management of CORS origins
- Replaces the onliner converting the value of the env variable into a list by a function. This allows to test this conversion. Co-Authored-by: iGor milhit <igor@milhit.ch>main
parent
67213da237
commit
d171af8d0b
21
config.py
21
config.py
|
|
@ -1,5 +1,6 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from pydantic import Field, computed_field
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -9,13 +10,29 @@ class Settings(BaseSettings):
|
||||||
env_file_encoding = "utf-8"
|
env_file_encoding = "utf-8"
|
||||||
)
|
)
|
||||||
|
|
||||||
# CORS
|
# CORS as a private field with a string type:
|
||||||
cors_origins: str = "http://localhost:5173"
|
cors_origins_raw: str = Field(
|
||||||
|
default="http://localhost:5173",
|
||||||
|
alias="cors_origins"
|
||||||
|
)
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def cors_origins(self) -> list[str]:
|
||||||
|
"""Converts a string into list."""
|
||||||
|
if isinstance(self.cors_origins_raw, list):
|
||||||
|
return self.cors_origins_raw
|
||||||
|
return [
|
||||||
|
origin.strip()
|
||||||
|
for origin in self.cors_origins_raw.split(",")
|
||||||
|
if origin.strip()
|
||||||
|
]
|
||||||
|
|
||||||
# Autres configurations futures
|
# Autres configurations futures
|
||||||
app_name: str = "Dough Calculator"
|
app_name: str = "Dough Calculator"
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
return Settings()
|
return Settings()
|
||||||
|
|
|
||||||
4
main.py
4
main.py
|
|
@ -8,11 +8,9 @@ from config import get_settings
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
origins = [origin.strip() for origin in settings.cors_origins.split(",")]
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=origins,
|
allow_origins=settings.cors_origins,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
|
|
|
||||||
|
|
@ -1,56 +1,56 @@
|
||||||
from config import Settings, get_settings
|
from config import Settings
|
||||||
|
|
||||||
|
|
||||||
def test_settings_default_values():
|
class TestSettings:
|
||||||
"""Vérifie les valeurs par défaut"""
|
"""Tests the Settings model."""
|
||||||
settings = Settings()
|
|
||||||
assert settings.cors_origins == "http://localhost:5173"
|
|
||||||
assert settings.app_name == "Dough Calculator"
|
|
||||||
assert settings.debug is False
|
|
||||||
|
|
||||||
def test_settings_from_env(monkeypatch):
|
def test_default_cors_origins(self):
|
||||||
"""Vérifie le chargement depuis variables d'environnement"""
|
"""Test the default value."""
|
||||||
monkeypatch.setenv("CORS_ORIGINS", "https://example.com,https://test.com")
|
settings = Settings()
|
||||||
monkeypatch.setenv("APP_NAME", "Test App")
|
assert settings.cors_origins == ["http://localhost:5173"]
|
||||||
monkeypatch.setenv("DEBUG", "true")
|
assert isinstance(settings.cors_origins, list)
|
||||||
|
|
||||||
# Vider le cache pour forcer le rechargement
|
def test_cors_origins_single_value(self, monkeypatch):
|
||||||
get_settings.cache_clear()
|
"""Test with only one origin in the env variable."""
|
||||||
|
monkeypatch.setenv("CORS_ORIGINS", "http://example.com")
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.cors_origins == ["http://example.com"]
|
||||||
|
|
||||||
settings = get_settings()
|
def test_cors_origins_multiple_values(self, monkeypatch):
|
||||||
assert settings.cors_origins == "https://example.com,https://test.com"
|
"""Test with several origins comma separated."""
|
||||||
assert settings.app_name == "Test App"
|
monkeypatch.setenv(
|
||||||
assert settings.debug is True
|
"CORS_ORIGINS", "http://a.com,http://b.com,http://c.com"
|
||||||
|
)
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.cors_origins == [
|
||||||
|
"http://a.com", "http://b.com", "http://c.com"
|
||||||
|
]
|
||||||
|
|
||||||
def test_cors_origins_parsing():
|
def test_cors_origins_with_spaces(self, monkeypatch):
|
||||||
"""Vérifie que les origines CORS sont correctement parsées"""
|
"""Test with origins with spaces around commas."""
|
||||||
settings = Settings(cors_origins="http://localhost:3000,https://prod.com")
|
monkeypatch.setenv(
|
||||||
origins = [origin.strip() for origin in settings.cors_origins.split(",")]
|
"CORS_ORIGINS", "http://a.com , http://b.com , http://c.com"
|
||||||
|
)
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.cors_origins == [
|
||||||
|
"http://a.com", "http://b.com", "http://c.com"
|
||||||
|
]
|
||||||
|
|
||||||
assert len(origins) == 2
|
def test_cors_origins_empty_string(self, monkeypatch):
|
||||||
assert "http://localhost:3000" in origins
|
"""Test with an empty list."""
|
||||||
assert "https://prod.com" in origins
|
monkeypatch.setenv("CORS_ORIGINS", "")
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.cors_origins == []
|
||||||
|
|
||||||
def test_settings_cache():
|
def test_cors_origins_with_trailing_comma(self, monkeypatch):
|
||||||
"""Vérifie que le cache fonctionne"""
|
"""Test with a trailing comma."""
|
||||||
get_settings.cache_clear()
|
monkeypatch.setenv("CORS_ORIGINS", "http://a.com,http://b.com,")
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.cors_origins == ["http://a.com", "http://b.com"]
|
||||||
|
|
||||||
settings1 = get_settings()
|
def test_cors_origins_type_is_list(self):
|
||||||
settings2 = get_settings()
|
"""Test that the returned type is a list"""
|
||||||
|
settings = Settings()
|
||||||
|
assert isinstance(settings.cors_origins, list)
|
||||||
|
assert all(isinstance(origin, str) for origin in settings.cors_origins)
|
||||||
|
|
||||||
# Même instance en mémoire grâce au cache
|
|
||||||
assert settings1 is settings2
|
|
||||||
|
|
||||||
def test_env_file_loading(tmp_path, monkeypatch):
|
|
||||||
"""Vérifie le chargement depuis fichier .env"""
|
|
||||||
# Créer un fichier .env temporaire
|
|
||||||
env_file = tmp_path / ".env"
|
|
||||||
env_file.write_text("CORS_ORIGINS=https://from-file.com\nDEBUG=true")
|
|
||||||
|
|
||||||
# Changer le répertoire de travail
|
|
||||||
monkeypatch.chdir(tmp_path)
|
|
||||||
get_settings.cache_clear()
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
assert settings.cors_origins == "https://from-file.com"
|
|
||||||
assert settings.debug is True
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue