178 lines
6.3 KiB
Plaintext
178 lines
6.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3813f900-ea54-4533-873a-70aa46417ba0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"✅ SpA-GAN Data Ready: 29117 samples\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# ==========================================\n",
|
|
"# CELL 1: SETUP & LOAD DỮ LIỆU\n",
|
|
"# ==========================================\n",
|
|
"import os, sys, torch, numpy as np\n",
|
|
"import torch.nn as nn\n",
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
"\n",
|
|
"# Fix Import Path\n",
|
|
"sys.path.append(os.getcwd())\n",
|
|
"sys.path.append(\"/home/jovyan/cloud_train\")\n",
|
|
"\n",
|
|
"try:\n",
|
|
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
|
|
"except ImportError:\n",
|
|
" raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n",
|
|
"\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n",
|
|
"\n",
|
|
"class SEN12MSCR_TorchDataset(Dataset):\n",
|
|
" def __init__(self, base_dir, season=Seasons.SPRING):\n",
|
|
" self.loader = SEN12MSCRDataset(base_dir)\n",
|
|
" self.season = season\n",
|
|
" self.season_ids = self.loader.get_season_ids(season)\n",
|
|
" self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n",
|
|
" def __len__(self): return len(self.samples)\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" sid, pid = self.samples[idx]\n",
|
|
" s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n",
|
|
" s1 = np.clip(s1, -25, 0) / 25.0\n",
|
|
" s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n",
|
|
" c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n",
|
|
" return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n",
|
|
"\n",
|
|
"train_dataset = SEN12MSCR_TorchDataset(base_dir)\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n",
|
|
"print(f\"✅ SpA-GAN Data Ready: {len(train_dataset)} samples\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "078b6951-1dcf-44e9-895c-9ec0c8c848b6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ==========================================\n",
|
|
"# CELL 2: MODEL SpA-GAN\n",
|
|
"# ==========================================\n",
|
|
"class SpatialAttention(nn.Module):\n",
|
|
" def __init__(self, in_channels):\n",
|
|
" super(SpatialAttention, self).__init__()\n",
|
|
" self.conv = nn.Conv2d(in_channels, 1, kernel_size=1) \n",
|
|
" self.sigmoid = nn.Sigmoid()\n",
|
|
" def forward(self, x):\n",
|
|
" return x * self.sigmoid(self.conv(x))\n",
|
|
"\n",
|
|
"class SpA_Generator(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(SpA_Generator, self).__init__()\n",
|
|
" self.enc1 = nn.Sequential(nn.Conv2d(15, 64, 4, 2, 1), nn.LeakyReLU(0.2))\n",
|
|
" self.attn = SpatialAttention(64) # Module Attention\n",
|
|
" self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))\n",
|
|
" \n",
|
|
" self.dec1 = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())\n",
|
|
" self.dec2 = nn.ConvTranspose2d(64, 13, 4, 2, 1)\n",
|
|
" self.tanh = nn.Tanh()\n",
|
|
"\n",
|
|
" def forward(self, s1, s2_cloudy):\n",
|
|
" x = torch.cat([s1, s2_cloudy], dim=1)\n",
|
|
" e1 = self.enc1(x)\n",
|
|
" e1_attn = self.attn(e1) # Chú ý vào vùng quan trọng\n",
|
|
" e2 = self.enc2(e1_attn)\n",
|
|
" d1 = self.dec1(e2)\n",
|
|
" return self.tanh(self.dec2(d1 + e1))\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "e6bd89be-aa61-4429-b588-f2041311f2a6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"🚀 Bắt đầu train SpA-GAN...\n",
|
|
"Step 0 | Loss G: 0.6843\n",
|
|
"Step 10 | Loss G: 0.3739\n",
|
|
"Step 20 | Loss G: 0.1896\n",
|
|
"Step 30 | Loss G: 0.1224\n",
|
|
"Step 40 | Loss G: 0.0938\n",
|
|
"Step 50 | Loss G: 0.1833\n",
|
|
"\n",
|
|
"💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/SpAGAN_generator.pth\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# ==========================================\n",
|
|
"# CELL 3: TRAINING & SAVE\n",
|
|
"# ==========================================\n",
|
|
"if 'train_loader' in locals():\n",
|
|
" gen = SpA_Generator().to(device)\n",
|
|
" opt_g = torch.optim.Adam(gen.parameters(), lr=0.0002)\n",
|
|
" criterion = nn.L1Loss()\n",
|
|
" \n",
|
|
" print(\"🚀 Bắt đầu train SpA-GAN...\")\n",
|
|
" for i, (s1, s2, cloud) in enumerate(train_loader):\n",
|
|
" s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n",
|
|
" \n",
|
|
" opt_g.zero_grad()\n",
|
|
" fake = gen(s1, cloud)\n",
|
|
" loss_g = criterion(fake, s2) \n",
|
|
" loss_g.backward()\n",
|
|
" opt_g.step()\n",
|
|
" \n",
|
|
" if i % 10 == 0: print(f\"Step {i} | Loss G: {loss_g.item():.4f}\")\n",
|
|
" if i > 50: break # Demo 50 bước\n",
|
|
" \n",
|
|
" # --- LƯU MODEL ---\n",
|
|
" save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n",
|
|
" os.makedirs(save_dir, exist_ok=True)\n",
|
|
" torch.save(gen.state_dict(), os.path.join(save_dir, \"SpAGAN_generator.pth\"))\n",
|
|
" print(f\"\\n💾 Đã lưu model tại: {save_dir}/SpAGAN_generator.pth\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bc1d8e75-0a3f-4197-8196-7132122f2be9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|