223 lines
7.6 KiB
Python
223 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
import shutil
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from imagepipeline.core.context import ModuleContext
|
|
from imagepipeline.core.exceptions import DependencyError
|
|
from imagepipeline.modules.ai_exposure import AIExposureModule
|
|
from imagepipeline.modules.ai_tone_map import AIToneMapModule
|
|
from imagepipeline.modules.comfy_flux_edit import ComfyFluxEditModule
|
|
from imagepipeline.modules.openrouter_edit import OpenRouterEditModule
|
|
from imagepipeline.modules.registry import get_module, list_modules
|
|
from tests.conftest import make_png
|
|
|
|
try:
|
|
import numpy # noqa: F401
|
|
|
|
has_numpy = True
|
|
except ImportError:
|
|
has_numpy = False
|
|
|
|
try:
|
|
import torch # noqa: F401
|
|
|
|
has_torch = True
|
|
except ImportError:
|
|
has_torch = False
|
|
|
|
has_magick = bool(shutil.which("magick") or shutil.which("convert"))
|
|
|
|
|
|
class TestAIModuleRegistration:
|
|
def test_ai_modules_registered(self) -> None:
|
|
names = list_modules()
|
|
for name in (
|
|
"ai_exposure",
|
|
"ai_tone_map",
|
|
"openrouter_edit",
|
|
"comfy_flux_edit",
|
|
):
|
|
assert name in names
|
|
|
|
def test_get_ai_modules(self) -> None:
|
|
assert get_module("ai_exposure") is AIExposureModule
|
|
assert get_module("ai_tone_map") is AIToneMapModule
|
|
assert get_module("openrouter_edit") is OpenRouterEditModule
|
|
assert get_module("comfy_flux_edit") is ComfyFluxEditModule
|
|
|
|
|
|
class TestAIParameters:
|
|
def test_ai_exposure_defaults(self) -> None:
|
|
params = AIExposureModule.validate_module_params({})
|
|
assert params["skip_existing"] is True
|
|
assert params["max_edge"] == 2048
|
|
assert params["device"] == "cpu"
|
|
assert params["strength"] == 1.0
|
|
|
|
def test_ai_tone_map_defaults(self) -> None:
|
|
params = AIToneMapModule.validate_module_params({})
|
|
assert params["checkpoint"] == ""
|
|
assert params["strength"] == 1.0
|
|
assert params["net_input_size"] == 256
|
|
|
|
def test_openrouter_requires_prompt(self) -> None:
|
|
with pytest.raises(ValueError, match="required"):
|
|
OpenRouterEditModule.validate_module_params({})
|
|
|
|
def test_openrouter_defaults(self) -> None:
|
|
params = OpenRouterEditModule.validate_module_params({"prompt": "brighten shadows"})
|
|
assert params["model"] == "black-forest-labs/flux.2-klein-4b"
|
|
assert params["strength"] == 0.3
|
|
assert params["api_key_env"] == "OPENROUTER_API_KEY"
|
|
|
|
def test_comfy_requires_prompt(self) -> None:
|
|
with pytest.raises(ValueError, match="required"):
|
|
ComfyFluxEditModule.validate_module_params({})
|
|
|
|
|
|
class TestAIToneMapFallback:
|
|
@pytest.mark.skipif(not has_numpy, reason="numpy not installed")
|
|
def test_clahe_fallback_writes_output(self, tmp_path: Path) -> None:
|
|
src = tmp_path / "photo.png"
|
|
make_png(src, width=8, height=8, rgb=(40, 80, 120))
|
|
output_dir = tmp_path / "out"
|
|
output_dir.mkdir()
|
|
ctx = ModuleContext(
|
|
input_paths=[src],
|
|
matched_groups=[],
|
|
output_dir=output_dir,
|
|
params=AIToneMapModule.validate_module_params({"strength": 1.0}),
|
|
pipeline_output_root=tmp_path,
|
|
step_id="ai_tone_map_01",
|
|
logger=None,
|
|
)
|
|
AIToneMapModule().run(ctx)
|
|
dst = output_dir / "photo.png"
|
|
assert dst.is_file()
|
|
assert dst.stat().st_size > 0
|
|
|
|
|
|
class TestSkipExisting:
|
|
@pytest.mark.skipif(not has_numpy, reason="numpy not installed")
|
|
def test_second_run_skips_existing_outputs(self, tmp_path: Path, capsys) -> None:
|
|
src = tmp_path / "photo.png"
|
|
make_png(src, width=8, height=8)
|
|
output_dir = tmp_path / "out"
|
|
output_dir.mkdir()
|
|
params = AIToneMapModule.validate_module_params({"checkpoint": ""})
|
|
from imagepipeline.core.log import PipelineLogger
|
|
|
|
logger = PipelineLogger(verbose=True)
|
|
ctx = ModuleContext(
|
|
input_paths=[src],
|
|
matched_groups=[],
|
|
output_dir=output_dir,
|
|
params=params,
|
|
pipeline_output_root=tmp_path,
|
|
step_id="ai_tone_map_01",
|
|
logger=logger,
|
|
)
|
|
module = AIToneMapModule()
|
|
module.run(ctx)
|
|
assert (output_dir / "photo.png").is_file()
|
|
|
|
module.run(ctx)
|
|
output = capsys.readouterr().out
|
|
assert "Skipped module ai_tone_map" in output
|
|
|
|
|
|
@pytest.mark.skipif(not has_torch, reason="torch not installed")
|
|
class TestAIExposure:
|
|
def test_ai_exposure_processes_image(self, tmp_path: Path) -> None:
|
|
src = tmp_path / "photo.png"
|
|
make_png(src, width=4, height=4, rgb=(30, 60, 90))
|
|
output_dir = tmp_path / "out"
|
|
output_dir.mkdir()
|
|
|
|
mock_model = MagicMock()
|
|
mock_device = MagicMock()
|
|
|
|
def fake_enhance(_model, image, *, device, strength):
|
|
return image
|
|
|
|
with (
|
|
patch.object(AIExposureModule, "_get_model", return_value=(mock_model, mock_device)),
|
|
patch(
|
|
"imagepipeline.ai.zero_dce.enhance_image",
|
|
side_effect=fake_enhance,
|
|
),
|
|
):
|
|
ctx = ModuleContext(
|
|
input_paths=[src],
|
|
matched_groups=[],
|
|
output_dir=output_dir,
|
|
params=AIExposureModule.validate_module_params({"max_edge": 0}),
|
|
pipeline_output_root=tmp_path,
|
|
step_id="ai_exposure_01",
|
|
logger=None,
|
|
)
|
|
AIExposureModule().run(ctx)
|
|
|
|
assert (output_dir / "photo.png").is_file()
|
|
|
|
|
|
class TestComfyFluxEdit:
|
|
def test_server_unreachable_raises(self, tmp_path: Path) -> None:
|
|
src = tmp_path / "photo.png"
|
|
make_png(src)
|
|
output_dir = tmp_path / "out"
|
|
output_dir.mkdir()
|
|
workflow = tmp_path / "workflow.json"
|
|
workflow.write_text('{"1": {"class_type": "LoadImage", "inputs": {"image": "x"}}}')
|
|
|
|
ctx = ModuleContext(
|
|
input_paths=[src],
|
|
matched_groups=[],
|
|
output_dir=output_dir,
|
|
params=ComfyFluxEditModule.validate_module_params(
|
|
{
|
|
"prompt": "test",
|
|
"server_url": "http://127.0.0.1:1",
|
|
"workflow_path": workflow,
|
|
}
|
|
),
|
|
pipeline_output_root=tmp_path,
|
|
step_id="comfy_flux_edit_01",
|
|
logger=None,
|
|
)
|
|
with pytest.raises(DependencyError, match="not reachable"):
|
|
ComfyFluxEditModule().run(ctx)
|
|
|
|
def test_missing_workflow_raises(self, tmp_path: Path) -> None:
|
|
src = tmp_path / "photo.png"
|
|
make_png(src)
|
|
output_dir = tmp_path / "out"
|
|
output_dir.mkdir()
|
|
ctx = ModuleContext(
|
|
input_paths=[src],
|
|
matched_groups=[],
|
|
output_dir=output_dir,
|
|
params=ComfyFluxEditModule.validate_module_params(
|
|
{
|
|
"prompt": "test",
|
|
"workflow_path": tmp_path / "missing.json",
|
|
}
|
|
),
|
|
pipeline_output_root=tmp_path,
|
|
step_id="comfy_flux_edit_01",
|
|
logger=None,
|
|
)
|
|
with pytest.raises(DependencyError, match="workflow not found"):
|
|
ComfyFluxEditModule().run(ctx)
|
|
|
|
|
|
class TestOpenRouterDependencies:
|
|
def test_missing_api_key_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
|
with pytest.raises(DependencyError, match="OPENROUTER_API_KEY"):
|
|
OpenRouterEditModule.check_dependencies()
|