112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class CSDN_Tem(nn.Module):
|
|
def __init__(self, in_ch: int, out_ch: int) -> None:
|
|
super().__init__()
|
|
self.depth_conv = nn.Conv2d(
|
|
in_channels=in_ch,
|
|
out_channels=in_ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
groups=in_ch,
|
|
)
|
|
self.point_conv = nn.Conv2d(
|
|
in_channels=in_ch,
|
|
out_channels=out_ch,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
groups=1,
|
|
)
|
|
|
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
out = self.depth_conv(input_tensor)
|
|
return self.point_conv(out)
|
|
|
|
|
|
class EnhanceNetNoPool(nn.Module):
|
|
def __init__(self, scale_factor: int = 1) -> None:
|
|
super().__init__()
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.scale_factor = scale_factor
|
|
self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor)
|
|
number_f = 32
|
|
|
|
self.e_conv1 = CSDN_Tem(3, number_f)
|
|
self.e_conv2 = CSDN_Tem(number_f, number_f)
|
|
self.e_conv3 = CSDN_Tem(number_f, number_f)
|
|
self.e_conv4 = CSDN_Tem(number_f, number_f)
|
|
self.e_conv5 = CSDN_Tem(number_f * 2, number_f)
|
|
self.e_conv6 = CSDN_Tem(number_f * 2, number_f)
|
|
self.e_conv7 = CSDN_Tem(number_f * 2, 3)
|
|
|
|
def enhance(self, x: torch.Tensor, x_r: torch.Tensor) -> torch.Tensor:
|
|
x = x + x_r * (torch.pow(x, 2) - x)
|
|
x = x + x_r * (torch.pow(x, 2) - x)
|
|
x = x + x_r * (torch.pow(x, 2) - x)
|
|
enhance_image_1 = x + x_r * (torch.pow(x, 2) - x)
|
|
x = enhance_image_1 + x_r * (torch.pow(enhance_image_1, 2) - enhance_image_1)
|
|
x = x + x_r * (torch.pow(x, 2) - x)
|
|
x = x + x_r * (torch.pow(x, 2) - x)
|
|
return x + x_r * (torch.pow(x, 2) - x)
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if self.scale_factor == 1:
|
|
x_down = x
|
|
else:
|
|
x_down = F.interpolate(
|
|
x, scale_factor=1 / self.scale_factor, mode="bilinear"
|
|
)
|
|
|
|
x1 = self.relu(self.e_conv1(x_down))
|
|
x2 = self.relu(self.e_conv2(x1))
|
|
x3 = self.relu(self.e_conv3(x2))
|
|
x4 = self.relu(self.e_conv4(x3))
|
|
x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1)))
|
|
x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1)))
|
|
x_r = F.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
|
|
if self.scale_factor != 1:
|
|
x_r = self.upsample(x_r)
|
|
return self.enhance(x, x_r), x_r
|
|
|
|
|
|
def load_zero_dce_model(weights_path, device: torch.device) -> EnhanceNetNoPool:
|
|
model = EnhanceNetNoPool(scale_factor=1)
|
|
state = torch.load(weights_path, map_location=device, weights_only=False)
|
|
model.load_state_dict(state)
|
|
model.to(device)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def enhance_image(
|
|
model: EnhanceNetNoPool,
|
|
image_rgb,
|
|
*,
|
|
device: torch.device,
|
|
strength: float = 1.0,
|
|
):
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
if not isinstance(image_rgb, Image.Image):
|
|
image_rgb = Image.fromarray(image_rgb)
|
|
|
|
arr = np.asarray(image_rgb, dtype=np.float32) / 255.0
|
|
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device)
|
|
with torch.no_grad():
|
|
enhanced, _ = model(tensor)
|
|
if strength < 1.0:
|
|
enhanced = tensor * (1.0 - strength) + enhanced * strength
|
|
enhanced = torch.clamp(enhanced, 0.0, 1.0)
|
|
out = (enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255.0).astype(
|
|
np.uint8
|
|
)
|
|
return Image.fromarray(out)
|