This commit is contained in:
Victor Phan
2026-01-24 15:48:52 +00:00
parent f264d1856b
commit fb8a107e77
6 changed files with 1760 additions and 0 deletions

183
Train_GLF-CR.ipynb Normal file
View File

@@ -0,0 +1,183 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "a148f700-b14e-4dc3-a203-9e2f93955587",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Working Directory: /home/jovyan/cloud_train/dataset\n",
"✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\n",
"Using device: cpu\n",
"✅ Data Ready: 29117 samples\n"
]
}
],
"source": [
"# ==========================================\n",
"# CELL 1: SETUP & LOAD DỮ LIỆU\n",
"# ==========================================\n",
"import os, sys, torch, numpy as np\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"# Fix Import Path\n",
"sys.path.append(os.getcwd())\n",
"sys.path.append(\"/home/jovyan/cloud_train\")\n",
"\n",
"try:\n",
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
"except ImportError:\n",
" raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n",
"\n",
"class SEN12MSCR_TorchDataset(Dataset):\n",
" def __init__(self, base_dir, season=Seasons.SPRING):\n",
" self.loader = SEN12MSCRDataset(base_dir)\n",
" self.season = season\n",
" self.season_ids = self.loader.get_season_ids(season)\n",
" self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n",
" def __len__(self): return len(self.samples)\n",
" def __getitem__(self, idx):\n",
" sid, pid = self.samples[idx]\n",
" s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n",
" s1 = np.clip(s1, -25, 0) / 25.0\n",
" s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n",
" c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n",
" return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n",
"\n",
"train_dataset = SEN12MSCR_TorchDataset(base_dir)\n",
"train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n",
"print(f\"✅ GLF-CR Data Ready: {len(train_dataset)} samples\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cdb8ff0f-ea65-4ed4-8f5b-b99b05c602a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# ==========================================\n",
"# CELL 2: MODEL GLF-CR\n",
"# ==========================================\n",
"class FusionBlock(nn.Module):\n",
" def __init__(self, channels):\n",
" super().__init__()\n",
" self.conv = nn.Sequential(\n",
" nn.Conv2d(channels * 2, channels, 1),\n",
" nn.BatchNorm2d(channels), nn.ReLU()\n",
" )\n",
" def forward(self, x1, x2):\n",
" return self.conv(torch.cat([x1, x2], dim=1))\n",
"\n",
"class GLF_Generator(nn.Module):\n",
" def __init__(self):\n",
" super(GLF_Generator, self).__init__()\n",
" # 2 Nhánh riêng biệt\n",
" self.enc_s1 = nn.Sequential(nn.Conv2d(2, 64, 3, 1, 1), nn.ReLU()) # Radar\n",
" self.enc_s2 = nn.Sequential(nn.Conv2d(13, 64, 3, 1, 1), nn.ReLU()) # Mây\n",
" self.fuse = FusionBlock(64)\n",
" self.dec = nn.Sequential(\n",
" nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(),\n",
" nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(),\n",
" nn.Conv2d(64, 13, 3, 1, 1)\n",
" )\n",
" self.tanh = nn.Tanh()\n",
"\n",
" def forward(self, s1, s2_cloudy):\n",
" f1 = self.enc_s1(s1)\n",
" f2 = self.enc_s2(s2_cloudy)\n",
" f_fused = self.fuse(f1, f2)\n",
" return self.tanh(self.dec(f_fused))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b124a548-a7fe-46e7-939d-ae3ff5e5b83c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 Bắt đầu train GLF-CR...\n",
"Step 0 | MSE Loss: 0.4782\n",
"Step 10 | MSE Loss: 0.0848\n",
"Step 20 | MSE Loss: 0.0446\n",
"Step 30 | MSE Loss: 0.0274\n",
"Step 40 | MSE Loss: 0.2338\n",
"Step 50 | MSE Loss: 0.0158\n",
"\n",
"💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/GLFCR_generator.pth\n"
]
}
],
"source": [
"# ==========================================\n",
"# CELL 3: TRAINING & SAVE\n",
"# ==========================================\n",
"if 'train_loader' in locals():\n",
" gen = GLF_Generator().to(device)\n",
" opt = torch.optim.Adam(gen.parameters(), lr=0.0002)\n",
" crit = nn.MSELoss() \n",
" \n",
" print(\"🚀 Bắt đầu train GLF-CR...\")\n",
" for i, (s1, s2, cloud) in enumerate(train_loader):\n",
" s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n",
" \n",
" opt.zero_grad()\n",
" out = gen(s1, cloud)\n",
" loss = crit(out, s2)\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" if i % 10 == 0: print(f\"Step {i} | MSE Loss: {loss.item():.4f}\")\n",
" if i > 50: break\n",
" \n",
" # --- LƯU MODEL ---\n",
" save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n",
" os.makedirs(save_dir, exist_ok=True)\n",
" torch.save(gen.state_dict(), os.path.join(save_dir, \"GLFCR_generator.pth\"))\n",
" print(f\"\\n💾 Đã lưu model tại: {save_dir}/GLFCR_generator.pth\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f863772-5925-45b4-a9aa-e1dc54f52658",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}