This commit is contained in:
Victor Phan
2026-01-24 15:48:52 +00:00
parent f264d1856b
commit fb8a107e77
6 changed files with 1760 additions and 0 deletions

568
Train_CR-GAN.ipynb Normal file
View File

@@ -0,0 +1,568 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c886b44b-4288-4260-8b5b-bb78a87e172f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Working Directory: /home/jovyan/cloud_train/dataset\n",
"✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\n",
"Using device: cpu\n",
"✅ Data Ready: 29117 samples\n"
]
}
],
"source": [
"# ==========================================\n",
"# CELL 1: SETUP & LOAD DỮ LIỆU (ĐÃ SỬA LỖI IMPORT)\n",
"# ==========================================\n",
"import os\n",
"import sys\n",
"import torch\n",
"import numpy as np\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torch.nn as nn\n",
"\n",
"# --- KHẮC PHỤC LỖI KHÔNG TÌM THẤY MODULE ---\n",
"# Thêm đường dẫn thư mục hiện tại vào hệ thống để Python tìm thấy file loader\n",
"current_dir = os.getcwd()\n",
"sys.path.append(current_dir)\n",
"print(f\"📂 Working Directory: {current_dir}\")\n",
"\n",
"try:\n",
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
" print(\"✅ Đã import thành công sen12ms_cr_dataLoader!\")\n",
"except ImportError:\n",
" # Nếu vẫn lỗi, thử trỏ cứng vào thư mục cloud_train\n",
" sys.path.append(\"/home/jovyan/cloud_train\")\n",
" try:\n",
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
" print(\"✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\")\n",
" except ImportError:\n",
" raise RuntimeError(\"❌ Vẫn không tìm thấy file 'sen12ms_cr_dataLoader.py'. Hãy đảm bảo file này nằm cùng thư mục với Notebook!\")\n",
"\n",
"# Cấu hình thiết bị\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Đường dẫn dataset\n",
"base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n",
"\n",
"# Dataset Wrapper\n",
"class SEN12MSCR_TorchDataset(Dataset):\n",
" def __init__(self, base_dir, season=Seasons.SPRING):\n",
" self.loader = SEN12MSCRDataset(base_dir)\n",
" self.season = season\n",
" self.season_ids = self.loader.get_season_ids(season)\n",
" self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n",
" def __len__(self): return len(self.samples)\n",
" def __getitem__(self, idx):\n",
" sid, pid = self.samples[idx]\n",
" s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n",
" s1 = np.clip(s1, -25, 0) / 25.0\n",
" s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n",
" c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n",
" return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n",
"\n",
"try:\n",
" train_dataset = SEN12MSCR_TorchDataset(base_dir)\n",
" train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n",
" print(f\"✅ Data Ready: {len(train_dataset)} samples\")\n",
"except Exception as e:\n",
" print(f\"❌ Lỗi Data: {e}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a6357434-8b1c-48ef-ac84-7afadc4ce04a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# ==========================================\n",
"# CELL 2: ĐỊNH NGHĨA MODEL CR-GAN\n",
"# ==========================================\n",
"class CR_Generator(nn.Module):\n",
" def __init__(self):\n",
" super(CR_Generator, self).__init__()\n",
" # Input: 13 kênh (Mây) + 2 kênh (Radar) = 15\n",
" self.enc1 = nn.Sequential(nn.Conv2d(15, 64, 4, 2, 1), nn.LeakyReLU(0.2))\n",
" self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))\n",
" self.enc3 = nn.Sequential(nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2))\n",
" \n",
" self.dec1 = nn.Sequential(nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU())\n",
" self.dec2 = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())\n",
" self.dec3 = nn.ConvTranspose2d(64, 13, 4, 2, 1) # Output 13 kênh sạch\n",
" self.tanh = nn.Tanh()\n",
"\n",
" def forward(self, s1, s2_cloudy):\n",
" x = torch.cat([s1, s2_cloudy], dim=1)\n",
" e1 = self.enc1(x)\n",
" e2 = self.enc2(e1)\n",
" e3 = self.enc3(e2)\n",
" \n",
" d1 = self.dec1(e3)\n",
" d2 = self.dec2(d1 + e2) # Skip connection\n",
" d3 = self.dec3(d2 + e1)\n",
" return self.tanh(d3)\n",
"\n",
"class Discriminator(nn.Module):\n",
" def __init__(self):\n",
" super(Discriminator, self).__init__()\n",
" # Input: 13 (Ảnh) + 15 (Điều kiện) = 28\n",
" self.model = nn.Sequential(\n",
" nn.Conv2d(28, 64, 4, 2, 1), nn.LeakyReLU(0.2),\n",
" nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),\n",
" nn.Conv2d(128, 1, 4, 1, 0), nn.Sigmoid()\n",
" )\n",
" def forward(self, img, cond1, cond2):\n",
" return self.model(torch.cat([img, cond1, cond2], 1))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09276904-a4c9-4774-b171-32d2e577afa0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 Bắt đầu huấn luyện CR-GAN...\n",
"Epoch 0 | Step 0 | Loss G: 72.3185 | Loss D: 0.7102\n",
"Epoch 0 | Step 10 | Loss G: 37.1491 | Loss D: 0.5784\n",
"Epoch 0 | Step 20 | Loss G: 19.9918 | Loss D: 0.5008\n",
"Epoch 0 | Step 30 | Loss G: 21.9959 | Loss D: 0.6595\n",
"Epoch 0 | Step 40 | Loss G: 10.0374 | Loss D: 0.5188\n",
"Epoch 0 | Step 50 | Loss G: 7.7058 | Loss D: 0.5300\n",
"Epoch 0 | Step 60 | Loss G: 6.9481 | Loss D: 0.5223\n",
"Epoch 0 | Step 70 | Loss G: 12.7031 | Loss D: 0.6349\n",
"Epoch 0 | Step 80 | Loss G: 6.1611 | Loss D: 0.6098\n",
"Epoch 0 | Step 90 | Loss G: 8.4944 | Loss D: 0.4761\n",
"Epoch 0 | Step 100 | Loss G: 16.4352 | Loss D: 1.2795\n",
"Epoch 0 | Step 110 | Loss G: 6.7580 | Loss D: 0.6131\n",
"Epoch 0 | Step 120 | Loss G: 6.7856 | Loss D: 0.5403\n",
"Epoch 0 | Step 130 | Loss G: 6.2629 | Loss D: 0.6605\n",
"Epoch 0 | Step 140 | Loss G: 5.4081 | Loss D: 0.5667\n",
"Epoch 0 | Step 150 | Loss G: 5.5551 | Loss D: 0.5392\n",
"Epoch 0 | Step 160 | Loss G: 5.4969 | Loss D: 0.5394\n",
"Epoch 0 | Step 170 | Loss G: 4.6524 | Loss D: 0.5536\n",
"Epoch 0 | Step 180 | Loss G: 9.2369 | Loss D: 0.5098\n",
"Epoch 0 | Step 190 | Loss G: 10.5043 | Loss D: 0.4320\n",
"Epoch 0 | Step 200 | Loss G: 5.1155 | Loss D: 0.5958\n",
"Epoch 0 | Step 210 | Loss G: 5.1848 | Loss D: 0.5596\n",
"Epoch 0 | Step 220 | Loss G: 6.5014 | Loss D: 0.5307\n",
"Epoch 0 | Step 230 | Loss G: 6.4747 | Loss D: 0.6416\n",
"Epoch 0 | Step 240 | Loss G: 4.6687 | Loss D: 0.5795\n",
"Epoch 0 | Step 250 | Loss G: 3.8810 | Loss D: 0.6212\n",
"Epoch 0 | Step 260 | Loss G: 4.5834 | Loss D: 0.5906\n",
"Epoch 0 | Step 270 | Loss G: 6.8081 | Loss D: 0.4587\n",
"Epoch 0 | Step 280 | Loss G: 6.4570 | Loss D: 0.5759\n",
"Epoch 0 | Step 290 | Loss G: 4.5583 | Loss D: 0.5872\n",
"Epoch 0 | Step 300 | Loss G: 5.8205 | Loss D: 0.6410\n",
"Epoch 0 | Step 310 | Loss G: 4.7019 | Loss D: 0.6298\n",
"Epoch 0 | Step 320 | Loss G: 4.3386 | Loss D: 0.5507\n",
"Epoch 0 | Step 330 | Loss G: 4.2705 | Loss D: 0.5533\n",
"Epoch 0 | Step 340 | Loss G: 6.9023 | Loss D: 0.8489\n",
"Epoch 0 | Step 350 | Loss G: 8.0638 | Loss D: 0.7825\n",
"Epoch 0 | Step 360 | Loss G: 4.4931 | Loss D: 0.6148\n",
"Epoch 0 | Step 370 | Loss G: 3.7342 | Loss D: 0.6964\n",
"Epoch 0 | Step 380 | Loss G: 4.2160 | Loss D: 0.6782\n",
"Epoch 0 | Step 390 | Loss G: 5.1540 | Loss D: 0.6306\n",
"Epoch 0 | Step 400 | Loss G: 4.5333 | Loss D: 0.6755\n",
"Epoch 0 | Step 410 | Loss G: 4.0220 | Loss D: 0.6751\n",
"Epoch 0 | Step 420 | Loss G: 5.7766 | Loss D: 0.7027\n",
"Epoch 0 | Step 430 | Loss G: 4.0071 | Loss D: 0.5669\n",
"Epoch 0 | Step 440 | Loss G: 4.1375 | Loss D: 0.4732\n",
"Epoch 0 | Step 450 | Loss G: 3.8729 | Loss D: 0.5686\n",
"Epoch 0 | Step 460 | Loss G: 3.5877 | Loss D: 0.6608\n",
"Epoch 0 | Step 470 | Loss G: 5.6308 | Loss D: 0.7109\n",
"Epoch 0 | Step 480 | Loss G: 5.3286 | Loss D: 0.7127\n",
"Epoch 0 | Step 490 | Loss G: 2.9878 | Loss D: 0.6668\n",
"Epoch 0 | Step 500 | Loss G: 4.0083 | Loss D: 0.7079\n",
"Epoch 0 | Step 510 | Loss G: 3.6710 | Loss D: 0.7004\n",
"Epoch 0 | Step 520 | Loss G: 5.5583 | Loss D: 0.6757\n",
"Epoch 0 | Step 530 | Loss G: 4.2446 | Loss D: 0.6811\n",
"Epoch 0 | Step 540 | Loss G: 3.7496 | Loss D: 0.6590\n",
"Epoch 0 | Step 550 | Loss G: 7.4803 | Loss D: 0.8142\n",
"Epoch 0 | Step 560 | Loss G: 3.7535 | Loss D: 0.6432\n",
"Epoch 0 | Step 570 | Loss G: 2.8306 | Loss D: 0.5701\n",
"Epoch 0 | Step 580 | Loss G: 2.9589 | Loss D: 0.6700\n",
"Epoch 0 | Step 590 | Loss G: 4.7672 | Loss D: 0.7502\n",
"Epoch 0 | Step 600 | Loss G: 3.1664 | Loss D: 0.6552\n",
"Epoch 0 | Step 610 | Loss G: 3.2172 | Loss D: 0.6563\n",
"Epoch 0 | Step 620 | Loss G: 3.1097 | Loss D: 0.7513\n",
"Epoch 0 | Step 630 | Loss G: 3.8371 | Loss D: 0.7282\n",
"Epoch 0 | Step 640 | Loss G: 3.1871 | Loss D: 0.6991\n",
"Epoch 0 | Step 650 | Loss G: 3.4674 | Loss D: 0.6747\n",
"Epoch 0 | Step 660 | Loss G: 3.5930 | Loss D: 0.6720\n",
"Epoch 0 | Step 670 | Loss G: 2.9213 | Loss D: 0.6792\n",
"Epoch 0 | Step 680 | Loss G: 3.5470 | Loss D: 0.6509\n",
"Epoch 0 | Step 690 | Loss G: 3.6894 | Loss D: 0.6192\n",
"Epoch 0 | Step 700 | Loss G: 3.6680 | Loss D: 0.5443\n",
"Epoch 0 | Step 710 | Loss G: 3.6357 | Loss D: 0.5633\n",
"Epoch 0 | Step 720 | Loss G: 3.2239 | Loss D: 0.6394\n",
"Epoch 0 | Step 730 | Loss G: 2.5703 | Loss D: 0.6727\n",
"Epoch 0 | Step 740 | Loss G: 3.2920 | Loss D: 0.6691\n",
"Epoch 0 | Step 750 | Loss G: 3.1487 | Loss D: 0.7395\n",
"Epoch 0 | Step 760 | Loss G: 3.2808 | Loss D: 0.6791\n",
"Epoch 0 | Step 770 | Loss G: 3.6725 | Loss D: 0.7175\n",
"Epoch 0 | Step 780 | Loss G: 2.8584 | Loss D: 0.6647\n",
"Epoch 0 | Step 790 | Loss G: 3.5167 | Loss D: 0.7073\n",
"Epoch 0 | Step 800 | Loss G: 2.9515 | Loss D: 0.6595\n",
"Epoch 0 | Step 810 | Loss G: 3.0861 | Loss D: 0.7068\n",
"Epoch 0 | Step 820 | Loss G: 5.8470 | Loss D: 0.9232\n",
"Epoch 0 | Step 830 | Loss G: 3.4973 | Loss D: 0.6484\n",
"Epoch 0 | Step 840 | Loss G: 3.3402 | Loss D: 0.6405\n",
"Epoch 0 | Step 850 | Loss G: 2.9324 | Loss D: 0.6842\n",
"Epoch 0 | Step 860 | Loss G: 3.3249 | Loss D: 0.5885\n",
"Epoch 0 | Step 870 | Loss G: 3.2350 | Loss D: 0.7483\n",
"Epoch 0 | Step 880 | Loss G: 2.9857 | Loss D: 0.6709\n",
"Epoch 0 | Step 890 | Loss G: 3.1763 | Loss D: 0.5673\n",
"Epoch 0 | Step 900 | Loss G: 3.0522 | Loss D: 0.7522\n",
"Epoch 0 | Step 910 | Loss G: 2.5046 | Loss D: 0.7584\n",
"Epoch 0 | Step 920 | Loss G: 3.5277 | Loss D: 0.7532\n",
"Epoch 0 | Step 930 | Loss G: 2.7921 | Loss D: 0.6860\n",
"Epoch 0 | Step 940 | Loss G: 4.3079 | Loss D: 0.7080\n",
"Epoch 0 | Step 950 | Loss G: 2.7965 | Loss D: 0.6728\n",
"Epoch 0 | Step 960 | Loss G: 2.7622 | Loss D: 0.6773\n",
"Epoch 0 | Step 970 | Loss G: 2.8015 | Loss D: 0.6507\n",
"Epoch 0 | Step 980 | Loss G: 2.0900 | Loss D: 0.6775\n",
"Epoch 0 | Step 990 | Loss G: 2.9263 | Loss D: 0.6361\n",
"Epoch 0 | Step 1000 | Loss G: 3.5840 | Loss D: 0.6793\n",
"Epoch 0 | Step 1010 | Loss G: 3.0541 | Loss D: 0.7161\n",
"Epoch 0 | Step 1020 | Loss G: 2.6563 | Loss D: 0.6252\n",
"Epoch 0 | Step 1030 | Loss G: 2.9624 | Loss D: 0.6389\n",
"Epoch 0 | Step 1040 | Loss G: 4.0138 | Loss D: 0.5921\n",
"Epoch 0 | Step 1050 | Loss G: 3.4586 | Loss D: 0.6120\n",
"Epoch 0 | Step 1060 | Loss G: 3.6436 | Loss D: 0.5100\n",
"Epoch 0 | Step 1070 | Loss G: 2.5240 | Loss D: 0.7917\n",
"Epoch 0 | Step 1080 | Loss G: 2.9808 | Loss D: 0.5955\n",
"Epoch 0 | Step 1090 | Loss G: 3.5152 | Loss D: 0.6790\n",
"Epoch 0 | Step 1100 | Loss G: 4.8162 | Loss D: 0.9497\n",
"Epoch 0 | Step 1110 | Loss G: 4.0998 | Loss D: 0.6457\n",
"Epoch 0 | Step 1120 | Loss G: 3.7265 | Loss D: 0.7151\n",
"Epoch 0 | Step 1130 | Loss G: 2.8072 | Loss D: 0.7019\n",
"Epoch 0 | Step 1140 | Loss G: 3.8286 | Loss D: 0.7170\n",
"Epoch 0 | Step 1150 | Loss G: 2.3177 | Loss D: 0.6822\n",
"Epoch 0 | Step 1160 | Loss G: 2.3291 | Loss D: 0.6801\n",
"Epoch 0 | Step 1170 | Loss G: 2.5152 | Loss D: 0.6733\n",
"Epoch 0 | Step 1180 | Loss G: 2.1710 | Loss D: 0.6602\n",
"Epoch 0 | Step 1190 | Loss G: 3.0357 | Loss D: 0.6695\n",
"Epoch 0 | Step 1200 | Loss G: 2.9576 | Loss D: 0.6414\n",
"Epoch 0 | Step 1210 | Loss G: 3.0184 | Loss D: 0.6573\n",
"Epoch 0 | Step 1220 | Loss G: 2.2493 | Loss D: 0.7150\n",
"Epoch 0 | Step 1230 | Loss G: 2.7642 | Loss D: 0.7177\n",
"Epoch 0 | Step 1240 | Loss G: 2.4553 | Loss D: 0.6698\n",
"Epoch 0 | Step 1250 | Loss G: 2.5607 | Loss D: 0.6953\n",
"Epoch 0 | Step 1260 | Loss G: 2.9071 | Loss D: 0.6687\n",
"Epoch 0 | Step 1270 | Loss G: 2.6420 | Loss D: 0.6780\n",
"Epoch 0 | Step 1280 | Loss G: 2.6417 | Loss D: 0.6567\n",
"Epoch 0 | Step 1290 | Loss G: 2.5739 | Loss D: 0.6899\n",
"Epoch 0 | Step 1300 | Loss G: 2.9631 | Loss D: 0.6819\n",
"Epoch 0 | Step 1310 | Loss G: 2.5568 | Loss D: 0.6575\n",
"Epoch 0 | Step 1320 | Loss G: 2.2896 | Loss D: 0.6679\n",
"Epoch 0 | Step 1330 | Loss G: 2.9179 | Loss D: 0.6886\n",
"Epoch 0 | Step 1340 | Loss G: 2.8625 | Loss D: 0.6136\n",
"Epoch 0 | Step 1350 | Loss G: 2.6456 | Loss D: 0.6368\n",
"Epoch 0 | Step 1360 | Loss G: 2.9531 | Loss D: 0.6491\n",
"Epoch 0 | Step 1370 | Loss G: 2.5252 | Loss D: 0.7159\n",
"Epoch 0 | Step 1380 | Loss G: 3.9087 | Loss D: 0.7079\n",
"Epoch 0 | Step 1390 | Loss G: 2.3737 | Loss D: 0.6529\n",
"Epoch 0 | Step 1400 | Loss G: 2.3664 | Loss D: 0.6481\n",
"Epoch 0 | Step 1410 | Loss G: 2.7657 | Loss D: 0.5882\n",
"Epoch 0 | Step 1420 | Loss G: 2.7599 | Loss D: 0.6327\n",
"Epoch 0 | Step 1430 | Loss G: 2.6549 | Loss D: 0.8037\n",
"Epoch 0 | Step 1440 | Loss G: 3.3464 | Loss D: 0.5944\n",
"Epoch 0 | Step 1450 | Loss G: 2.6873 | Loss D: 0.7476\n",
"Epoch 0 | Step 1460 | Loss G: 2.5144 | Loss D: 0.6655\n",
"Epoch 0 | Step 1470 | Loss G: 3.1850 | Loss D: 0.7270\n",
"Epoch 0 | Step 1480 | Loss G: 5.0466 | Loss D: 0.6669\n",
"Epoch 0 | Step 1490 | Loss G: 2.7675 | Loss D: 0.6640\n",
"Epoch 0 | Step 1500 | Loss G: 2.3371 | Loss D: 0.6604\n",
"Epoch 0 | Step 1510 | Loss G: 2.6685 | Loss D: 0.6510\n",
"Epoch 0 | Step 1520 | Loss G: 3.0199 | Loss D: 0.6744\n",
"Epoch 0 | Step 1530 | Loss G: 2.9492 | Loss D: 0.7057\n",
"Epoch 0 | Step 1540 | Loss G: 2.8013 | Loss D: 0.7118\n",
"Epoch 0 | Step 1550 | Loss G: 2.8135 | Loss D: 0.6324\n",
"Epoch 0 | Step 1560 | Loss G: 3.1775 | Loss D: 0.6969\n",
"Epoch 0 | Step 1570 | Loss G: 3.1974 | Loss D: 0.6341\n",
"Epoch 0 | Step 1580 | Loss G: 2.4508 | Loss D: 0.6909\n",
"Epoch 0 | Step 1590 | Loss G: 3.2501 | Loss D: 0.7138\n",
"Epoch 0 | Step 1600 | Loss G: 3.5432 | Loss D: 0.7108\n",
"Epoch 0 | Step 1610 | Loss G: 2.5303 | Loss D: 0.6710\n",
"Epoch 0 | Step 1620 | Loss G: 2.8239 | Loss D: 0.6950\n",
"Epoch 0 | Step 1630 | Loss G: 2.2003 | Loss D: 0.6568\n",
"Epoch 0 | Step 1640 | Loss G: 2.6544 | Loss D: 0.7159\n",
"Epoch 0 | Step 1650 | Loss G: 2.3602 | Loss D: 0.6818\n",
"Epoch 0 | Step 1660 | Loss G: 3.1921 | Loss D: 0.6072\n",
"Epoch 0 | Step 1670 | Loss G: 2.5231 | Loss D: 0.7427\n",
"Epoch 0 | Step 1680 | Loss G: 2.9751 | Loss D: 0.6027\n",
"Epoch 0 | Step 1690 | Loss G: 2.3957 | Loss D: 0.6853\n",
"Epoch 0 | Step 1700 | Loss G: 2.5812 | Loss D: 0.6981\n",
"Epoch 0 | Step 1710 | Loss G: 2.7025 | Loss D: 0.7519\n",
"Epoch 0 | Step 1720 | Loss G: 2.9266 | Loss D: 0.6202\n",
"Epoch 0 | Step 1730 | Loss G: 2.7670 | Loss D: 0.6953\n",
"Epoch 0 | Step 1740 | Loss G: 2.7067 | Loss D: 0.5941\n",
"Epoch 0 | Step 1750 | Loss G: 3.5343 | Loss D: 0.7225\n",
"Epoch 0 | Step 1760 | Loss G: 2.4662 | Loss D: 0.6810\n",
"Epoch 0 | Step 1770 | Loss G: 2.6970 | Loss D: 0.6763\n",
"Epoch 0 | Step 1780 | Loss G: 3.1359 | Loss D: 0.6382\n",
"Epoch 0 | Step 1790 | Loss G: 2.6589 | Loss D: 0.6323\n",
"Epoch 0 | Step 1800 | Loss G: 2.6840 | Loss D: 0.6116\n",
"Epoch 0 | Step 1810 | Loss G: 3.0749 | Loss D: 0.7142\n",
"Epoch 0 | Step 1820 | Loss G: 2.4837 | Loss D: 0.7109\n",
"Epoch 0 | Step 1830 | Loss G: 2.9848 | Loss D: 0.6832\n",
"Epoch 0 | Step 1840 | Loss G: 3.1044 | Loss D: 0.5613\n",
"Epoch 0 | Step 1850 | Loss G: 2.6458 | Loss D: 1.0100\n",
"Epoch 0 | Step 1860 | Loss G: 2.4200 | Loss D: 0.7013\n",
"Epoch 0 | Step 1870 | Loss G: 2.4586 | Loss D: 0.7051\n",
"Epoch 0 | Step 1880 | Loss G: 2.3607 | Loss D: 0.7131\n",
"Epoch 0 | Step 1890 | Loss G: 2.6252 | Loss D: 0.7187\n",
"Epoch 0 | Step 1900 | Loss G: 2.4684 | Loss D: 0.6961\n",
"Epoch 0 | Step 1910 | Loss G: 2.2537 | Loss D: 0.6876\n",
"Epoch 0 | Step 1920 | Loss G: 2.4095 | Loss D: 0.6838\n",
"Epoch 0 | Step 1930 | Loss G: 2.8157 | Loss D: 0.6736\n",
"Epoch 0 | Step 1940 | Loss G: 2.4893 | Loss D: 0.6775\n",
"Epoch 0 | Step 1950 | Loss G: 3.3445 | Loss D: 0.6962\n",
"Epoch 0 | Step 1960 | Loss G: 2.2544 | Loss D: 0.6767\n",
"Epoch 0 | Step 1970 | Loss G: 2.6515 | Loss D: 0.6322\n",
"Epoch 0 | Step 1980 | Loss G: 2.2835 | Loss D: 0.6343\n",
"Epoch 0 | Step 1990 | Loss G: 2.4262 | Loss D: 0.6298\n",
"Epoch 0 | Step 2000 | Loss G: 2.7222 | Loss D: 0.6135\n",
"Epoch 0 | Step 2010 | Loss G: 2.3119 | Loss D: 0.6223\n",
"Epoch 0 | Step 2020 | Loss G: 2.8205 | Loss D: 0.6000\n",
"Epoch 0 | Step 2030 | Loss G: 2.8469 | Loss D: 0.5928\n",
"Epoch 0 | Step 2040 | Loss G: 3.0901 | Loss D: 0.6076\n",
"Epoch 0 | Step 2050 | Loss G: 2.4228 | Loss D: 0.6983\n",
"Epoch 0 | Step 2060 | Loss G: 2.7112 | Loss D: 0.7587\n",
"Epoch 0 | Step 2070 | Loss G: 2.6442 | Loss D: 0.7504\n",
"Epoch 0 | Step 2080 | Loss G: 2.4068 | Loss D: 0.7240\n",
"Epoch 0 | Step 2090 | Loss G: 2.9278 | Loss D: 0.6676\n",
"Epoch 0 | Step 2100 | Loss G: 2.2134 | Loss D: 0.7018\n",
"Epoch 0 | Step 2110 | Loss G: 2.8740 | Loss D: 0.6760\n",
"Epoch 0 | Step 2120 | Loss G: 2.4062 | Loss D: 0.6768\n",
"Epoch 0 | Step 2130 | Loss G: 2.2714 | Loss D: 0.6880\n",
"Epoch 0 | Step 2140 | Loss G: 2.3448 | Loss D: 0.6871\n",
"Epoch 0 | Step 2150 | Loss G: 2.1828 | Loss D: 0.6292\n",
"Epoch 0 | Step 2160 | Loss G: 2.4968 | Loss D: 0.6260\n",
"Epoch 0 | Step 2170 | Loss G: 2.2925 | Loss D: 0.6720\n",
"Epoch 0 | Step 2180 | Loss G: 2.1419 | Loss D: 0.6476\n",
"Epoch 0 | Step 2190 | Loss G: 2.4403 | Loss D: 0.6648\n",
"Epoch 0 | Step 2200 | Loss G: 3.3253 | Loss D: 0.5541\n",
"Epoch 0 | Step 2210 | Loss G: 2.3137 | Loss D: 0.6822\n",
"Epoch 0 | Step 2220 | Loss G: 3.3092 | Loss D: 0.6984\n",
"Epoch 0 | Step 2230 | Loss G: 2.7817 | Loss D: 0.6168\n",
"Epoch 0 | Step 2240 | Loss G: 2.4387 | Loss D: 0.6215\n",
"Epoch 0 | Step 2250 | Loss G: 2.7103 | Loss D: 0.5489\n",
"Epoch 0 | Step 2260 | Loss G: 2.8750 | Loss D: 0.7436\n",
"Epoch 0 | Step 2270 | Loss G: 3.2172 | Loss D: 0.7419\n",
"Epoch 0 | Step 2280 | Loss G: 2.2944 | Loss D: 0.7735\n",
"Epoch 0 | Step 2290 | Loss G: 2.9406 | Loss D: 0.6687\n",
"Epoch 0 | Step 2300 | Loss G: 2.3666 | Loss D: 0.7192\n",
"Epoch 0 | Step 2310 | Loss G: 2.0075 | Loss D: 0.7016\n",
"Epoch 0 | Step 2320 | Loss G: 2.5500 | Loss D: 0.7104\n",
"Epoch 0 | Step 2330 | Loss G: 2.2471 | Loss D: 0.6831\n",
"Epoch 0 | Step 2340 | Loss G: 2.2565 | Loss D: 0.6743\n",
"Epoch 0 | Step 2350 | Loss G: 2.1140 | Loss D: 0.6982\n",
"Epoch 0 | Step 2360 | Loss G: 2.2506 | Loss D: 0.6729\n",
"Epoch 0 | Step 2370 | Loss G: 2.4593 | Loss D: 0.6737\n",
"Epoch 0 | Step 2380 | Loss G: 2.3511 | Loss D: 0.6809\n",
"Epoch 0 | Step 2390 | Loss G: 2.4674 | Loss D: 0.6723\n",
"Epoch 0 | Step 2400 | Loss G: 2.7037 | Loss D: 0.6395\n",
"Epoch 0 | Step 2410 | Loss G: 2.3319 | Loss D: 0.7071\n",
"Epoch 0 | Step 2420 | Loss G: 2.4786 | Loss D: 0.6931\n",
"Epoch 0 | Step 2430 | Loss G: 3.5919 | Loss D: 0.7343\n",
"Epoch 0 | Step 2440 | Loss G: 2.3883 | Loss D: 0.7144\n",
"Epoch 0 | Step 2450 | Loss G: 2.1943 | Loss D: 0.6796\n",
"Epoch 0 | Step 2460 | Loss G: 2.6514 | Loss D: 0.7204\n",
"Epoch 0 | Step 2470 | Loss G: 2.0497 | Loss D: 0.6569\n",
"Epoch 0 | Step 2480 | Loss G: 2.3691 | Loss D: 0.7213\n",
"Epoch 0 | Step 2490 | Loss G: 2.7880 | Loss D: 0.6707\n",
"Epoch 0 | Step 2500 | Loss G: 2.9353 | Loss D: 0.7046\n",
"Epoch 0 | Step 2510 | Loss G: 2.0710 | Loss D: 0.7094\n",
"Epoch 0 | Step 2520 | Loss G: 2.1715 | Loss D: 0.6868\n",
"Epoch 0 | Step 2530 | Loss G: 2.9305 | Loss D: 0.7008\n",
"Epoch 0 | Step 2540 | Loss G: 3.2625 | Loss D: 0.7100\n",
"Epoch 0 | Step 2550 | Loss G: 2.4526 | Loss D: 0.6615\n",
"Epoch 0 | Step 2560 | Loss G: 2.1840 | Loss D: 0.6649\n",
"Epoch 0 | Step 2570 | Loss G: 2.3944 | Loss D: 0.6434\n",
"Epoch 0 | Step 2580 | Loss G: 2.3077 | Loss D: 0.6637\n",
"Epoch 0 | Step 2590 | Loss G: 2.2480 | Loss D: 0.6865\n",
"Epoch 0 | Step 2600 | Loss G: 2.2794 | Loss D: 0.6606\n",
"Epoch 0 | Step 2610 | Loss G: 2.2931 | Loss D: 0.6991\n",
"Epoch 0 | Step 2620 | Loss G: 2.1852 | Loss D: 0.6694\n",
"Epoch 0 | Step 2630 | Loss G: 2.4422 | Loss D: 0.6534\n",
"Epoch 0 | Step 2640 | Loss G: 2.4345 | Loss D: 0.7486\n",
"Epoch 0 | Step 2650 | Loss G: 2.1540 | Loss D: 0.7449\n",
"Epoch 0 | Step 2660 | Loss G: 2.2600 | Loss D: 0.6953\n",
"Epoch 0 | Step 2670 | Loss G: 2.6546 | Loss D: 0.6010\n",
"Epoch 0 | Step 2680 | Loss G: 2.1378 | Loss D: 0.7238\n",
"Epoch 0 | Step 2690 | Loss G: 2.4992 | Loss D: 0.6710\n",
"Epoch 0 | Step 2700 | Loss G: 2.4387 | Loss D: 0.6570\n",
"Epoch 0 | Step 2710 | Loss G: 2.7910 | Loss D: 0.7033\n",
"Epoch 0 | Step 2720 | Loss G: 2.0291 | Loss D: 0.6740\n",
"Epoch 0 | Step 2730 | Loss G: 2.3530 | Loss D: 0.6880\n",
"Epoch 0 | Step 2740 | Loss G: 2.4699 | Loss D: 0.6142\n",
"Epoch 0 | Step 2750 | Loss G: 2.2944 | Loss D: 0.7012\n",
"Epoch 0 | Step 2760 | Loss G: 2.3153 | Loss D: 0.6831\n",
"Epoch 0 | Step 2770 | Loss G: 1.9557 | Loss D: 0.6957\n",
"Epoch 0 | Step 2780 | Loss G: 1.9761 | Loss D: 0.7643\n",
"Epoch 0 | Step 2790 | Loss G: 2.0589 | Loss D: 0.7031\n",
"Epoch 0 | Step 2800 | Loss G: 2.1808 | Loss D: 0.6616\n",
"Epoch 0 | Step 2810 | Loss G: 3.0882 | Loss D: 0.6470\n",
"Epoch 0 | Step 2820 | Loss G: 2.2611 | Loss D: 0.7233\n",
"Epoch 0 | Step 2830 | Loss G: 2.2900 | Loss D: 0.6960\n",
"Epoch 0 | Step 2840 | Loss G: 2.8689 | Loss D: 0.6929\n",
"Epoch 0 | Step 2850 | Loss G: 2.6262 | Loss D: 0.6793\n",
"Epoch 0 | Step 2860 | Loss G: 2.0969 | Loss D: 0.6568\n",
"Epoch 0 | Step 2870 | Loss G: 2.0100 | Loss D: 0.6813\n",
"Epoch 0 | Step 2880 | Loss G: 2.4225 | Loss D: 0.6647\n",
"Epoch 0 | Step 2890 | Loss G: 2.3839 | Loss D: 0.6293\n",
"Epoch 0 | Step 2900 | Loss G: 2.2553 | Loss D: 0.6671\n",
"Epoch 0 | Step 2910 | Loss G: 2.2506 | Loss D: 0.6316\n",
"Epoch 0 | Step 2920 | Loss G: 2.2017 | Loss D: 0.6144\n",
"Epoch 0 | Step 2930 | Loss G: 2.3828 | Loss D: 0.7388\n",
"Epoch 0 | Step 2940 | Loss G: 3.6152 | Loss D: 0.5791\n",
"Epoch 0 | Step 2950 | Loss G: 2.4575 | Loss D: 0.6609\n",
"Epoch 0 | Step 2960 | Loss G: 2.5841 | Loss D: 0.8060\n",
"Epoch 0 | Step 2970 | Loss G: 2.6008 | Loss D: 0.6548\n",
"Epoch 0 | Step 2980 | Loss G: 2.9643 | Loss D: 0.6375\n",
"Epoch 0 | Step 2990 | Loss G: 2.0817 | Loss D: 0.6882\n",
"Epoch 0 | Step 3000 | Loss G: 2.4347 | Loss D: 0.7490\n",
"Epoch 0 | Step 3010 | Loss G: 2.7559 | Loss D: 0.7129\n",
"Epoch 0 | Step 3020 | Loss G: 2.1537 | Loss D: 0.6953\n",
"Epoch 0 | Step 3030 | Loss G: 2.2344 | Loss D: 0.7369\n",
"Epoch 0 | Step 3040 | Loss G: 2.3036 | Loss D: 0.6224\n",
"Epoch 0 | Step 3050 | Loss G: 2.3935 | Loss D: 0.7120\n",
"Epoch 0 | Step 3060 | Loss G: 2.3280 | Loss D: 0.6134\n",
"Epoch 0 | Step 3070 | Loss G: 2.5429 | Loss D: 0.7636\n",
"Epoch 0 | Step 3080 | Loss G: 2.3605 | Loss D: 0.6986\n",
"Epoch 0 | Step 3090 | Loss G: 2.1876 | Loss D: 0.7954\n",
"Epoch 0 | Step 3100 | Loss G: 2.3063 | Loss D: 0.6299\n",
"Epoch 0 | Step 3110 | Loss G: 2.3226 | Loss D: 0.8436\n",
"Epoch 0 | Step 3120 | Loss G: 2.9968 | Loss D: 0.7128\n",
"Epoch 0 | Step 3130 | Loss G: 2.0452 | Loss D: 0.7180\n",
"Epoch 0 | Step 3140 | Loss G: 1.9460 | Loss D: 0.7010\n",
"Epoch 0 | Step 3150 | Loss G: 2.7439 | Loss D: 0.6920\n",
"Epoch 0 | Step 3160 | Loss G: 1.8548 | Loss D: 0.7069\n",
"Epoch 0 | Step 3170 | Loss G: 2.6291 | Loss D: 0.6260\n",
"Epoch 0 | Step 3180 | Loss G: 3.1567 | Loss D: 0.6356\n",
"Epoch 0 | Step 3190 | Loss G: 1.9601 | Loss D: 0.6841\n",
"Epoch 0 | Step 3200 | Loss G: 2.3887 | Loss D: 0.6708\n",
"Epoch 0 | Step 3210 | Loss G: 2.0190 | Loss D: 0.6558\n",
"Epoch 0 | Step 3220 | Loss G: 2.5288 | Loss D: 0.6406\n",
"Epoch 0 | Step 3230 | Loss G: 2.0141 | Loss D: 0.6553\n",
"Epoch 0 | Step 3240 | Loss G: 3.5275 | Loss D: 0.6934\n",
"Epoch 0 | Step 3250 | Loss G: 2.5778 | Loss D: 0.5621\n",
"Epoch 0 | Step 3260 | Loss G: 2.4638 | Loss D: 0.6106\n",
"Epoch 0 | Step 3270 | Loss G: 3.1237 | Loss D: 0.6939\n",
"Epoch 0 | Step 3280 | Loss G: 2.8655 | Loss D: 0.6382\n",
"Epoch 0 | Step 3290 | Loss G: 2.6247 | Loss D: 0.6818\n",
"Epoch 0 | Step 3300 | Loss G: 2.2132 | Loss D: 0.6349\n",
"Epoch 0 | Step 3310 | Loss G: 2.3956 | Loss D: 0.7407\n",
"Epoch 0 | Step 3320 | Loss G: 2.8687 | Loss D: 0.4859\n",
"Epoch 0 | Step 3330 | Loss G: 2.4718 | Loss D: 0.7202\n",
"Epoch 0 | Step 3340 | Loss G: 1.9500 | Loss D: 0.6743\n",
"Epoch 0 | Step 3350 | Loss G: 2.4359 | Loss D: 0.6221\n",
"Epoch 0 | Step 3360 | Loss G: 2.1875 | Loss D: 0.7105\n",
"Epoch 0 | Step 3370 | Loss G: 2.2411 | Loss D: 0.5774\n",
"Epoch 0 | Step 3380 | Loss G: 2.2889 | Loss D: 0.6689\n",
"Epoch 0 | Step 3390 | Loss G: 2.5175 | Loss D: 0.7022\n",
"Epoch 0 | Step 3400 | Loss G: 2.1678 | Loss D: 0.6610\n",
"Epoch 0 | Step 3410 | Loss G: 1.8933 | Loss D: 0.7071\n",
"Epoch 0 | Step 3420 | Loss G: 2.1478 | Loss D: 0.6712\n",
"Epoch 0 | Step 3430 | Loss G: 2.5614 | Loss D: 0.6450\n",
"Epoch 0 | Step 3440 | Loss G: 2.5616 | Loss D: 0.6267\n",
"Epoch 0 | Step 3450 | Loss G: 2.2356 | Loss D: 0.6358\n",
"Epoch 0 | Step 3460 | Loss G: 3.6788 | Loss D: 0.6079\n",
"Epoch 0 | Step 3470 | Loss G: 2.4702 | Loss D: 0.6131\n"
]
}
],
"source": [
"# ==========================================\n",
"# CELL 3: HUẤN LUYỆN & LƯU MODEL\n",
"# ==========================================\n",
"if 'train_loader' in locals():\n",
" generator = CR_Generator().to(device)\n",
" discriminator = Discriminator().to(device)\n",
" \n",
" criterion_GAN = nn.BCELoss()\n",
" criterion_L1 = nn.L1Loss()\n",
" opt_g = torch.optim.Adam(generator.parameters(), lr=0.0002)\n",
" opt_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002)\n",
"\n",
" print(\"🚀 Bắt đầu huấn luyện CR-GAN...\")\n",
" epochs = 5 # Demo chạy 5 vòng (tăng lên nếu muốn train thật)\n",
" \n",
" for epoch in range(epochs):\n",
" for i, (s1, s2, cloud) in enumerate(train_loader):\n",
" s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n",
" \n",
" # --- Train Generator ---\n",
" opt_g.zero_grad()\n",
" fake = generator(s1, cloud)\n",
" pred_fake = discriminator(fake, s1, cloud)\n",
" loss_g = criterion_GAN(pred_fake, torch.ones_like(pred_fake)) + 100 * criterion_L1(fake, s2)\n",
" loss_g.backward()\n",
" opt_g.step()\n",
" \n",
" # --- Train Discriminator ---\n",
" opt_d.zero_grad()\n",
" pred_real = discriminator(s2, s1, cloud)\n",
" pred_fake_d = discriminator(fake.detach(), s1, cloud)\n",
" loss_d = 0.5 * (criterion_GAN(pred_real, torch.ones_like(pred_real)) + \n",
" criterion_GAN(pred_fake_d, torch.zeros_like(pred_fake_d)))\n",
" loss_d.backward()\n",
" opt_d.step()\n",
" \n",
" if i % 10 == 0: \n",
" print(f\"Epoch {epoch} | Step {i} | Loss G: {loss_g.item():.4f} | Loss D: {loss_d.item():.4f}\")\n",
"\n",
" # --- LƯU MODEL ---\n",
" save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n",
" os.makedirs(save_dir, exist_ok=True)\n",
" torch.save(generator.state_dict(), os.path.join(save_dir, \"CRGAN_generator.pth\"))\n",
" print(f\"\\n💾 Đã lưu model tại: {save_dir}/CRGAN_generator.pth\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87873ca2-6a64-4fc9-8687-8c8555b99ed4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

183
Train_GLF-CR.ipynb Normal file
View File

@@ -0,0 +1,183 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"id": "a148f700-b14e-4dc3-a203-9e2f93955587",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Working Directory: /home/jovyan/cloud_train/dataset\n",
"✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\n",
"Using device: cpu\n",
"✅ Data Ready: 29117 samples\n"
]
}
],
"source": [
"# ==========================================\n",
"# CELL 1: SETUP & LOAD DỮ LIỆU\n",
"# ==========================================\n",
"import os, sys, torch, numpy as np\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"# Fix Import Path\n",
"sys.path.append(os.getcwd())\n",
"sys.path.append(\"/home/jovyan/cloud_train\")\n",
"\n",
"try:\n",
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
"except ImportError:\n",
" raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n",
"\n",
"class SEN12MSCR_TorchDataset(Dataset):\n",
" def __init__(self, base_dir, season=Seasons.SPRING):\n",
" self.loader = SEN12MSCRDataset(base_dir)\n",
" self.season = season\n",
" self.season_ids = self.loader.get_season_ids(season)\n",
" self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n",
" def __len__(self): return len(self.samples)\n",
" def __getitem__(self, idx):\n",
" sid, pid = self.samples[idx]\n",
" s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n",
" s1 = np.clip(s1, -25, 0) / 25.0\n",
" s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n",
" c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n",
" return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n",
"\n",
"train_dataset = SEN12MSCR_TorchDataset(base_dir)\n",
"train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n",
"print(f\"✅ GLF-CR Data Ready: {len(train_dataset)} samples\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cdb8ff0f-ea65-4ed4-8f5b-b99b05c602a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# ==========================================\n",
"# CELL 2: MODEL GLF-CR\n",
"# ==========================================\n",
"class FusionBlock(nn.Module):\n",
" def __init__(self, channels):\n",
" super().__init__()\n",
" self.conv = nn.Sequential(\n",
" nn.Conv2d(channels * 2, channels, 1),\n",
" nn.BatchNorm2d(channels), nn.ReLU()\n",
" )\n",
" def forward(self, x1, x2):\n",
" return self.conv(torch.cat([x1, x2], dim=1))\n",
"\n",
"class GLF_Generator(nn.Module):\n",
" def __init__(self):\n",
" super(GLF_Generator, self).__init__()\n",
" # 2 Nhánh riêng biệt\n",
" self.enc_s1 = nn.Sequential(nn.Conv2d(2, 64, 3, 1, 1), nn.ReLU()) # Radar\n",
" self.enc_s2 = nn.Sequential(nn.Conv2d(13, 64, 3, 1, 1), nn.ReLU()) # Mây\n",
" self.fuse = FusionBlock(64)\n",
" self.dec = nn.Sequential(\n",
" nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(),\n",
" nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(),\n",
" nn.Conv2d(64, 13, 3, 1, 1)\n",
" )\n",
" self.tanh = nn.Tanh()\n",
"\n",
" def forward(self, s1, s2_cloudy):\n",
" f1 = self.enc_s1(s1)\n",
" f2 = self.enc_s2(s2_cloudy)\n",
" f_fused = self.fuse(f1, f2)\n",
" return self.tanh(self.dec(f_fused))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b124a548-a7fe-46e7-939d-ae3ff5e5b83c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 Bắt đầu train GLF-CR...\n",
"Step 0 | MSE Loss: 0.4782\n",
"Step 10 | MSE Loss: 0.0848\n",
"Step 20 | MSE Loss: 0.0446\n",
"Step 30 | MSE Loss: 0.0274\n",
"Step 40 | MSE Loss: 0.2338\n",
"Step 50 | MSE Loss: 0.0158\n",
"\n",
"💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/GLFCR_generator.pth\n"
]
}
],
"source": [
"# ==========================================\n",
"# CELL 3: TRAINING & SAVE\n",
"# ==========================================\n",
"if 'train_loader' in locals():\n",
" gen = GLF_Generator().to(device)\n",
" opt = torch.optim.Adam(gen.parameters(), lr=0.0002)\n",
" crit = nn.MSELoss() \n",
" \n",
" print(\"🚀 Bắt đầu train GLF-CR...\")\n",
" for i, (s1, s2, cloud) in enumerate(train_loader):\n",
" s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n",
" \n",
" opt.zero_grad()\n",
" out = gen(s1, cloud)\n",
" loss = crit(out, s2)\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" if i % 10 == 0: print(f\"Step {i} | MSE Loss: {loss.item():.4f}\")\n",
" if i > 50: break\n",
" \n",
" # --- LƯU MODEL ---\n",
" save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n",
" os.makedirs(save_dir, exist_ok=True)\n",
" torch.save(gen.state_dict(), os.path.join(save_dir, \"GLFCR_generator.pth\"))\n",
" print(f\"\\n💾 Đã lưu model tại: {save_dir}/GLFCR_generator.pth\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f863772-5925-45b4-a9aa-e1dc54f52658",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

177
Train_SpA-GAN.ipynb Normal file
View File

@@ -0,0 +1,177 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "3813f900-ea54-4533-873a-70aa46417ba0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ SpA-GAN Data Ready: 29117 samples\n"
]
}
],
"source": [
"# ==========================================\n",
"# CELL 1: SETUP & LOAD DỮ LIỆU\n",
"# ==========================================\n",
"import os, sys, torch, numpy as np\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"# Fix Import Path\n",
"sys.path.append(os.getcwd())\n",
"sys.path.append(\"/home/jovyan/cloud_train\")\n",
"\n",
"try:\n",
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
"except ImportError:\n",
" raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n",
"\n",
"class SEN12MSCR_TorchDataset(Dataset):\n",
" def __init__(self, base_dir, season=Seasons.SPRING):\n",
" self.loader = SEN12MSCRDataset(base_dir)\n",
" self.season = season\n",
" self.season_ids = self.loader.get_season_ids(season)\n",
" self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n",
" def __len__(self): return len(self.samples)\n",
" def __getitem__(self, idx):\n",
" sid, pid = self.samples[idx]\n",
" s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n",
" s1 = np.clip(s1, -25, 0) / 25.0\n",
" s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n",
" c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n",
" return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n",
"\n",
"train_dataset = SEN12MSCR_TorchDataset(base_dir)\n",
"train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n",
"print(f\"✅ SpA-GAN Data Ready: {len(train_dataset)} samples\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "078b6951-1dcf-44e9-895c-9ec0c8c848b6",
"metadata": {},
"outputs": [],
"source": [
"# ==========================================\n",
"# CELL 2: MODEL SpA-GAN\n",
"# ==========================================\n",
"class SpatialAttention(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super(SpatialAttention, self).__init__()\n",
" self.conv = nn.Conv2d(in_channels, 1, kernel_size=1) \n",
" self.sigmoid = nn.Sigmoid()\n",
" def forward(self, x):\n",
" return x * self.sigmoid(self.conv(x))\n",
"\n",
"class SpA_Generator(nn.Module):\n",
" def __init__(self):\n",
" super(SpA_Generator, self).__init__()\n",
" self.enc1 = nn.Sequential(nn.Conv2d(15, 64, 4, 2, 1), nn.LeakyReLU(0.2))\n",
" self.attn = SpatialAttention(64) # Module Attention\n",
" self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))\n",
" \n",
" self.dec1 = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())\n",
" self.dec2 = nn.ConvTranspose2d(64, 13, 4, 2, 1)\n",
" self.tanh = nn.Tanh()\n",
"\n",
" def forward(self, s1, s2_cloudy):\n",
" x = torch.cat([s1, s2_cloudy], dim=1)\n",
" e1 = self.enc1(x)\n",
" e1_attn = self.attn(e1) # Chú ý vào vùng quan trọng\n",
" e2 = self.enc2(e1_attn)\n",
" d1 = self.dec1(e2)\n",
" return self.tanh(self.dec2(d1 + e1))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e6bd89be-aa61-4429-b588-f2041311f2a6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 Bắt đầu train SpA-GAN...\n",
"Step 0 | Loss G: 0.6843\n",
"Step 10 | Loss G: 0.3739\n",
"Step 20 | Loss G: 0.1896\n",
"Step 30 | Loss G: 0.1224\n",
"Step 40 | Loss G: 0.0938\n",
"Step 50 | Loss G: 0.1833\n",
"\n",
"💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/SpAGAN_generator.pth\n"
]
}
],
"source": [
"# ==========================================\n",
"# CELL 3: TRAINING & SAVE\n",
"# ==========================================\n",
"if 'train_loader' in locals():\n",
" gen = SpA_Generator().to(device)\n",
" opt_g = torch.optim.Adam(gen.parameters(), lr=0.0002)\n",
" criterion = nn.L1Loss()\n",
" \n",
" print(\"🚀 Bắt đầu train SpA-GAN...\")\n",
" for i, (s1, s2, cloud) in enumerate(train_loader):\n",
" s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n",
" \n",
" opt_g.zero_grad()\n",
" fake = gen(s1, cloud)\n",
" loss_g = criterion(fake, s2) \n",
" loss_g.backward()\n",
" opt_g.step()\n",
" \n",
" if i % 10 == 0: print(f\"Step {i} | Loss G: {loss_g.item():.4f}\")\n",
" if i > 50: break # Demo 50 bước\n",
" \n",
" # --- LƯU MODEL ---\n",
" save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n",
" os.makedirs(save_dir, exist_ok=True)\n",
" torch.save(gen.state_dict(), os.path.join(save_dir, \"SpAGAN_generator.pth\"))\n",
" print(f\"\\n💾 Đã lưu model tại: {save_dir}/SpAGAN_generator.pth\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc1d8e75-0a3f-4197-8196-7132122f2be9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,321 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d-0LFUETwSj8",
"outputId": "3b84f2ea-e6d6-4404-a0ae-68ed567ce5ef"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Đã cấu hình xong! Hãy chạy các cell bên dưới để tải từng phần.\n"
]
}
],
"source": [
"# --- CELL 1: SETUP ---\n",
"import os\n",
"import time\n",
"\n",
"# 1. Cấu hình\n",
"os.environ['RSYNC_PASSWORD'] = 'm1554803'\n",
"output_dir = \"./dataset/\"\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"base_url = \"rsync://m1554803@dataserv.ub.tum.de/m1554803/\"\n",
"\n",
"# 2. Hàm tải file (Dùng chung cho các cell bên dưới)\n",
"def download_files(file_list):\n",
" print(f\"📂 Lưu tại: {output_dir}\")\n",
" print(\"-\" * 50)\n",
" for filename in file_list:\n",
" print(f\"\\n🚀 Đang tải: {filename}\")\n",
"\n",
" # Lệnh rsync chuẩn, có resume (-P) và bỏ qua lỗi permission\n",
" command = f\"rsync -rvP --no-o --no-g --no-p --no-t {base_url}{filename} {output_dir}\"\n",
"\n",
" exit_code = os.system(command)\n",
" if exit_code == 0:\n",
" print(f\"✅ XONG: {filename}\")\n",
" else:\n",
" print(f\"❌ LỖI: {filename} (Code: {exit_code})\")\n",
" time.sleep(1)\n",
"\n",
"print(\"✅ Đã cấu hình xong! Hãy chạy các cell bên dưới để tải từng phần.\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JIijwdmmwSho",
"outputId": "b84d0afb-66bf-4fb7-b2c6-3db3c55c49d0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Lưu tại: ./dataset/\n",
"--------------------------------------------------\n",
"\n",
"🚀 Đang tải: checksums.sha512\n",
"receiving incremental file list\n",
"checksums.sha512\n",
" 2,230 100% 2.13MB/s 0:00:00 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 43 bytes received 2,330 bytes 431.45 bytes/sec\n",
"total size is 2,230 speedup is 0.94\n",
"✅ XONG: checksums.sha512\n",
"\n",
"🚀 Đang tải: sen12ms_cr_dataLoader.py\n",
"receiving incremental file list\n",
"sen12ms_cr_dataLoader.py\n",
" 9,285 100% 8.85MB/s 0:00:00 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 43 bytes received 9,394 bytes 1,715.82 bytes/sec\n",
"total size is 9,285 speedup is 0.98\n",
"✅ XONG: sen12ms_cr_dataLoader.py\n"
]
}
],
"source": [
"# --- CELL 2: FILES NHỎ ---\n",
"files = [\n",
" \"checksums.sha512\",\n",
" \"sen12ms_cr_dataLoader.py\"\n",
"]\n",
"download_files(files)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2SezTEcFwSf7",
"outputId": "5fe6e5ed-4267-4623-8554-8f9646074639"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Lưu tại: ./dataset/\n",
"--------------------------------------------------\n",
"\n",
"🚀 Đang tải: ROIs2017_winter_s1.tar.gz\n",
"receiving incremental file list\n",
"ROIs2017_winter_s1.tar.gz\n",
" 8,294,291,725 100% 78.02MB/s 0:01:41 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 728,687 bytes received 364,428 bytes 6,726.86 bytes/sec\n",
"total size is 8,294,291,725 speedup is 7,587.76\n",
"✅ XONG: ROIs2017_winter_s1.tar.gz\n",
"\n",
"🚀 Đang tải: ROIs2017_winter_s2.tar.gz\n",
"receiving incremental file list\n",
"ROIs2017_winter_s2.tar.gz\n",
" 21,959,347,301 100% 76.10MB/s 0:04:35 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 1,340,419 bytes received 670,295 bytes 5,215.86 bytes/sec\n",
"total size is 21,959,347,301 speedup is 10,921.17\n",
"✅ XONG: ROIs2017_winter_s2.tar.gz\n",
"\n",
"🚀 Đang tải: ROIs2017_winter_s2_cloudy.tar.gz\n",
"receiving incremental file list\n",
"ROIs2017_winter_s2_cloudy.tar.gz\n",
" 13,391,641,622 100% 74.42MB/s 0:02:51 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 925,899 bytes received 463,043 bytes 5,566.90 bytes/sec\n",
"total size is 13,391,641,622 speedup is 9,641.61\n",
"✅ XONG: ROIs2017_winter_s2_cloudy.tar.gz\n"
]
}
],
"source": [
"# --- CELL 3: MÙA ĐÔNG (WINTER) ---\n",
"# Tổng dung lượng: ~40.6 GB\n",
"files = [\n",
" \"ROIs2017_winter_s1.tar.gz\",\n",
" \"ROIs2017_winter_s2.tar.gz\",\n",
" \"ROIs2017_winter_s2_cloudy.tar.gz\"\n",
"]\n",
"download_files(files)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7iIFPApowSbu",
"outputId": "9e4d45c7-0ee5-45f6-860e-99af4601aa22"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Lưu tại: ./dataset/\n",
"--------------------------------------------------\n",
"\n",
"🚀 Đang tải: ROIs1158_spring_s1.tar.gz\n",
"receiving incremental file list\n",
"ROIs1158_spring_s1.tar.gz\n",
" 13,229,071,106 100% 79.26MB/s 0:02:39 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 920,259 bytes received 460,216 bytes 5,763.99 bytes/sec\n",
"total size is 13,229,071,106 speedup is 9,582.98\n",
"✅ XONG: ROIs1158_spring_s1.tar.gz\n",
"\n",
"🚀 Đang tải: ROIs1158_spring_s2.tar.gz\n",
"receiving incremental file list\n",
"ROIs1158_spring_s2.tar.gz\n",
" 34,953,924,853 100% 76.34MB/s 0:07:16 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 2,133,595 bytes received 1,066,884 bytes 5,419.95 bytes/sec\n",
"total size is 34,953,924,853 speedup is 10,921.47\n",
"✅ XONG: ROIs1158_spring_s2.tar.gz\n",
"\n",
"🚀 Đang tải: ROIs1158_spring_s2_cloudy.tar.gz\n",
"receiving incremental file list\n",
"ROIs1158_spring_s2_cloudy.tar.gz\n",
" 21,910,216,767 100% 77.47MB/s 0:04:29 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 1,337,419 bytes received 668,802 bytes 5,400.33 bytes/sec\n",
"total size is 21,910,216,767 speedup is 10,921.14\n",
"✅ XONG: ROIs1158_spring_s2_cloudy.tar.gz\n"
]
}
],
"source": [
"# --- CELL 4: MÙA XUÂN (SPRING) ---\n",
"# Tổng dung lượng: ~65.3 GB\n",
"files = [\n",
" \"ROIs1158_spring_s1.tar.gz\",\n",
" \"ROIs1158_spring_s2.tar.gz\",\n",
" \"ROIs1158_spring_s2_cloudy.tar.gz\"\n",
"]\n",
"download_files(files)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "s6t_5qyUwSZ-",
"outputId": "ddc101ae-753d-44f1-c3ee-c8413c16fb99"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Lưu tại: ./dataset/\n",
"--------------------------------------------------\n",
"\n",
"🚀 Đang tải: ROIs1868_summer_s2_cloudy.tar.gz\n",
"receiving incremental file list\n",
"ROIs1868_summer_s2_cloudy.tar.gz\n",
" 24,766,970,197 100% 77.10MB/s 0:05:06 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 1,511,791 bytes received 755,986 bytes 5,329.68 bytes/sec\n",
"total size is 24,766,970,197 speedup is 10,921.25\n",
"✅ XONG: ROIs1868_summer_s2_cloudy.tar.gz\n"
]
}
],
"source": [
"# --- CELL 5: MÙA HÈ (SUMMER) ---\n",
"# Tổng dung lượng: ~74.7 GB\n",
"files = [\n",
" \"ROIs1868_summer_s1.tar.gz\",\n",
" \"ROIs1868_summer_s2.tar.gz\",\n",
" \"ROIs1868_summer_s2_cloudy.tar.gz\"\n",
"]\n",
"download_files(files)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KbccbbI9wSUd",
"outputId": "7b208f9d-8c3f-45e8-9582-88a872527cb0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Lưu tại: ./dataset/\n",
"--------------------------------------------------\n",
"\n",
"🚀 Đang tải: ROIs1970_fall_s2_cloudy.tar.gz\n",
"receiving incremental file list\n",
"ROIs1970_fall_s2_cloudy.tar.gz\n",
" 30,149,143,027 100% 36.83MB/s 0:13:00 (xfr#1, to-chk=0/1)\n",
"\n",
"sent 1,785,615 bytes received 897,343,620 bytes 1,008,557.75 bytes/sec\n",
"total size is 30,149,143,027 speedup is 33.53\n",
"✅ XONG: ROIs1970_fall_s2_cloudy.tar.gz\n"
]
}
],
"source": [
"### --- CELL 6: MÙA THU (FALL) ---\n",
"# Tổng dung lượng: ~91.2 GB\n",
"files = [\n",
" \"ROIs1970_fall_s1.tar.gz\",\n",
" \"ROIs1970_fall_s2.tar.gz\",\n",
" \"ROIs1970_fall_s2_cloudy.tar.gz\"\n",
"]\n",
"download_files(files)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

241
format_lại_đataset.ipynb Normal file
View File

@@ -0,0 +1,241 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "415c3125-320a-4164-a9ad-4007ac750ff9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Đang thực hiện sắp xếp tại: /home/jovyan/cloud_train/dataset\n",
" Đã tạo folder mẹ: ROIs1158_spring\n",
" ➡ Di chuyển: ROIs1158_spring_s1 -> ROIs1158_spring/s1\n",
" ➡ Di chuyển: ROIs1158_spring_s2 -> ROIs1158_spring/s2\n",
" ➡ Di chuyển: ROIs1158_spring_s2_cloudy -> ROIs1158_spring/s2_cloudy\n",
" Đã tạo folder mẹ: ROIs1868_summer\n",
" ➡ Di chuyển: ROIs1868_summer_s1 -> ROIs1868_summer/s1\n",
" ➡ Di chuyển: ROIs1868_summer_s2 -> ROIs1868_summer/s2\n",
" ➡ Di chuyển: ROIs1868_summer_s2_cloudy -> ROIs1868_summer/s2_cloudy\n",
" Đã tạo folder mẹ: ROIs1970_fall\n",
" ➡ Di chuyển: ROIs1970_fall_s1 -> ROIs1970_fall/s1\n",
" ➡ Di chuyển: ROIs1970_fall_s2 -> ROIs1970_fall/s2\n",
" ➡ Di chuyển: ROIs1970_fall_s2_cloudy -> ROIs1970_fall/s2_cloudy\n",
" Đã tạo folder mẹ: ROIs2017_winter\n",
" ➡ Di chuyển: ROIs2017_winter_s1 -> ROIs2017_winter/s1\n",
" ➡ Di chuyển: ROIs2017_winter_s2 -> ROIs2017_winter/s2\n",
" ➡ Di chuyển: ROIs2017_winter_s2_cloudy -> ROIs2017_winter/s2_cloudy\n",
"\n",
"✅ ĐÃ XONG! Đã di chuyển 12 thư mục.\n",
"👉 Bây giờ hãy chạy lại Cell load dữ liệu (Train code), nó sẽ hoạt động.\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"\n",
"# ==========================================\n",
"# SCRIPT SẮP XẾP LẠI THƯ MỤC (CHẠY 1 LẦN)\n",
"# ==========================================\n",
"\n",
"# 1. Xác định đường dẫn gốc\n",
"user_home = os.path.expanduser(\"~\")\n",
"base_dir = os.path.join(user_home, \"cloud_train\", \"dataset\")\n",
"print(f\"📂 Đang thực hiện sắp xếp tại: {base_dir}\")\n",
"\n",
"if not os.path.exists(base_dir):\n",
" print(\"❌ Lỗi: Không tìm thấy thư mục dataset!\")\n",
"else:\n",
" # 2. Định nghĩa cấu trúc chuẩn cần gom\n",
" # Key = Tên thư mục mẹ (Mới)\n",
" # Value = Danh sách các thư mục con (Cũ - đang nằm lẻ)\n",
" seasons_map = {\n",
" \"ROIs1158_spring\": [\"ROIs1158_spring_s1\", \"ROIs1158_spring_s2\", \"ROIs1158_spring_s2_cloudy\"],\n",
" \"ROIs1868_summer\": [\"ROIs1868_summer_s1\", \"ROIs1868_summer_s2\", \"ROIs1868_summer_s2_cloudy\"],\n",
" \"ROIs1970_fall\": [\"ROIs1970_fall_s1\", \"ROIs1970_fall_s2\", \"ROIs1970_fall_s2_cloudy\"],\n",
" \"ROIs2017_winter\": [\"ROIs2017_winter_s1\", \"ROIs2017_winter_s2\", \"ROIs2017_winter_s2_cloudy\"]\n",
" }\n",
"\n",
" count_moved = 0\n",
" \n",
" for target_season, sub_folders in seasons_map.items():\n",
" # Tạo thư mục mẹ (ví dụ: ROIs1158_spring)\n",
" target_path = os.path.join(base_dir, target_season)\n",
" if not os.path.exists(target_path):\n",
" os.makedirs(target_path)\n",
" print(f\" Đã tạo folder mẹ: {target_season}\")\n",
" \n",
" for sub in sub_folders:\n",
" source_path = os.path.join(base_dir, sub) # Đường dẫn thư mục lẻ hiện tại\n",
" \n",
" # Nếu thư mục lẻ tồn tại, di chuyển nội dung của nó vào thư mục mẹ\n",
" if os.path.exists(source_path):\n",
" # Đổi tên thư mục lẻ cho đúng chuẩn (bỏ prefix mùa đi)\n",
" # Ví dụ: ROIs1158_spring_s1 -> s1\n",
" \n",
" new_folder_name = \"\"\n",
" if \"_s1\" in sub: new_folder_name = \"s1\"\n",
" elif \"_s2_cloudy\" in sub: new_folder_name = \"s2_cloudy\"\n",
" elif \"_s2\" in sub: new_folder_name = \"s2\"\n",
" \n",
" final_dest = os.path.join(target_path, new_folder_name)\n",
" \n",
" print(f\" ➡ Di chuyển: {sub} -> {target_season}/{new_folder_name}\")\n",
" \n",
" try:\n",
" # Nếu đích chưa có thì move thẳng folder sang và đổi tên\n",
" if not os.path.exists(final_dest):\n",
" shutil.move(source_path, final_dest)\n",
" count_moved += 1\n",
" else:\n",
" print(f\" ⚠️ Thư mục {new_folder_name} đã tồn tại trong {target_season}, bỏ qua.\")\n",
" except Exception as e:\n",
" print(f\" ❌ Lỗi: {e}\")\n",
" else:\n",
" # Kiểm tra xem nó đã nằm đúng chỗ chưa\n",
" check_path = os.path.join(target_path, sub.split(\"_\")[-1].replace(\"cloudy\", \"s2_cloudy\") if \"cloudy\" in sub else sub.split(\"_\")[-1])\n",
" # Logic check đơn giản\n",
" pass\n",
"\n",
" if count_moved > 0:\n",
" print(f\"\\n✅ ĐÃ XONG! Đã di chuyển {count_moved} thư mục.\")\n",
" print(\"👉 Bây giờ hãy chạy lại Cell load dữ liệu (Train code), nó sẽ hoạt động.\")\n",
" else:\n",
" print(\"\\n Không có gì thay đổi. Có thể cấu trúc đã đúng hoặc tên file không khớp.\")\n",
" print(\"Hãy kiểm tra thủ công bằng lệnh: !ls -F /home/jovyan/cloud_train/dataset/\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0b006b3f-54f9-46d2-87a6-c4ee00f7b1e4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📂 Đang xử lý tại: /home/jovyan/cloud_train/dataset\n",
"🔍 Đang quét mùa: ROIs1158_spring\n",
" found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s1\n",
" found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s2\n",
" found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s2_cloudy\n",
"🔍 Đang quét mùa: ROIs1868_summer\n",
" found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s1\n",
" found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s2\n",
" found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s2_cloudy\n",
"🔍 Đang quét mùa: ROIs1970_fall\n",
" found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s1\n",
" found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s2\n",
" found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s2_cloudy\n",
"🔍 Đang quét mùa: ROIs2017_winter\n",
" found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s1\n",
" found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s2\n",
" found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n",
" ✅ Đã xóa folder rỗng: s2_cloudy\n",
"\n",
"✅ ĐÃ SỬA XONG CẤU TRÚC!\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"\n",
"# ==========================================\n",
"# FIX CẤU TRÚC PHASE 2: ĐƯA THƯ MỤC CON RA NGOÀI\n",
"# ==========================================\n",
"\n",
"user_home = os.path.expanduser(\"~\")\n",
"base_dir = os.path.join(user_home, \"cloud_train\", \"dataset\")\n",
"print(f\"📂 Đang xử lý tại: {base_dir}\")\n",
"\n",
"seasons = [\n",
" \"ROIs1158_spring\", \n",
" \"ROIs1868_summer\", \n",
" \"ROIs1970_fall\", \n",
" \"ROIs2017_winter\"\n",
"]\n",
"\n",
"# Các folder trung gian cần loại bỏ\n",
"sub_types = [\"s1\", \"s2\", \"s2_cloudy\"]\n",
"\n",
"for season in seasons:\n",
" season_path = os.path.join(base_dir, season)\n",
" if not os.path.exists(season_path):\n",
" continue\n",
" \n",
" print(f\"🔍 Đang quét mùa: {season}\")\n",
" \n",
" for sub in sub_types:\n",
" # Đường dẫn tới folder trung gian (ví dụ: ROIs1158_spring/s1)\n",
" nested_path = os.path.join(season_path, sub)\n",
" \n",
" if os.path.exists(nested_path):\n",
" print(f\" found nested folder: {sub} -> Đang di chuyển nội dung ra ngoài...\")\n",
" \n",
" # Lấy danh sách các scene bên trong (s1_1, s1_2...)\n",
" scenes = os.listdir(nested_path)\n",
" for scene in scenes:\n",
" src = os.path.join(nested_path, scene)\n",
" dst = os.path.join(season_path, scene)\n",
" \n",
" # Di chuyển ra folder mẹ\n",
" if not os.path.exists(dst):\n",
" shutil.move(src, dst)\n",
" \n",
" # Sau khi chuyển hết thì xóa folder rỗng đi\n",
" try:\n",
" os.rmdir(nested_path)\n",
" print(f\" ✅ Đã xóa folder rỗng: {sub}\")\n",
" except OSError:\n",
" print(f\" ⚠️ Không xóa được {sub} (có thể còn file rác)\")\n",
"\n",
"print(\"\\n✅ ĐÃ SỬA XONG CẤU TRÚC!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6151206-f103-451c-a572-029a4da25816",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

270
sen12ms_cr_dataLoader.py Executable file
View File

@@ -0,0 +1,270 @@
"""
Generic data loading routines for the SEN12MS-CR dataset of corresponding Sentinel 1,
Sentinel 2 and cloudy Sentinel 2 data.
The SEN12MS-CR class is meant to provide a set of helper routines for loading individual
image patches as well as triplets of patches from the dataset. These routines can easily
be wrapped or extended for use with many deep learning frameworks or as standalone helper
methods. For an example use case please see the "main" routine at the end of this file.
NOTE: Some folder/file existence and validity checks are implemented but it is
by no means complete.
Authors: Patrick Ebel (patrick.ebel@tum.de), Lloyd Hughes (lloyd.hughes@tum.de),
based on the exemplary data loader code of https://mediatum.ub.tum.de/1474000, with minimal modifications applied.
"""
import os
import rasterio
import numpy as np
from enum import Enum
from glob import glob
class S1Bands(Enum):
VV = 1
VH = 2
ALL = [VV, VH]
NONE = []
class S2Bands(Enum):
B01 = aerosol = 1
B02 = blue = 2
B03 = green = 3
B04 = red = 4
B05 = re1 = 5
B06 = re2 = 6
B07 = re3 = 7
B08 = nir1 = 8
B08A = nir2 = 9
B09 = vapor = 10
B10 = cirrus = 11
B11 = swir1 = 12
B12 = swir2 = 13
ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12]
RGB = [B04, B03, B02]
NONE = []
class Seasons(Enum):
SPRING = "ROIs1158_spring"
SUMMER = "ROIs1868_summer"
FALL = "ROIs1970_fall"
WINTER = "ROIs2017_winter"
ALL = [SPRING, SUMMER, FALL, WINTER]
class Sensor(Enum):
s1 = "s1"
s2 = "s2"
s2cloudy = "s2cloudy"
# Note: The order in which you request the bands is the same order they will be returned in.
class SEN12MSCRDataset:
def __init__(self, base_dir):
self.base_dir = base_dir
if not os.path.exists(self.base_dir):
raise Exception(
"The specified base_dir for SEN12MS-CR dataset does not exist")
"""
Returns a list of scene ids for a specific season.
"""
def get_scene_ids(self, season):
season = Seasons(season).value
path = os.path.join(self.base_dir, season)
if not os.path.exists(path):
raise NameError("Could not find season {} in base directory {}".format(
season, self.base_dir))
# add all dirs except "s2_cloudy" (which messes with subsequent string splits)
scene_list = [os.path.basename(s)
for s in glob(os.path.join(path, "*")) if "s2_cloudy" not in s]
scene_list = [int(s.split("_")[1]) for s in scene_list]
return set(scene_list)
"""
Returns a list of patch ids for a specific scene within a specific season
"""
def get_patch_ids(self, season, scene_id):
season = Seasons(season).value
path = os.path.join(self.base_dir, season, f"s1_{scene_id}")
if not os.path.exists(path):
raise NameError(
"Could not find scene {} within season {}".format(scene_id, season))
patch_ids = [os.path.splitext(os.path.basename(p))[0]
for p in glob(os.path.join(path, "*"))]
patch_ids = [int(p.rsplit("_", 1)[1].split("p")[1]) for p in patch_ids]
return patch_ids
"""
Return a dict of scene ids and their corresponding patch ids.
key => scene_ids, value => list of patch_ids
"""
def get_season_ids(self, season):
season = Seasons(season).value
ids = {}
scene_ids = self.get_scene_ids(season)
for sid in scene_ids:
ids[sid] = self.get_patch_ids(season, sid)
return ids
"""
Returns raster data and image bounds for the defined bands of a specific patch
This method only loads a sinlge patch from a single sensor as defined by the bands specified
"""
def get_patch(self, season, scene_id, patch_id, bands):
season = Seasons(season).value
sensor = None
if isinstance(bands, (list, tuple)):
b = bands[0]
else:
b = bands
if isinstance(b, S1Bands):
sensor = Sensor.s1.value
bandEnum = S1Bands
elif isinstance(b, S2Bands):
sensor = Sensor.s2.value
bandEnum = S2Bands
else:
raise Exception("Invalid bands specified")
if isinstance(bands, (list, tuple)):
bands = [b.value for b in bands]
else:
bands = bands.value
scene = "{}_{}".format(sensor, scene_id)
filename = "{}_{}_p{}.tif".format(season, scene, patch_id)
patch_path = os.path.join(self.base_dir, season, scene, filename)
with rasterio.open(patch_path) as patch:
data = patch.read(bands)
bounds = patch.bounds
if len(data.shape) == 2:
data = np.expand_dims(data, axis=0)
return data, bounds
"""
Returns a triplet of patches. S1, S2 and cloudy S2 as well as the geo-bounds of the patch
"""
def get_s1s2s2cloudy_triplet(self, season, scene_id, patch_id, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
s1, bounds = self.get_patch(season, scene_id, patch_id, s1_bands)
s2, _ = self.get_patch(season, scene_id, patch_id, s2_bands)
s2cloudy, _ = self.get_patch(season, scene_id, patch_id, s2cloudy_bands)
return s1, s2, s2cloudy, bounds
"""
Returns a triplet of numpy arrays with dimensions D, B, W, H where D is the number of patches specified
using scene_ids and patch_ids and B is the number of bands for S1, S2 or cloudy S2
"""
def get_triplets(self, season, scene_ids=None, patch_ids=None, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
season = Seasons(season)
scene_list = []
patch_list = []
bounds = []
s1_data = []
s2_data = []
s2cloudy_data = []
# This is due to the fact that not all patch ids are available in all scenes
# And not all scenes exist in all seasons
if isinstance(scene_ids, list) and isinstance(patch_ids, list):
raise Exception("Only scene_ids or patch_ids can be a list, not both.")
if scene_ids is None:
scene_list = self.get_scene_ids(season)
else:
try:
scene_list.extend(scene_ids)
except TypeError:
scene_list.append(scene_ids)
if patch_ids is not None:
try:
patch_list.extend(patch_ids)
except TypeError:
patch_list.append(patch_ids)
for sid in scene_list:
if patch_ids is None:
patch_list = self.get_patch_ids(season, sid)
for pid in patch_list:
s1, s2, s2cloudy, bound = self.get_s1s2s2cloudy_triplet(
season, sid, pid, s1_bands, s2_bands, s2cloudy_bands)
s1_data.append(s1)
s2_data.append(s2)
s2cloudy_data.append(s2cloudy)
bounds.append(bound)
return np.stack(s1_data, axis=0), np.stack(s2_data, axis=0), np.stack(s2cloudy_data, axis=0), bounds
if __name__ == "__main__":
import time
# Load the dataset specifying the base directory
sen12mscr = SEN12MSCRDataset(".")
spring_ids = sen12mscr.get_season_ids(Seasons.SPRING)
cnt_patches = sum([len(pids) for pids in spring_ids.values()])
print("Spring: {} scenes with a total of {} patches".format(
len(spring_ids), cnt_patches))
start = time.time()
# Load the RGB bands of the first S2 patch in scene 8
SCENE_ID = 8
s2_rgb_patch, bounds = sen12mscr.get_patch(Seasons.SPRING, SCENE_ID,
spring_ids[SCENE_ID][0], bands=S2Bands.RGB)
print("Time Taken {}s".format(time.time() - start))
print("S2 RGB: {} Bounds: {}".format(s2_rgb_patch.shape, bounds))
print("\n")
# Load a triplet of patches from the first three scenes of Spring - all S1 bands, NDVI S2 bands, and NDVI S2 cloudy bands
i = 0
start = time.time()
for scene_id, patch_ids in spring_ids.items():
if i >= 3:
break
s1, s2, s2cloudy, bounds = sen12mscr.get_s1s2s2cloudy_triplet(Seasons.SPRING, scene_id, patch_ids[0], s1_bands=S1Bands.ALL,
s2_bands=[S2Bands.red, S2Bands.nir1], s2cloudy_bands=[S2Bands.red, S2Bands.nir1])
print(
f"Scene: {scene_id}, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}, Bounds: {bounds}")
i += 1
print("Time Taken {}s".format(time.time() - start))
print("\n")
start = time.time()
# Load all bands of all patches in a specified scene (scene 106)
s1, s2, s2cloudy, _ = sen12mscr.get_triplets(Seasons.SPRING, 106, s1_bands=S1Bands.ALL,
s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL)
print(f"Scene: 106, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}")
print("Time Taken {}s".format(time.time() - start))