Files
imagepipeline/imagepipeline/ai/zero_dce.py
T
2026-05-30 11:33:07 +02:00

112 lines
3.6 KiB
Python

from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
class CSDN_Tem(nn.Module):
def __init__(self, in_ch: int, out_ch: int) -> None:
super().__init__()
self.depth_conv = nn.Conv2d(
in_channels=in_ch,
out_channels=in_ch,
kernel_size=3,
stride=1,
padding=1,
groups=in_ch,
)
self.point_conv = nn.Conv2d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=1,
stride=1,
padding=0,
groups=1,
)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
out = self.depth_conv(input_tensor)
return self.point_conv(out)
class EnhanceNetNoPool(nn.Module):
def __init__(self, scale_factor: int = 1) -> None:
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.scale_factor = scale_factor
self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor)
number_f = 32
self.e_conv1 = CSDN_Tem(3, number_f)
self.e_conv2 = CSDN_Tem(number_f, number_f)
self.e_conv3 = CSDN_Tem(number_f, number_f)
self.e_conv4 = CSDN_Tem(number_f, number_f)
self.e_conv5 = CSDN_Tem(number_f * 2, number_f)
self.e_conv6 = CSDN_Tem(number_f * 2, number_f)
self.e_conv7 = CSDN_Tem(number_f * 2, 3)
def enhance(self, x: torch.Tensor, x_r: torch.Tensor) -> torch.Tensor:
x = x + x_r * (torch.pow(x, 2) - x)
x = x + x_r * (torch.pow(x, 2) - x)
x = x + x_r * (torch.pow(x, 2) - x)
enhance_image_1 = x + x_r * (torch.pow(x, 2) - x)
x = enhance_image_1 + x_r * (torch.pow(enhance_image_1, 2) - enhance_image_1)
x = x + x_r * (torch.pow(x, 2) - x)
x = x + x_r * (torch.pow(x, 2) - x)
return x + x_r * (torch.pow(x, 2) - x)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.scale_factor == 1:
x_down = x
else:
x_down = F.interpolate(
x, scale_factor=1 / self.scale_factor, mode="bilinear"
)
x1 = self.relu(self.e_conv1(x_down))
x2 = self.relu(self.e_conv2(x1))
x3 = self.relu(self.e_conv3(x2))
x4 = self.relu(self.e_conv4(x3))
x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1)))
x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1)))
x_r = F.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
if self.scale_factor != 1:
x_r = self.upsample(x_r)
return self.enhance(x, x_r), x_r
def load_zero_dce_model(weights_path, device: torch.device) -> EnhanceNetNoPool:
model = EnhanceNetNoPool(scale_factor=1)
state = torch.load(weights_path, map_location=device, weights_only=False)
model.load_state_dict(state)
model.to(device)
model.eval()
return model
def enhance_image(
model: EnhanceNetNoPool,
image_rgb,
*,
device: torch.device,
strength: float = 1.0,
):
import numpy as np
from PIL import Image
if not isinstance(image_rgb, Image.Image):
image_rgb = Image.fromarray(image_rgb)
arr = np.asarray(image_rgb, dtype=np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
enhanced, _ = model(tensor)
if strength < 1.0:
enhanced = tensor * (1.0 - strength) + enhanced * strength
enhanced = torch.clamp(enhanced, 0.0, 1.0)
out = (enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255.0).astype(
np.uint8
)
return Image.fromarray(out)