init
This commit is contained in:
568
Train_CR-GAN.ipynb
Normal file
568
Train_CR-GAN.ipynb
Normal file
@@ -0,0 +1,568 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c886b44b-4288-4260-8b5b-bb78a87e172f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Working Directory: /home/jovyan/cloud_train/dataset\n",
|
||||
"✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\n",
|
||||
"Using device: cpu\n",
|
||||
"✅ Data Ready: 29117 samples\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ==========================================\n",
|
||||
"# CELL 1: SETUP & LOAD DỮ LIỆU (ĐÃ SỬA LỖI IMPORT)\n",
|
||||
"# ==========================================\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"import torch.nn as nn\n",
|
||||
"\n",
|
||||
"# --- KHẮC PHỤC LỖI KHÔNG TÌM THẤY MODULE ---\n",
|
||||
"# Thêm đường dẫn thư mục hiện tại vào hệ thống để Python tìm thấy file loader\n",
|
||||
"current_dir = os.getcwd()\n",
|
||||
"sys.path.append(current_dir)\n",
|
||||
"print(f\"📂 Working Directory: {current_dir}\")\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
|
||||
" print(\"✅ Đã import thành công sen12ms_cr_dataLoader!\")\n",
|
||||
"except ImportError:\n",
|
||||
" # Nếu vẫn lỗi, thử trỏ cứng vào thư mục cloud_train\n",
|
||||
" sys.path.append(\"/home/jovyan/cloud_train\")\n",
|
||||
" try:\n",
|
||||
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
|
||||
" print(\"✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\")\n",
|
||||
" except ImportError:\n",
|
||||
" raise RuntimeError(\"❌ Vẫn không tìm thấy file 'sen12ms_cr_dataLoader.py'. Hãy đảm bảo file này nằm cùng thư mục với Notebook!\")\n",
|
||||
"\n",
|
||||
"# Cấu hình thiết bị\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"print(f\"Using device: {device}\")\n",
|
||||
"\n",
|
||||
"# Đường dẫn dataset\n",
|
||||
"base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n",
|
||||
"\n",
|
||||
"# Dataset Wrapper\n",
|
||||
"class SEN12MSCR_TorchDataset(Dataset):\n",
|
||||
" def __init__(self, base_dir, season=Seasons.SPRING):\n",
|
||||
" self.loader = SEN12MSCRDataset(base_dir)\n",
|
||||
" self.season = season\n",
|
||||
" self.season_ids = self.loader.get_season_ids(season)\n",
|
||||
" self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n",
|
||||
" def __len__(self): return len(self.samples)\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" sid, pid = self.samples[idx]\n",
|
||||
" s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n",
|
||||
" s1 = np.clip(s1, -25, 0) / 25.0\n",
|
||||
" s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n",
|
||||
" c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n",
|
||||
" return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" train_dataset = SEN12MSCR_TorchDataset(base_dir)\n",
|
||||
" train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n",
|
||||
" print(f\"✅ Data Ready: {len(train_dataset)} samples\")\n",
|
||||
"except Exception as e:\n",
|
||||
" print(f\"❌ Lỗi Data: {e}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a6357434-8b1c-48ef-ac84-7afadc4ce04a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# ==========================================\n",
|
||||
"# CELL 2: ĐỊNH NGHĨA MODEL CR-GAN\n",
|
||||
"# ==========================================\n",
|
||||
"class CR_Generator(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(CR_Generator, self).__init__()\n",
|
||||
" # Input: 13 kênh (Mây) + 2 kênh (Radar) = 15\n",
|
||||
" self.enc1 = nn.Sequential(nn.Conv2d(15, 64, 4, 2, 1), nn.LeakyReLU(0.2))\n",
|
||||
" self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))\n",
|
||||
" self.enc3 = nn.Sequential(nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2))\n",
|
||||
" \n",
|
||||
" self.dec1 = nn.Sequential(nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU())\n",
|
||||
" self.dec2 = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())\n",
|
||||
" self.dec3 = nn.ConvTranspose2d(64, 13, 4, 2, 1) # Output 13 kênh sạch\n",
|
||||
" self.tanh = nn.Tanh()\n",
|
||||
"\n",
|
||||
" def forward(self, s1, s2_cloudy):\n",
|
||||
" x = torch.cat([s1, s2_cloudy], dim=1)\n",
|
||||
" e1 = self.enc1(x)\n",
|
||||
" e2 = self.enc2(e1)\n",
|
||||
" e3 = self.enc3(e2)\n",
|
||||
" \n",
|
||||
" d1 = self.dec1(e3)\n",
|
||||
" d2 = self.dec2(d1 + e2) # Skip connection\n",
|
||||
" d3 = self.dec3(d2 + e1)\n",
|
||||
" return self.tanh(d3)\n",
|
||||
"\n",
|
||||
"class Discriminator(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Discriminator, self).__init__()\n",
|
||||
" # Input: 13 (Ảnh) + 15 (Điều kiện) = 28\n",
|
||||
" self.model = nn.Sequential(\n",
|
||||
" nn.Conv2d(28, 64, 4, 2, 1), nn.LeakyReLU(0.2),\n",
|
||||
" nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),\n",
|
||||
" nn.Conv2d(128, 1, 4, 1, 0), nn.Sigmoid()\n",
|
||||
" )\n",
|
||||
" def forward(self, img, cond1, cond2):\n",
|
||||
" return self.model(torch.cat([img, cond1, cond2], 1))\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "09276904-a4c9-4774-b171-32d2e577afa0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"🚀 Bắt đầu huấn luyện CR-GAN...\n",
|
||||
"Epoch 0 | Step 0 | Loss G: 72.3185 | Loss D: 0.7102\n",
|
||||
"Epoch 0 | Step 10 | Loss G: 37.1491 | Loss D: 0.5784\n",
|
||||
"Epoch 0 | Step 20 | Loss G: 19.9918 | Loss D: 0.5008\n",
|
||||
"Epoch 0 | Step 30 | Loss G: 21.9959 | Loss D: 0.6595\n",
|
||||
"Epoch 0 | Step 40 | Loss G: 10.0374 | Loss D: 0.5188\n",
|
||||
"Epoch 0 | Step 50 | Loss G: 7.7058 | Loss D: 0.5300\n",
|
||||
"Epoch 0 | Step 60 | Loss G: 6.9481 | Loss D: 0.5223\n",
|
||||
"Epoch 0 | Step 70 | Loss G: 12.7031 | Loss D: 0.6349\n",
|
||||
"Epoch 0 | Step 80 | Loss G: 6.1611 | Loss D: 0.6098\n",
|
||||
"Epoch 0 | Step 90 | Loss G: 8.4944 | Loss D: 0.4761\n",
|
||||
"Epoch 0 | Step 100 | Loss G: 16.4352 | Loss D: 1.2795\n",
|
||||
"Epoch 0 | Step 110 | Loss G: 6.7580 | Loss D: 0.6131\n",
|
||||
"Epoch 0 | Step 120 | Loss G: 6.7856 | Loss D: 0.5403\n",
|
||||
"Epoch 0 | Step 130 | Loss G: 6.2629 | Loss D: 0.6605\n",
|
||||
"Epoch 0 | Step 140 | Loss G: 5.4081 | Loss D: 0.5667\n",
|
||||
"Epoch 0 | Step 150 | Loss G: 5.5551 | Loss D: 0.5392\n",
|
||||
"Epoch 0 | Step 160 | Loss G: 5.4969 | Loss D: 0.5394\n",
|
||||
"Epoch 0 | Step 170 | Loss G: 4.6524 | Loss D: 0.5536\n",
|
||||
"Epoch 0 | Step 180 | Loss G: 9.2369 | Loss D: 0.5098\n",
|
||||
"Epoch 0 | Step 190 | Loss G: 10.5043 | Loss D: 0.4320\n",
|
||||
"Epoch 0 | Step 200 | Loss G: 5.1155 | Loss D: 0.5958\n",
|
||||
"Epoch 0 | Step 210 | Loss G: 5.1848 | Loss D: 0.5596\n",
|
||||
"Epoch 0 | Step 220 | Loss G: 6.5014 | Loss D: 0.5307\n",
|
||||
"Epoch 0 | Step 230 | Loss G: 6.4747 | Loss D: 0.6416\n",
|
||||
"Epoch 0 | Step 240 | Loss G: 4.6687 | Loss D: 0.5795\n",
|
||||
"Epoch 0 | Step 250 | Loss G: 3.8810 | Loss D: 0.6212\n",
|
||||
"Epoch 0 | Step 260 | Loss G: 4.5834 | Loss D: 0.5906\n",
|
||||
"Epoch 0 | Step 270 | Loss G: 6.8081 | Loss D: 0.4587\n",
|
||||
"Epoch 0 | Step 280 | Loss G: 6.4570 | Loss D: 0.5759\n",
|
||||
"Epoch 0 | Step 290 | Loss G: 4.5583 | Loss D: 0.5872\n",
|
||||
"Epoch 0 | Step 300 | Loss G: 5.8205 | Loss D: 0.6410\n",
|
||||
"Epoch 0 | Step 310 | Loss G: 4.7019 | Loss D: 0.6298\n",
|
||||
"Epoch 0 | Step 320 | Loss G: 4.3386 | Loss D: 0.5507\n",
|
||||
"Epoch 0 | Step 330 | Loss G: 4.2705 | Loss D: 0.5533\n",
|
||||
"Epoch 0 | Step 340 | Loss G: 6.9023 | Loss D: 0.8489\n",
|
||||
"Epoch 0 | Step 350 | Loss G: 8.0638 | Loss D: 0.7825\n",
|
||||
"Epoch 0 | Step 360 | Loss G: 4.4931 | Loss D: 0.6148\n",
|
||||
"Epoch 0 | Step 370 | Loss G: 3.7342 | Loss D: 0.6964\n",
|
||||
"Epoch 0 | Step 380 | Loss G: 4.2160 | Loss D: 0.6782\n",
|
||||
"Epoch 0 | Step 390 | Loss G: 5.1540 | Loss D: 0.6306\n",
|
||||
"Epoch 0 | Step 400 | Loss G: 4.5333 | Loss D: 0.6755\n",
|
||||
"Epoch 0 | Step 410 | Loss G: 4.0220 | Loss D: 0.6751\n",
|
||||
"Epoch 0 | Step 420 | Loss G: 5.7766 | Loss D: 0.7027\n",
|
||||
"Epoch 0 | Step 430 | Loss G: 4.0071 | Loss D: 0.5669\n",
|
||||
"Epoch 0 | Step 440 | Loss G: 4.1375 | Loss D: 0.4732\n",
|
||||
"Epoch 0 | Step 450 | Loss G: 3.8729 | Loss D: 0.5686\n",
|
||||
"Epoch 0 | Step 460 | Loss G: 3.5877 | Loss D: 0.6608\n",
|
||||
"Epoch 0 | Step 470 | Loss G: 5.6308 | Loss D: 0.7109\n",
|
||||
"Epoch 0 | Step 480 | Loss G: 5.3286 | Loss D: 0.7127\n",
|
||||
"Epoch 0 | Step 490 | Loss G: 2.9878 | Loss D: 0.6668\n",
|
||||
"Epoch 0 | Step 500 | Loss G: 4.0083 | Loss D: 0.7079\n",
|
||||
"Epoch 0 | Step 510 | Loss G: 3.6710 | Loss D: 0.7004\n",
|
||||
"Epoch 0 | Step 520 | Loss G: 5.5583 | Loss D: 0.6757\n",
|
||||
"Epoch 0 | Step 530 | Loss G: 4.2446 | Loss D: 0.6811\n",
|
||||
"Epoch 0 | Step 540 | Loss G: 3.7496 | Loss D: 0.6590\n",
|
||||
"Epoch 0 | Step 550 | Loss G: 7.4803 | Loss D: 0.8142\n",
|
||||
"Epoch 0 | Step 560 | Loss G: 3.7535 | Loss D: 0.6432\n",
|
||||
"Epoch 0 | Step 570 | Loss G: 2.8306 | Loss D: 0.5701\n",
|
||||
"Epoch 0 | Step 580 | Loss G: 2.9589 | Loss D: 0.6700\n",
|
||||
"Epoch 0 | Step 590 | Loss G: 4.7672 | Loss D: 0.7502\n",
|
||||
"Epoch 0 | Step 600 | Loss G: 3.1664 | Loss D: 0.6552\n",
|
||||
"Epoch 0 | Step 610 | Loss G: 3.2172 | Loss D: 0.6563\n",
|
||||
"Epoch 0 | Step 620 | Loss G: 3.1097 | Loss D: 0.7513\n",
|
||||
"Epoch 0 | Step 630 | Loss G: 3.8371 | Loss D: 0.7282\n",
|
||||
"Epoch 0 | Step 640 | Loss G: 3.1871 | Loss D: 0.6991\n",
|
||||
"Epoch 0 | Step 650 | Loss G: 3.4674 | Loss D: 0.6747\n",
|
||||
"Epoch 0 | Step 660 | Loss G: 3.5930 | Loss D: 0.6720\n",
|
||||
"Epoch 0 | Step 670 | Loss G: 2.9213 | Loss D: 0.6792\n",
|
||||
"Epoch 0 | Step 680 | Loss G: 3.5470 | Loss D: 0.6509\n",
|
||||
"Epoch 0 | Step 690 | Loss G: 3.6894 | Loss D: 0.6192\n",
|
||||
"Epoch 0 | Step 700 | Loss G: 3.6680 | Loss D: 0.5443\n",
|
||||
"Epoch 0 | Step 710 | Loss G: 3.6357 | Loss D: 0.5633\n",
|
||||
"Epoch 0 | Step 720 | Loss G: 3.2239 | Loss D: 0.6394\n",
|
||||
"Epoch 0 | Step 730 | Loss G: 2.5703 | Loss D: 0.6727\n",
|
||||
"Epoch 0 | Step 740 | Loss G: 3.2920 | Loss D: 0.6691\n",
|
||||
"Epoch 0 | Step 750 | Loss G: 3.1487 | Loss D: 0.7395\n",
|
||||
"Epoch 0 | Step 760 | Loss G: 3.2808 | Loss D: 0.6791\n",
|
||||
"Epoch 0 | Step 770 | Loss G: 3.6725 | Loss D: 0.7175\n",
|
||||
"Epoch 0 | Step 780 | Loss G: 2.8584 | Loss D: 0.6647\n",
|
||||
"Epoch 0 | Step 790 | Loss G: 3.5167 | Loss D: 0.7073\n",
|
||||
"Epoch 0 | Step 800 | Loss G: 2.9515 | Loss D: 0.6595\n",
|
||||
"Epoch 0 | Step 810 | Loss G: 3.0861 | Loss D: 0.7068\n",
|
||||
"Epoch 0 | Step 820 | Loss G: 5.8470 | Loss D: 0.9232\n",
|
||||
"Epoch 0 | Step 830 | Loss G: 3.4973 | Loss D: 0.6484\n",
|
||||
"Epoch 0 | Step 840 | Loss G: 3.3402 | Loss D: 0.6405\n",
|
||||
"Epoch 0 | Step 850 | Loss G: 2.9324 | Loss D: 0.6842\n",
|
||||
"Epoch 0 | Step 860 | Loss G: 3.3249 | Loss D: 0.5885\n",
|
||||
"Epoch 0 | Step 870 | Loss G: 3.2350 | Loss D: 0.7483\n",
|
||||
"Epoch 0 | Step 880 | Loss G: 2.9857 | Loss D: 0.6709\n",
|
||||
"Epoch 0 | Step 890 | Loss G: 3.1763 | Loss D: 0.5673\n",
|
||||
"Epoch 0 | Step 900 | Loss G: 3.0522 | Loss D: 0.7522\n",
|
||||
"Epoch 0 | Step 910 | Loss G: 2.5046 | Loss D: 0.7584\n",
|
||||
"Epoch 0 | Step 920 | Loss G: 3.5277 | Loss D: 0.7532\n",
|
||||
"Epoch 0 | Step 930 | Loss G: 2.7921 | Loss D: 0.6860\n",
|
||||
"Epoch 0 | Step 940 | Loss G: 4.3079 | Loss D: 0.7080\n",
|
||||
"Epoch 0 | Step 950 | Loss G: 2.7965 | Loss D: 0.6728\n",
|
||||
"Epoch 0 | Step 960 | Loss G: 2.7622 | Loss D: 0.6773\n",
|
||||
"Epoch 0 | Step 970 | Loss G: 2.8015 | Loss D: 0.6507\n",
|
||||
"Epoch 0 | Step 980 | Loss G: 2.0900 | Loss D: 0.6775\n",
|
||||
"Epoch 0 | Step 990 | Loss G: 2.9263 | Loss D: 0.6361\n",
|
||||
"Epoch 0 | Step 1000 | Loss G: 3.5840 | Loss D: 0.6793\n",
|
||||
"Epoch 0 | Step 1010 | Loss G: 3.0541 | Loss D: 0.7161\n",
|
||||
"Epoch 0 | Step 1020 | Loss G: 2.6563 | Loss D: 0.6252\n",
|
||||
"Epoch 0 | Step 1030 | Loss G: 2.9624 | Loss D: 0.6389\n",
|
||||
"Epoch 0 | Step 1040 | Loss G: 4.0138 | Loss D: 0.5921\n",
|
||||
"Epoch 0 | Step 1050 | Loss G: 3.4586 | Loss D: 0.6120\n",
|
||||
"Epoch 0 | Step 1060 | Loss G: 3.6436 | Loss D: 0.5100\n",
|
||||
"Epoch 0 | Step 1070 | Loss G: 2.5240 | Loss D: 0.7917\n",
|
||||
"Epoch 0 | Step 1080 | Loss G: 2.9808 | Loss D: 0.5955\n",
|
||||
"Epoch 0 | Step 1090 | Loss G: 3.5152 | Loss D: 0.6790\n",
|
||||
"Epoch 0 | Step 1100 | Loss G: 4.8162 | Loss D: 0.9497\n",
|
||||
"Epoch 0 | Step 1110 | Loss G: 4.0998 | Loss D: 0.6457\n",
|
||||
"Epoch 0 | Step 1120 | Loss G: 3.7265 | Loss D: 0.7151\n",
|
||||
"Epoch 0 | Step 1130 | Loss G: 2.8072 | Loss D: 0.7019\n",
|
||||
"Epoch 0 | Step 1140 | Loss G: 3.8286 | Loss D: 0.7170\n",
|
||||
"Epoch 0 | Step 1150 | Loss G: 2.3177 | Loss D: 0.6822\n",
|
||||
"Epoch 0 | Step 1160 | Loss G: 2.3291 | Loss D: 0.6801\n",
|
||||
"Epoch 0 | Step 1170 | Loss G: 2.5152 | Loss D: 0.6733\n",
|
||||
"Epoch 0 | Step 1180 | Loss G: 2.1710 | Loss D: 0.6602\n",
|
||||
"Epoch 0 | Step 1190 | Loss G: 3.0357 | Loss D: 0.6695\n",
|
||||
"Epoch 0 | Step 1200 | Loss G: 2.9576 | Loss D: 0.6414\n",
|
||||
"Epoch 0 | Step 1210 | Loss G: 3.0184 | Loss D: 0.6573\n",
|
||||
"Epoch 0 | Step 1220 | Loss G: 2.2493 | Loss D: 0.7150\n",
|
||||
"Epoch 0 | Step 1230 | Loss G: 2.7642 | Loss D: 0.7177\n",
|
||||
"Epoch 0 | Step 1240 | Loss G: 2.4553 | Loss D: 0.6698\n",
|
||||
"Epoch 0 | Step 1250 | Loss G: 2.5607 | Loss D: 0.6953\n",
|
||||
"Epoch 0 | Step 1260 | Loss G: 2.9071 | Loss D: 0.6687\n",
|
||||
"Epoch 0 | Step 1270 | Loss G: 2.6420 | Loss D: 0.6780\n",
|
||||
"Epoch 0 | Step 1280 | Loss G: 2.6417 | Loss D: 0.6567\n",
|
||||
"Epoch 0 | Step 1290 | Loss G: 2.5739 | Loss D: 0.6899\n",
|
||||
"Epoch 0 | Step 1300 | Loss G: 2.9631 | Loss D: 0.6819\n",
|
||||
"Epoch 0 | Step 1310 | Loss G: 2.5568 | Loss D: 0.6575\n",
|
||||
"Epoch 0 | Step 1320 | Loss G: 2.2896 | Loss D: 0.6679\n",
|
||||
"Epoch 0 | Step 1330 | Loss G: 2.9179 | Loss D: 0.6886\n",
|
||||
"Epoch 0 | Step 1340 | Loss G: 2.8625 | Loss D: 0.6136\n",
|
||||
"Epoch 0 | Step 1350 | Loss G: 2.6456 | Loss D: 0.6368\n",
|
||||
"Epoch 0 | Step 1360 | Loss G: 2.9531 | Loss D: 0.6491\n",
|
||||
"Epoch 0 | Step 1370 | Loss G: 2.5252 | Loss D: 0.7159\n",
|
||||
"Epoch 0 | Step 1380 | Loss G: 3.9087 | Loss D: 0.7079\n",
|
||||
"Epoch 0 | Step 1390 | Loss G: 2.3737 | Loss D: 0.6529\n",
|
||||
"Epoch 0 | Step 1400 | Loss G: 2.3664 | Loss D: 0.6481\n",
|
||||
"Epoch 0 | Step 1410 | Loss G: 2.7657 | Loss D: 0.5882\n",
|
||||
"Epoch 0 | Step 1420 | Loss G: 2.7599 | Loss D: 0.6327\n",
|
||||
"Epoch 0 | Step 1430 | Loss G: 2.6549 | Loss D: 0.8037\n",
|
||||
"Epoch 0 | Step 1440 | Loss G: 3.3464 | Loss D: 0.5944\n",
|
||||
"Epoch 0 | Step 1450 | Loss G: 2.6873 | Loss D: 0.7476\n",
|
||||
"Epoch 0 | Step 1460 | Loss G: 2.5144 | Loss D: 0.6655\n",
|
||||
"Epoch 0 | Step 1470 | Loss G: 3.1850 | Loss D: 0.7270\n",
|
||||
"Epoch 0 | Step 1480 | Loss G: 5.0466 | Loss D: 0.6669\n",
|
||||
"Epoch 0 | Step 1490 | Loss G: 2.7675 | Loss D: 0.6640\n",
|
||||
"Epoch 0 | Step 1500 | Loss G: 2.3371 | Loss D: 0.6604\n",
|
||||
"Epoch 0 | Step 1510 | Loss G: 2.6685 | Loss D: 0.6510\n",
|
||||
"Epoch 0 | Step 1520 | Loss G: 3.0199 | Loss D: 0.6744\n",
|
||||
"Epoch 0 | Step 1530 | Loss G: 2.9492 | Loss D: 0.7057\n",
|
||||
"Epoch 0 | Step 1540 | Loss G: 2.8013 | Loss D: 0.7118\n",
|
||||
"Epoch 0 | Step 1550 | Loss G: 2.8135 | Loss D: 0.6324\n",
|
||||
"Epoch 0 | Step 1560 | Loss G: 3.1775 | Loss D: 0.6969\n",
|
||||
"Epoch 0 | Step 1570 | Loss G: 3.1974 | Loss D: 0.6341\n",
|
||||
"Epoch 0 | Step 1580 | Loss G: 2.4508 | Loss D: 0.6909\n",
|
||||
"Epoch 0 | Step 1590 | Loss G: 3.2501 | Loss D: 0.7138\n",
|
||||
"Epoch 0 | Step 1600 | Loss G: 3.5432 | Loss D: 0.7108\n",
|
||||
"Epoch 0 | Step 1610 | Loss G: 2.5303 | Loss D: 0.6710\n",
|
||||
"Epoch 0 | Step 1620 | Loss G: 2.8239 | Loss D: 0.6950\n",
|
||||
"Epoch 0 | Step 1630 | Loss G: 2.2003 | Loss D: 0.6568\n",
|
||||
"Epoch 0 | Step 1640 | Loss G: 2.6544 | Loss D: 0.7159\n",
|
||||
"Epoch 0 | Step 1650 | Loss G: 2.3602 | Loss D: 0.6818\n",
|
||||
"Epoch 0 | Step 1660 | Loss G: 3.1921 | Loss D: 0.6072\n",
|
||||
"Epoch 0 | Step 1670 | Loss G: 2.5231 | Loss D: 0.7427\n",
|
||||
"Epoch 0 | Step 1680 | Loss G: 2.9751 | Loss D: 0.6027\n",
|
||||
"Epoch 0 | Step 1690 | Loss G: 2.3957 | Loss D: 0.6853\n",
|
||||
"Epoch 0 | Step 1700 | Loss G: 2.5812 | Loss D: 0.6981\n",
|
||||
"Epoch 0 | Step 1710 | Loss G: 2.7025 | Loss D: 0.7519\n",
|
||||
"Epoch 0 | Step 1720 | Loss G: 2.9266 | Loss D: 0.6202\n",
|
||||
"Epoch 0 | Step 1730 | Loss G: 2.7670 | Loss D: 0.6953\n",
|
||||
"Epoch 0 | Step 1740 | Loss G: 2.7067 | Loss D: 0.5941\n",
|
||||
"Epoch 0 | Step 1750 | Loss G: 3.5343 | Loss D: 0.7225\n",
|
||||
"Epoch 0 | Step 1760 | Loss G: 2.4662 | Loss D: 0.6810\n",
|
||||
"Epoch 0 | Step 1770 | Loss G: 2.6970 | Loss D: 0.6763\n",
|
||||
"Epoch 0 | Step 1780 | Loss G: 3.1359 | Loss D: 0.6382\n",
|
||||
"Epoch 0 | Step 1790 | Loss G: 2.6589 | Loss D: 0.6323\n",
|
||||
"Epoch 0 | Step 1800 | Loss G: 2.6840 | Loss D: 0.6116\n",
|
||||
"Epoch 0 | Step 1810 | Loss G: 3.0749 | Loss D: 0.7142\n",
|
||||
"Epoch 0 | Step 1820 | Loss G: 2.4837 | Loss D: 0.7109\n",
|
||||
"Epoch 0 | Step 1830 | Loss G: 2.9848 | Loss D: 0.6832\n",
|
||||
"Epoch 0 | Step 1840 | Loss G: 3.1044 | Loss D: 0.5613\n",
|
||||
"Epoch 0 | Step 1850 | Loss G: 2.6458 | Loss D: 1.0100\n",
|
||||
"Epoch 0 | Step 1860 | Loss G: 2.4200 | Loss D: 0.7013\n",
|
||||
"Epoch 0 | Step 1870 | Loss G: 2.4586 | Loss D: 0.7051\n",
|
||||
"Epoch 0 | Step 1880 | Loss G: 2.3607 | Loss D: 0.7131\n",
|
||||
"Epoch 0 | Step 1890 | Loss G: 2.6252 | Loss D: 0.7187\n",
|
||||
"Epoch 0 | Step 1900 | Loss G: 2.4684 | Loss D: 0.6961\n",
|
||||
"Epoch 0 | Step 1910 | Loss G: 2.2537 | Loss D: 0.6876\n",
|
||||
"Epoch 0 | Step 1920 | Loss G: 2.4095 | Loss D: 0.6838\n",
|
||||
"Epoch 0 | Step 1930 | Loss G: 2.8157 | Loss D: 0.6736\n",
|
||||
"Epoch 0 | Step 1940 | Loss G: 2.4893 | Loss D: 0.6775\n",
|
||||
"Epoch 0 | Step 1950 | Loss G: 3.3445 | Loss D: 0.6962\n",
|
||||
"Epoch 0 | Step 1960 | Loss G: 2.2544 | Loss D: 0.6767\n",
|
||||
"Epoch 0 | Step 1970 | Loss G: 2.6515 | Loss D: 0.6322\n",
|
||||
"Epoch 0 | Step 1980 | Loss G: 2.2835 | Loss D: 0.6343\n",
|
||||
"Epoch 0 | Step 1990 | Loss G: 2.4262 | Loss D: 0.6298\n",
|
||||
"Epoch 0 | Step 2000 | Loss G: 2.7222 | Loss D: 0.6135\n",
|
||||
"Epoch 0 | Step 2010 | Loss G: 2.3119 | Loss D: 0.6223\n",
|
||||
"Epoch 0 | Step 2020 | Loss G: 2.8205 | Loss D: 0.6000\n",
|
||||
"Epoch 0 | Step 2030 | Loss G: 2.8469 | Loss D: 0.5928\n",
|
||||
"Epoch 0 | Step 2040 | Loss G: 3.0901 | Loss D: 0.6076\n",
|
||||
"Epoch 0 | Step 2050 | Loss G: 2.4228 | Loss D: 0.6983\n",
|
||||
"Epoch 0 | Step 2060 | Loss G: 2.7112 | Loss D: 0.7587\n",
|
||||
"Epoch 0 | Step 2070 | Loss G: 2.6442 | Loss D: 0.7504\n",
|
||||
"Epoch 0 | Step 2080 | Loss G: 2.4068 | Loss D: 0.7240\n",
|
||||
"Epoch 0 | Step 2090 | Loss G: 2.9278 | Loss D: 0.6676\n",
|
||||
"Epoch 0 | Step 2100 | Loss G: 2.2134 | Loss D: 0.7018\n",
|
||||
"Epoch 0 | Step 2110 | Loss G: 2.8740 | Loss D: 0.6760\n",
|
||||
"Epoch 0 | Step 2120 | Loss G: 2.4062 | Loss D: 0.6768\n",
|
||||
"Epoch 0 | Step 2130 | Loss G: 2.2714 | Loss D: 0.6880\n",
|
||||
"Epoch 0 | Step 2140 | Loss G: 2.3448 | Loss D: 0.6871\n",
|
||||
"Epoch 0 | Step 2150 | Loss G: 2.1828 | Loss D: 0.6292\n",
|
||||
"Epoch 0 | Step 2160 | Loss G: 2.4968 | Loss D: 0.6260\n",
|
||||
"Epoch 0 | Step 2170 | Loss G: 2.2925 | Loss D: 0.6720\n",
|
||||
"Epoch 0 | Step 2180 | Loss G: 2.1419 | Loss D: 0.6476\n",
|
||||
"Epoch 0 | Step 2190 | Loss G: 2.4403 | Loss D: 0.6648\n",
|
||||
"Epoch 0 | Step 2200 | Loss G: 3.3253 | Loss D: 0.5541\n",
|
||||
"Epoch 0 | Step 2210 | Loss G: 2.3137 | Loss D: 0.6822\n",
|
||||
"Epoch 0 | Step 2220 | Loss G: 3.3092 | Loss D: 0.6984\n",
|
||||
"Epoch 0 | Step 2230 | Loss G: 2.7817 | Loss D: 0.6168\n",
|
||||
"Epoch 0 | Step 2240 | Loss G: 2.4387 | Loss D: 0.6215\n",
|
||||
"Epoch 0 | Step 2250 | Loss G: 2.7103 | Loss D: 0.5489\n",
|
||||
"Epoch 0 | Step 2260 | Loss G: 2.8750 | Loss D: 0.7436\n",
|
||||
"Epoch 0 | Step 2270 | Loss G: 3.2172 | Loss D: 0.7419\n",
|
||||
"Epoch 0 | Step 2280 | Loss G: 2.2944 | Loss D: 0.7735\n",
|
||||
"Epoch 0 | Step 2290 | Loss G: 2.9406 | Loss D: 0.6687\n",
|
||||
"Epoch 0 | Step 2300 | Loss G: 2.3666 | Loss D: 0.7192\n",
|
||||
"Epoch 0 | Step 2310 | Loss G: 2.0075 | Loss D: 0.7016\n",
|
||||
"Epoch 0 | Step 2320 | Loss G: 2.5500 | Loss D: 0.7104\n",
|
||||
"Epoch 0 | Step 2330 | Loss G: 2.2471 | Loss D: 0.6831\n",
|
||||
"Epoch 0 | Step 2340 | Loss G: 2.2565 | Loss D: 0.6743\n",
|
||||
"Epoch 0 | Step 2350 | Loss G: 2.1140 | Loss D: 0.6982\n",
|
||||
"Epoch 0 | Step 2360 | Loss G: 2.2506 | Loss D: 0.6729\n",
|
||||
"Epoch 0 | Step 2370 | Loss G: 2.4593 | Loss D: 0.6737\n",
|
||||
"Epoch 0 | Step 2380 | Loss G: 2.3511 | Loss D: 0.6809\n",
|
||||
"Epoch 0 | Step 2390 | Loss G: 2.4674 | Loss D: 0.6723\n",
|
||||
"Epoch 0 | Step 2400 | Loss G: 2.7037 | Loss D: 0.6395\n",
|
||||
"Epoch 0 | Step 2410 | Loss G: 2.3319 | Loss D: 0.7071\n",
|
||||
"Epoch 0 | Step 2420 | Loss G: 2.4786 | Loss D: 0.6931\n",
|
||||
"Epoch 0 | Step 2430 | Loss G: 3.5919 | Loss D: 0.7343\n",
|
||||
"Epoch 0 | Step 2440 | Loss G: 2.3883 | Loss D: 0.7144\n",
|
||||
"Epoch 0 | Step 2450 | Loss G: 2.1943 | Loss D: 0.6796\n",
|
||||
"Epoch 0 | Step 2460 | Loss G: 2.6514 | Loss D: 0.7204\n",
|
||||
"Epoch 0 | Step 2470 | Loss G: 2.0497 | Loss D: 0.6569\n",
|
||||
"Epoch 0 | Step 2480 | Loss G: 2.3691 | Loss D: 0.7213\n",
|
||||
"Epoch 0 | Step 2490 | Loss G: 2.7880 | Loss D: 0.6707\n",
|
||||
"Epoch 0 | Step 2500 | Loss G: 2.9353 | Loss D: 0.7046\n",
|
||||
"Epoch 0 | Step 2510 | Loss G: 2.0710 | Loss D: 0.7094\n",
|
||||
"Epoch 0 | Step 2520 | Loss G: 2.1715 | Loss D: 0.6868\n",
|
||||
"Epoch 0 | Step 2530 | Loss G: 2.9305 | Loss D: 0.7008\n",
|
||||
"Epoch 0 | Step 2540 | Loss G: 3.2625 | Loss D: 0.7100\n",
|
||||
"Epoch 0 | Step 2550 | Loss G: 2.4526 | Loss D: 0.6615\n",
|
||||
"Epoch 0 | Step 2560 | Loss G: 2.1840 | Loss D: 0.6649\n",
|
||||
"Epoch 0 | Step 2570 | Loss G: 2.3944 | Loss D: 0.6434\n",
|
||||
"Epoch 0 | Step 2580 | Loss G: 2.3077 | Loss D: 0.6637\n",
|
||||
"Epoch 0 | Step 2590 | Loss G: 2.2480 | Loss D: 0.6865\n",
|
||||
"Epoch 0 | Step 2600 | Loss G: 2.2794 | Loss D: 0.6606\n",
|
||||
"Epoch 0 | Step 2610 | Loss G: 2.2931 | Loss D: 0.6991\n",
|
||||
"Epoch 0 | Step 2620 | Loss G: 2.1852 | Loss D: 0.6694\n",
|
||||
"Epoch 0 | Step 2630 | Loss G: 2.4422 | Loss D: 0.6534\n",
|
||||
"Epoch 0 | Step 2640 | Loss G: 2.4345 | Loss D: 0.7486\n",
|
||||
"Epoch 0 | Step 2650 | Loss G: 2.1540 | Loss D: 0.7449\n",
|
||||
"Epoch 0 | Step 2660 | Loss G: 2.2600 | Loss D: 0.6953\n",
|
||||
"Epoch 0 | Step 2670 | Loss G: 2.6546 | Loss D: 0.6010\n",
|
||||
"Epoch 0 | Step 2680 | Loss G: 2.1378 | Loss D: 0.7238\n",
|
||||
"Epoch 0 | Step 2690 | Loss G: 2.4992 | Loss D: 0.6710\n",
|
||||
"Epoch 0 | Step 2700 | Loss G: 2.4387 | Loss D: 0.6570\n",
|
||||
"Epoch 0 | Step 2710 | Loss G: 2.7910 | Loss D: 0.7033\n",
|
||||
"Epoch 0 | Step 2720 | Loss G: 2.0291 | Loss D: 0.6740\n",
|
||||
"Epoch 0 | Step 2730 | Loss G: 2.3530 | Loss D: 0.6880\n",
|
||||
"Epoch 0 | Step 2740 | Loss G: 2.4699 | Loss D: 0.6142\n",
|
||||
"Epoch 0 | Step 2750 | Loss G: 2.2944 | Loss D: 0.7012\n",
|
||||
"Epoch 0 | Step 2760 | Loss G: 2.3153 | Loss D: 0.6831\n",
|
||||
"Epoch 0 | Step 2770 | Loss G: 1.9557 | Loss D: 0.6957\n",
|
||||
"Epoch 0 | Step 2780 | Loss G: 1.9761 | Loss D: 0.7643\n",
|
||||
"Epoch 0 | Step 2790 | Loss G: 2.0589 | Loss D: 0.7031\n",
|
||||
"Epoch 0 | Step 2800 | Loss G: 2.1808 | Loss D: 0.6616\n",
|
||||
"Epoch 0 | Step 2810 | Loss G: 3.0882 | Loss D: 0.6470\n",
|
||||
"Epoch 0 | Step 2820 | Loss G: 2.2611 | Loss D: 0.7233\n",
|
||||
"Epoch 0 | Step 2830 | Loss G: 2.2900 | Loss D: 0.6960\n",
|
||||
"Epoch 0 | Step 2840 | Loss G: 2.8689 | Loss D: 0.6929\n",
|
||||
"Epoch 0 | Step 2850 | Loss G: 2.6262 | Loss D: 0.6793\n",
|
||||
"Epoch 0 | Step 2860 | Loss G: 2.0969 | Loss D: 0.6568\n",
|
||||
"Epoch 0 | Step 2870 | Loss G: 2.0100 | Loss D: 0.6813\n",
|
||||
"Epoch 0 | Step 2880 | Loss G: 2.4225 | Loss D: 0.6647\n",
|
||||
"Epoch 0 | Step 2890 | Loss G: 2.3839 | Loss D: 0.6293\n",
|
||||
"Epoch 0 | Step 2900 | Loss G: 2.2553 | Loss D: 0.6671\n",
|
||||
"Epoch 0 | Step 2910 | Loss G: 2.2506 | Loss D: 0.6316\n",
|
||||
"Epoch 0 | Step 2920 | Loss G: 2.2017 | Loss D: 0.6144\n",
|
||||
"Epoch 0 | Step 2930 | Loss G: 2.3828 | Loss D: 0.7388\n",
|
||||
"Epoch 0 | Step 2940 | Loss G: 3.6152 | Loss D: 0.5791\n",
|
||||
"Epoch 0 | Step 2950 | Loss G: 2.4575 | Loss D: 0.6609\n",
|
||||
"Epoch 0 | Step 2960 | Loss G: 2.5841 | Loss D: 0.8060\n",
|
||||
"Epoch 0 | Step 2970 | Loss G: 2.6008 | Loss D: 0.6548\n",
|
||||
"Epoch 0 | Step 2980 | Loss G: 2.9643 | Loss D: 0.6375\n",
|
||||
"Epoch 0 | Step 2990 | Loss G: 2.0817 | Loss D: 0.6882\n",
|
||||
"Epoch 0 | Step 3000 | Loss G: 2.4347 | Loss D: 0.7490\n",
|
||||
"Epoch 0 | Step 3010 | Loss G: 2.7559 | Loss D: 0.7129\n",
|
||||
"Epoch 0 | Step 3020 | Loss G: 2.1537 | Loss D: 0.6953\n",
|
||||
"Epoch 0 | Step 3030 | Loss G: 2.2344 | Loss D: 0.7369\n",
|
||||
"Epoch 0 | Step 3040 | Loss G: 2.3036 | Loss D: 0.6224\n",
|
||||
"Epoch 0 | Step 3050 | Loss G: 2.3935 | Loss D: 0.7120\n",
|
||||
"Epoch 0 | Step 3060 | Loss G: 2.3280 | Loss D: 0.6134\n",
|
||||
"Epoch 0 | Step 3070 | Loss G: 2.5429 | Loss D: 0.7636\n",
|
||||
"Epoch 0 | Step 3080 | Loss G: 2.3605 | Loss D: 0.6986\n",
|
||||
"Epoch 0 | Step 3090 | Loss G: 2.1876 | Loss D: 0.7954\n",
|
||||
"Epoch 0 | Step 3100 | Loss G: 2.3063 | Loss D: 0.6299\n",
|
||||
"Epoch 0 | Step 3110 | Loss G: 2.3226 | Loss D: 0.8436\n",
|
||||
"Epoch 0 | Step 3120 | Loss G: 2.9968 | Loss D: 0.7128\n",
|
||||
"Epoch 0 | Step 3130 | Loss G: 2.0452 | Loss D: 0.7180\n",
|
||||
"Epoch 0 | Step 3140 | Loss G: 1.9460 | Loss D: 0.7010\n",
|
||||
"Epoch 0 | Step 3150 | Loss G: 2.7439 | Loss D: 0.6920\n",
|
||||
"Epoch 0 | Step 3160 | Loss G: 1.8548 | Loss D: 0.7069\n",
|
||||
"Epoch 0 | Step 3170 | Loss G: 2.6291 | Loss D: 0.6260\n",
|
||||
"Epoch 0 | Step 3180 | Loss G: 3.1567 | Loss D: 0.6356\n",
|
||||
"Epoch 0 | Step 3190 | Loss G: 1.9601 | Loss D: 0.6841\n",
|
||||
"Epoch 0 | Step 3200 | Loss G: 2.3887 | Loss D: 0.6708\n",
|
||||
"Epoch 0 | Step 3210 | Loss G: 2.0190 | Loss D: 0.6558\n",
|
||||
"Epoch 0 | Step 3220 | Loss G: 2.5288 | Loss D: 0.6406\n",
|
||||
"Epoch 0 | Step 3230 | Loss G: 2.0141 | Loss D: 0.6553\n",
|
||||
"Epoch 0 | Step 3240 | Loss G: 3.5275 | Loss D: 0.6934\n",
|
||||
"Epoch 0 | Step 3250 | Loss G: 2.5778 | Loss D: 0.5621\n",
|
||||
"Epoch 0 | Step 3260 | Loss G: 2.4638 | Loss D: 0.6106\n",
|
||||
"Epoch 0 | Step 3270 | Loss G: 3.1237 | Loss D: 0.6939\n",
|
||||
"Epoch 0 | Step 3280 | Loss G: 2.8655 | Loss D: 0.6382\n",
|
||||
"Epoch 0 | Step 3290 | Loss G: 2.6247 | Loss D: 0.6818\n",
|
||||
"Epoch 0 | Step 3300 | Loss G: 2.2132 | Loss D: 0.6349\n",
|
||||
"Epoch 0 | Step 3310 | Loss G: 2.3956 | Loss D: 0.7407\n",
|
||||
"Epoch 0 | Step 3320 | Loss G: 2.8687 | Loss D: 0.4859\n",
|
||||
"Epoch 0 | Step 3330 | Loss G: 2.4718 | Loss D: 0.7202\n",
|
||||
"Epoch 0 | Step 3340 | Loss G: 1.9500 | Loss D: 0.6743\n",
|
||||
"Epoch 0 | Step 3350 | Loss G: 2.4359 | Loss D: 0.6221\n",
|
||||
"Epoch 0 | Step 3360 | Loss G: 2.1875 | Loss D: 0.7105\n",
|
||||
"Epoch 0 | Step 3370 | Loss G: 2.2411 | Loss D: 0.5774\n",
|
||||
"Epoch 0 | Step 3380 | Loss G: 2.2889 | Loss D: 0.6689\n",
|
||||
"Epoch 0 | Step 3390 | Loss G: 2.5175 | Loss D: 0.7022\n",
|
||||
"Epoch 0 | Step 3400 | Loss G: 2.1678 | Loss D: 0.6610\n",
|
||||
"Epoch 0 | Step 3410 | Loss G: 1.8933 | Loss D: 0.7071\n",
|
||||
"Epoch 0 | Step 3420 | Loss G: 2.1478 | Loss D: 0.6712\n",
|
||||
"Epoch 0 | Step 3430 | Loss G: 2.5614 | Loss D: 0.6450\n",
|
||||
"Epoch 0 | Step 3440 | Loss G: 2.5616 | Loss D: 0.6267\n",
|
||||
"Epoch 0 | Step 3450 | Loss G: 2.2356 | Loss D: 0.6358\n",
|
||||
"Epoch 0 | Step 3460 | Loss G: 3.6788 | Loss D: 0.6079\n",
|
||||
"Epoch 0 | Step 3470 | Loss G: 2.4702 | Loss D: 0.6131\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ==========================================\n",
|
||||
"# CELL 3: HUẤN LUYỆN & LƯU MODEL\n",
|
||||
"# ==========================================\n",
|
||||
"if 'train_loader' in locals():\n",
|
||||
" generator = CR_Generator().to(device)\n",
|
||||
" discriminator = Discriminator().to(device)\n",
|
||||
" \n",
|
||||
" criterion_GAN = nn.BCELoss()\n",
|
||||
" criterion_L1 = nn.L1Loss()\n",
|
||||
" opt_g = torch.optim.Adam(generator.parameters(), lr=0.0002)\n",
|
||||
" opt_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002)\n",
|
||||
"\n",
|
||||
" print(\"🚀 Bắt đầu huấn luyện CR-GAN...\")\n",
|
||||
" epochs = 5 # Demo chạy 5 vòng (tăng lên nếu muốn train thật)\n",
|
||||
" \n",
|
||||
" for epoch in range(epochs):\n",
|
||||
" for i, (s1, s2, cloud) in enumerate(train_loader):\n",
|
||||
" s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n",
|
||||
" \n",
|
||||
" # --- Train Generator ---\n",
|
||||
" opt_g.zero_grad()\n",
|
||||
" fake = generator(s1, cloud)\n",
|
||||
" pred_fake = discriminator(fake, s1, cloud)\n",
|
||||
" loss_g = criterion_GAN(pred_fake, torch.ones_like(pred_fake)) + 100 * criterion_L1(fake, s2)\n",
|
||||
" loss_g.backward()\n",
|
||||
" opt_g.step()\n",
|
||||
" \n",
|
||||
" # --- Train Discriminator ---\n",
|
||||
" opt_d.zero_grad()\n",
|
||||
" pred_real = discriminator(s2, s1, cloud)\n",
|
||||
" pred_fake_d = discriminator(fake.detach(), s1, cloud)\n",
|
||||
" loss_d = 0.5 * (criterion_GAN(pred_real, torch.ones_like(pred_real)) + \n",
|
||||
" criterion_GAN(pred_fake_d, torch.zeros_like(pred_fake_d)))\n",
|
||||
" loss_d.backward()\n",
|
||||
" opt_d.step()\n",
|
||||
" \n",
|
||||
" if i % 10 == 0: \n",
|
||||
" print(f\"Epoch {epoch} | Step {i} | Loss G: {loss_g.item():.4f} | Loss D: {loss_d.item():.4f}\")\n",
|
||||
"\n",
|
||||
" # --- LƯU MODEL ---\n",
|
||||
" save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n",
|
||||
" os.makedirs(save_dir, exist_ok=True)\n",
|
||||
" torch.save(generator.state_dict(), os.path.join(save_dir, \"CRGAN_generator.pth\"))\n",
|
||||
" print(f\"\\n💾 Đã lưu model tại: {save_dir}/CRGAN_generator.pth\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "87873ca2-6a64-4fc9-8687-8c8555b99ed4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
183
Train_GLF-CR.ipynb
Normal file
183
Train_GLF-CR.ipynb
Normal file
@@ -0,0 +1,183 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "a148f700-b14e-4dc3-a203-9e2f93955587",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Working Directory: /home/jovyan/cloud_train/dataset\n",
|
||||
"✅ Đã import thành công (dùng đường dẫn tuyệt đối)!\n",
|
||||
"Using device: cpu\n",
|
||||
"✅ Data Ready: 29117 samples\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ==========================================\n",
|
||||
"# CELL 1: SETUP & LOAD DỮ LIỆU\n",
|
||||
"# ==========================================\n",
|
||||
"import os, sys, torch, numpy as np\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"\n",
|
||||
"# Fix Import Path\n",
|
||||
"sys.path.append(os.getcwd())\n",
|
||||
"sys.path.append(\"/home/jovyan/cloud_train\")\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
|
||||
"except ImportError:\n",
|
||||
" raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n",
|
||||
"\n",
|
||||
"class SEN12MSCR_TorchDataset(Dataset):\n",
|
||||
" def __init__(self, base_dir, season=Seasons.SPRING):\n",
|
||||
" self.loader = SEN12MSCRDataset(base_dir)\n",
|
||||
" self.season = season\n",
|
||||
" self.season_ids = self.loader.get_season_ids(season)\n",
|
||||
" self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n",
|
||||
" def __len__(self): return len(self.samples)\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" sid, pid = self.samples[idx]\n",
|
||||
" s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n",
|
||||
" s1 = np.clip(s1, -25, 0) / 25.0\n",
|
||||
" s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n",
|
||||
" c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n",
|
||||
" return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n",
|
||||
"\n",
|
||||
"train_dataset = SEN12MSCR_TorchDataset(base_dir)\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n",
|
||||
"print(f\"✅ GLF-CR Data Ready: {len(train_dataset)} samples\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "cdb8ff0f-ea65-4ed4-8f5b-b99b05c602a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# ==========================================\n",
|
||||
"# CELL 2: MODEL GLF-CR\n",
|
||||
"# ==========================================\n",
|
||||
"class FusionBlock(nn.Module):\n",
|
||||
" def __init__(self, channels):\n",
|
||||
" super().__init__()\n",
|
||||
" self.conv = nn.Sequential(\n",
|
||||
" nn.Conv2d(channels * 2, channels, 1),\n",
|
||||
" nn.BatchNorm2d(channels), nn.ReLU()\n",
|
||||
" )\n",
|
||||
" def forward(self, x1, x2):\n",
|
||||
" return self.conv(torch.cat([x1, x2], dim=1))\n",
|
||||
"\n",
|
||||
"class GLF_Generator(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(GLF_Generator, self).__init__()\n",
|
||||
" # 2 Nhánh riêng biệt\n",
|
||||
" self.enc_s1 = nn.Sequential(nn.Conv2d(2, 64, 3, 1, 1), nn.ReLU()) # Radar\n",
|
||||
" self.enc_s2 = nn.Sequential(nn.Conv2d(13, 64, 3, 1, 1), nn.ReLU()) # Mây\n",
|
||||
" self.fuse = FusionBlock(64)\n",
|
||||
" self.dec = nn.Sequential(\n",
|
||||
" nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(),\n",
|
||||
" nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(),\n",
|
||||
" nn.Conv2d(64, 13, 3, 1, 1)\n",
|
||||
" )\n",
|
||||
" self.tanh = nn.Tanh()\n",
|
||||
"\n",
|
||||
" def forward(self, s1, s2_cloudy):\n",
|
||||
" f1 = self.enc_s1(s1)\n",
|
||||
" f2 = self.enc_s2(s2_cloudy)\n",
|
||||
" f_fused = self.fuse(f1, f2)\n",
|
||||
" return self.tanh(self.dec(f_fused))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "b124a548-a7fe-46e7-939d-ae3ff5e5b83c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"🚀 Bắt đầu train GLF-CR...\n",
|
||||
"Step 0 | MSE Loss: 0.4782\n",
|
||||
"Step 10 | MSE Loss: 0.0848\n",
|
||||
"Step 20 | MSE Loss: 0.0446\n",
|
||||
"Step 30 | MSE Loss: 0.0274\n",
|
||||
"Step 40 | MSE Loss: 0.2338\n",
|
||||
"Step 50 | MSE Loss: 0.0158\n",
|
||||
"\n",
|
||||
"💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/GLFCR_generator.pth\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ==========================================\n",
|
||||
"# CELL 3: TRAINING & SAVE\n",
|
||||
"# ==========================================\n",
|
||||
"if 'train_loader' in locals():\n",
|
||||
" gen = GLF_Generator().to(device)\n",
|
||||
" opt = torch.optim.Adam(gen.parameters(), lr=0.0002)\n",
|
||||
" crit = nn.MSELoss() \n",
|
||||
" \n",
|
||||
" print(\"🚀 Bắt đầu train GLF-CR...\")\n",
|
||||
" for i, (s1, s2, cloud) in enumerate(train_loader):\n",
|
||||
" s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n",
|
||||
" \n",
|
||||
" opt.zero_grad()\n",
|
||||
" out = gen(s1, cloud)\n",
|
||||
" loss = crit(out, s2)\n",
|
||||
" loss.backward()\n",
|
||||
" opt.step()\n",
|
||||
" \n",
|
||||
" if i % 10 == 0: print(f\"Step {i} | MSE Loss: {loss.item():.4f}\")\n",
|
||||
" if i > 50: break\n",
|
||||
" \n",
|
||||
" # --- LƯU MODEL ---\n",
|
||||
" save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n",
|
||||
" os.makedirs(save_dir, exist_ok=True)\n",
|
||||
" torch.save(gen.state_dict(), os.path.join(save_dir, \"GLFCR_generator.pth\"))\n",
|
||||
" print(f\"\\n💾 Đã lưu model tại: {save_dir}/GLFCR_generator.pth\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f863772-5925-45b4-a9aa-e1dc54f52658",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
177
Train_SpA-GAN.ipynb
Normal file
177
Train_SpA-GAN.ipynb
Normal file
@@ -0,0 +1,177 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "3813f900-ea54-4533-873a-70aa46417ba0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"✅ SpA-GAN Data Ready: 29117 samples\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ==========================================\n",
|
||||
"# CELL 1: SETUP & LOAD DỮ LIỆU\n",
|
||||
"# ==========================================\n",
|
||||
"import os, sys, torch, numpy as np\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"\n",
|
||||
"# Fix Import Path\n",
|
||||
"sys.path.append(os.getcwd())\n",
|
||||
"sys.path.append(\"/home/jovyan/cloud_train\")\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" from sen12ms_cr_dataLoader import SEN12MSCRDataset, Seasons, S1Bands, S2Bands\n",
|
||||
"except ImportError:\n",
|
||||
" raise RuntimeError(\"❌ Không tìm thấy file 'sen12ms_cr_dataLoader.py'\")\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"base_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"dataset\")\n",
|
||||
"\n",
|
||||
"class SEN12MSCR_TorchDataset(Dataset):\n",
|
||||
" def __init__(self, base_dir, season=Seasons.SPRING):\n",
|
||||
" self.loader = SEN12MSCRDataset(base_dir)\n",
|
||||
" self.season = season\n",
|
||||
" self.season_ids = self.loader.get_season_ids(season)\n",
|
||||
" self.samples = [(sid, pid) for sid, pids in self.season_ids.items() for pid in pids]\n",
|
||||
" def __len__(self): return len(self.samples)\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" sid, pid = self.samples[idx]\n",
|
||||
" s1, s2, c, _ = self.loader.get_s1s2s2cloudy_triplet(self.season, sid, pid)\n",
|
||||
" s1 = np.clip(s1, -25, 0) / 25.0\n",
|
||||
" s2 = (np.clip(s2, 0, 10000) / 5000.0) - 1.0\n",
|
||||
" c = (np.clip(c, 0, 10000) / 5000.0) - 1.0\n",
|
||||
" return torch.from_numpy(s1).float(), torch.from_numpy(s2).float(), torch.from_numpy(c).float()\n",
|
||||
"\n",
|
||||
"train_dataset = SEN12MSCR_TorchDataset(base_dir)\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)\n",
|
||||
"print(f\"✅ SpA-GAN Data Ready: {len(train_dataset)} samples\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "078b6951-1dcf-44e9-895c-9ec0c8c848b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ==========================================\n",
|
||||
"# CELL 2: MODEL SpA-GAN\n",
|
||||
"# ==========================================\n",
|
||||
"class SpatialAttention(nn.Module):\n",
|
||||
" def __init__(self, in_channels):\n",
|
||||
" super(SpatialAttention, self).__init__()\n",
|
||||
" self.conv = nn.Conv2d(in_channels, 1, kernel_size=1) \n",
|
||||
" self.sigmoid = nn.Sigmoid()\n",
|
||||
" def forward(self, x):\n",
|
||||
" return x * self.sigmoid(self.conv(x))\n",
|
||||
"\n",
|
||||
"class SpA_Generator(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(SpA_Generator, self).__init__()\n",
|
||||
" self.enc1 = nn.Sequential(nn.Conv2d(15, 64, 4, 2, 1), nn.LeakyReLU(0.2))\n",
|
||||
" self.attn = SpatialAttention(64) # Module Attention\n",
|
||||
" self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))\n",
|
||||
" \n",
|
||||
" self.dec1 = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())\n",
|
||||
" self.dec2 = nn.ConvTranspose2d(64, 13, 4, 2, 1)\n",
|
||||
" self.tanh = nn.Tanh()\n",
|
||||
"\n",
|
||||
" def forward(self, s1, s2_cloudy):\n",
|
||||
" x = torch.cat([s1, s2_cloudy], dim=1)\n",
|
||||
" e1 = self.enc1(x)\n",
|
||||
" e1_attn = self.attn(e1) # Chú ý vào vùng quan trọng\n",
|
||||
" e2 = self.enc2(e1_attn)\n",
|
||||
" d1 = self.dec1(e2)\n",
|
||||
" return self.tanh(self.dec2(d1 + e1))\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "e6bd89be-aa61-4429-b588-f2041311f2a6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"🚀 Bắt đầu train SpA-GAN...\n",
|
||||
"Step 0 | Loss G: 0.6843\n",
|
||||
"Step 10 | Loss G: 0.3739\n",
|
||||
"Step 20 | Loss G: 0.1896\n",
|
||||
"Step 30 | Loss G: 0.1224\n",
|
||||
"Step 40 | Loss G: 0.0938\n",
|
||||
"Step 50 | Loss G: 0.1833\n",
|
||||
"\n",
|
||||
"💾 Đã lưu model tại: /home/jovyan/cloud_train/saved_models/SpAGAN_generator.pth\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ==========================================\n",
|
||||
"# CELL 3: TRAINING & SAVE\n",
|
||||
"# ==========================================\n",
|
||||
"if 'train_loader' in locals():\n",
|
||||
" gen = SpA_Generator().to(device)\n",
|
||||
" opt_g = torch.optim.Adam(gen.parameters(), lr=0.0002)\n",
|
||||
" criterion = nn.L1Loss()\n",
|
||||
" \n",
|
||||
" print(\"🚀 Bắt đầu train SpA-GAN...\")\n",
|
||||
" for i, (s1, s2, cloud) in enumerate(train_loader):\n",
|
||||
" s1, s2, cloud = s1.to(device), s2.to(device), cloud.to(device)\n",
|
||||
" \n",
|
||||
" opt_g.zero_grad()\n",
|
||||
" fake = gen(s1, cloud)\n",
|
||||
" loss_g = criterion(fake, s2) \n",
|
||||
" loss_g.backward()\n",
|
||||
" opt_g.step()\n",
|
||||
" \n",
|
||||
" if i % 10 == 0: print(f\"Step {i} | Loss G: {loss_g.item():.4f}\")\n",
|
||||
" if i > 50: break # Demo 50 bước\n",
|
||||
" \n",
|
||||
" # --- LƯU MODEL ---\n",
|
||||
" save_dir = os.path.join(os.path.expanduser(\"~\"), \"cloud_train\", \"saved_models\")\n",
|
||||
" os.makedirs(save_dir, exist_ok=True)\n",
|
||||
" torch.save(gen.state_dict(), os.path.join(save_dir, \"SpAGAN_generator.pth\"))\n",
|
||||
" print(f\"\\n💾 Đã lưu model tại: {save_dir}/SpAGAN_generator.pth\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bc1d8e75-0a3f-4197-8196-7132122f2be9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
321
Tải_Dự_liệu_train_xóa_mây.ipynb
Normal file
321
Tải_Dự_liệu_train_xóa_mây.ipynb
Normal file
@@ -0,0 +1,321 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "d-0LFUETwSj8",
|
||||
"outputId": "3b84f2ea-e6d6-4404-a0ae-68ed567ce5ef"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"✅ Đã cấu hình xong! Hãy chạy các cell bên dưới để tải từng phần.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# --- CELL 1: SETUP ---\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"# 1. Cấu hình\n",
|
||||
"os.environ['RSYNC_PASSWORD'] = 'm1554803'\n",
|
||||
"output_dir = \"./dataset/\"\n",
|
||||
"os.makedirs(output_dir, exist_ok=True)\n",
|
||||
"base_url = \"rsync://m1554803@dataserv.ub.tum.de/m1554803/\"\n",
|
||||
"\n",
|
||||
"# 2. Hàm tải file (Dùng chung cho các cell bên dưới)\n",
|
||||
"def download_files(file_list):\n",
|
||||
" print(f\"📂 Lưu tại: {output_dir}\")\n",
|
||||
" print(\"-\" * 50)\n",
|
||||
" for filename in file_list:\n",
|
||||
" print(f\"\\n🚀 Đang tải: {filename}\")\n",
|
||||
"\n",
|
||||
" # Lệnh rsync chuẩn, có resume (-P) và bỏ qua lỗi permission\n",
|
||||
" command = f\"rsync -rvP --no-o --no-g --no-p --no-t {base_url}{filename} {output_dir}\"\n",
|
||||
"\n",
|
||||
" exit_code = os.system(command)\n",
|
||||
" if exit_code == 0:\n",
|
||||
" print(f\"✅ XONG: {filename}\")\n",
|
||||
" else:\n",
|
||||
" print(f\"❌ LỖI: {filename} (Code: {exit_code})\")\n",
|
||||
" time.sleep(1)\n",
|
||||
"\n",
|
||||
"print(\"✅ Đã cấu hình xong! Hãy chạy các cell bên dưới để tải từng phần.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "JIijwdmmwSho",
|
||||
"outputId": "b84d0afb-66bf-4fb7-b2c6-3db3c55c49d0"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Lưu tại: ./dataset/\n",
|
||||
"--------------------------------------------------\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: checksums.sha512\n",
|
||||
"receiving incremental file list\n",
|
||||
"checksums.sha512\n",
|
||||
" 2,230 100% 2.13MB/s 0:00:00 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 43 bytes received 2,330 bytes 431.45 bytes/sec\n",
|
||||
"total size is 2,230 speedup is 0.94\n",
|
||||
"✅ XONG: checksums.sha512\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: sen12ms_cr_dataLoader.py\n",
|
||||
"receiving incremental file list\n",
|
||||
"sen12ms_cr_dataLoader.py\n",
|
||||
" 9,285 100% 8.85MB/s 0:00:00 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 43 bytes received 9,394 bytes 1,715.82 bytes/sec\n",
|
||||
"total size is 9,285 speedup is 0.98\n",
|
||||
"✅ XONG: sen12ms_cr_dataLoader.py\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# --- CELL 2: FILES NHỎ ---\n",
|
||||
"files = [\n",
|
||||
" \"checksums.sha512\",\n",
|
||||
" \"sen12ms_cr_dataLoader.py\"\n",
|
||||
"]\n",
|
||||
"download_files(files)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "2SezTEcFwSf7",
|
||||
"outputId": "5fe6e5ed-4267-4623-8554-8f9646074639"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Lưu tại: ./dataset/\n",
|
||||
"--------------------------------------------------\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: ROIs2017_winter_s1.tar.gz\n",
|
||||
"receiving incremental file list\n",
|
||||
"ROIs2017_winter_s1.tar.gz\n",
|
||||
" 8,294,291,725 100% 78.02MB/s 0:01:41 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 728,687 bytes received 364,428 bytes 6,726.86 bytes/sec\n",
|
||||
"total size is 8,294,291,725 speedup is 7,587.76\n",
|
||||
"✅ XONG: ROIs2017_winter_s1.tar.gz\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: ROIs2017_winter_s2.tar.gz\n",
|
||||
"receiving incremental file list\n",
|
||||
"ROIs2017_winter_s2.tar.gz\n",
|
||||
" 21,959,347,301 100% 76.10MB/s 0:04:35 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 1,340,419 bytes received 670,295 bytes 5,215.86 bytes/sec\n",
|
||||
"total size is 21,959,347,301 speedup is 10,921.17\n",
|
||||
"✅ XONG: ROIs2017_winter_s2.tar.gz\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: ROIs2017_winter_s2_cloudy.tar.gz\n",
|
||||
"receiving incremental file list\n",
|
||||
"ROIs2017_winter_s2_cloudy.tar.gz\n",
|
||||
" 13,391,641,622 100% 74.42MB/s 0:02:51 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 925,899 bytes received 463,043 bytes 5,566.90 bytes/sec\n",
|
||||
"total size is 13,391,641,622 speedup is 9,641.61\n",
|
||||
"✅ XONG: ROIs2017_winter_s2_cloudy.tar.gz\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# --- CELL 3: MÙA ĐÔNG (WINTER) ---\n",
|
||||
"# Tổng dung lượng: ~40.6 GB\n",
|
||||
"files = [\n",
|
||||
" \"ROIs2017_winter_s1.tar.gz\",\n",
|
||||
" \"ROIs2017_winter_s2.tar.gz\",\n",
|
||||
" \"ROIs2017_winter_s2_cloudy.tar.gz\"\n",
|
||||
"]\n",
|
||||
"download_files(files)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "7iIFPApowSbu",
|
||||
"outputId": "9e4d45c7-0ee5-45f6-860e-99af4601aa22"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Lưu tại: ./dataset/\n",
|
||||
"--------------------------------------------------\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: ROIs1158_spring_s1.tar.gz\n",
|
||||
"receiving incremental file list\n",
|
||||
"ROIs1158_spring_s1.tar.gz\n",
|
||||
" 13,229,071,106 100% 79.26MB/s 0:02:39 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 920,259 bytes received 460,216 bytes 5,763.99 bytes/sec\n",
|
||||
"total size is 13,229,071,106 speedup is 9,582.98\n",
|
||||
"✅ XONG: ROIs1158_spring_s1.tar.gz\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: ROIs1158_spring_s2.tar.gz\n",
|
||||
"receiving incremental file list\n",
|
||||
"ROIs1158_spring_s2.tar.gz\n",
|
||||
" 34,953,924,853 100% 76.34MB/s 0:07:16 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 2,133,595 bytes received 1,066,884 bytes 5,419.95 bytes/sec\n",
|
||||
"total size is 34,953,924,853 speedup is 10,921.47\n",
|
||||
"✅ XONG: ROIs1158_spring_s2.tar.gz\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: ROIs1158_spring_s2_cloudy.tar.gz\n",
|
||||
"receiving incremental file list\n",
|
||||
"ROIs1158_spring_s2_cloudy.tar.gz\n",
|
||||
" 21,910,216,767 100% 77.47MB/s 0:04:29 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 1,337,419 bytes received 668,802 bytes 5,400.33 bytes/sec\n",
|
||||
"total size is 21,910,216,767 speedup is 10,921.14\n",
|
||||
"✅ XONG: ROIs1158_spring_s2_cloudy.tar.gz\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# --- CELL 4: MÙA XUÂN (SPRING) ---\n",
|
||||
"# Tổng dung lượng: ~65.3 GB\n",
|
||||
"files = [\n",
|
||||
" \"ROIs1158_spring_s1.tar.gz\",\n",
|
||||
" \"ROIs1158_spring_s2.tar.gz\",\n",
|
||||
" \"ROIs1158_spring_s2_cloudy.tar.gz\"\n",
|
||||
"]\n",
|
||||
"download_files(files)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "s6t_5qyUwSZ-",
|
||||
"outputId": "ddc101ae-753d-44f1-c3ee-c8413c16fb99"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Lưu tại: ./dataset/\n",
|
||||
"--------------------------------------------------\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: ROIs1868_summer_s2_cloudy.tar.gz\n",
|
||||
"receiving incremental file list\n",
|
||||
"ROIs1868_summer_s2_cloudy.tar.gz\n",
|
||||
" 24,766,970,197 100% 77.10MB/s 0:05:06 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 1,511,791 bytes received 755,986 bytes 5,329.68 bytes/sec\n",
|
||||
"total size is 24,766,970,197 speedup is 10,921.25\n",
|
||||
"✅ XONG: ROIs1868_summer_s2_cloudy.tar.gz\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# --- CELL 5: MÙA HÈ (SUMMER) ---\n",
|
||||
"# Tổng dung lượng: ~74.7 GB\n",
|
||||
"files = [\n",
|
||||
" \"ROIs1868_summer_s1.tar.gz\",\n",
|
||||
" \"ROIs1868_summer_s2.tar.gz\",\n",
|
||||
" \"ROIs1868_summer_s2_cloudy.tar.gz\"\n",
|
||||
"]\n",
|
||||
"download_files(files)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "KbccbbI9wSUd",
|
||||
"outputId": "7b208f9d-8c3f-45e8-9582-88a872527cb0"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Lưu tại: ./dataset/\n",
|
||||
"--------------------------------------------------\n",
|
||||
"\n",
|
||||
"🚀 Đang tải: ROIs1970_fall_s2_cloudy.tar.gz\n",
|
||||
"receiving incremental file list\n",
|
||||
"ROIs1970_fall_s2_cloudy.tar.gz\n",
|
||||
" 30,149,143,027 100% 36.83MB/s 0:13:00 (xfr#1, to-chk=0/1)\n",
|
||||
"\n",
|
||||
"sent 1,785,615 bytes received 897,343,620 bytes 1,008,557.75 bytes/sec\n",
|
||||
"total size is 30,149,143,027 speedup is 33.53\n",
|
||||
"✅ XONG: ROIs1970_fall_s2_cloudy.tar.gz\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"### --- CELL 6: MÙA THU (FALL) ---\n",
|
||||
"# Tổng dung lượng: ~91.2 GB\n",
|
||||
"files = [\n",
|
||||
" \"ROIs1970_fall_s1.tar.gz\",\n",
|
||||
" \"ROIs1970_fall_s2.tar.gz\",\n",
|
||||
" \"ROIs1970_fall_s2_cloudy.tar.gz\"\n",
|
||||
"]\n",
|
||||
"download_files(files)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
241
format_lại_đataset.ipynb
Normal file
241
format_lại_đataset.ipynb
Normal file
@@ -0,0 +1,241 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "415c3125-320a-4164-a9ad-4007ac750ff9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Đang thực hiện sắp xếp tại: /home/jovyan/cloud_train/dataset\n",
|
||||
"➕ Đã tạo folder mẹ: ROIs1158_spring\n",
|
||||
" ➡ Di chuyển: ROIs1158_spring_s1 -> ROIs1158_spring/s1\n",
|
||||
" ➡ Di chuyển: ROIs1158_spring_s2 -> ROIs1158_spring/s2\n",
|
||||
" ➡ Di chuyển: ROIs1158_spring_s2_cloudy -> ROIs1158_spring/s2_cloudy\n",
|
||||
"➕ Đã tạo folder mẹ: ROIs1868_summer\n",
|
||||
" ➡ Di chuyển: ROIs1868_summer_s1 -> ROIs1868_summer/s1\n",
|
||||
" ➡ Di chuyển: ROIs1868_summer_s2 -> ROIs1868_summer/s2\n",
|
||||
" ➡ Di chuyển: ROIs1868_summer_s2_cloudy -> ROIs1868_summer/s2_cloudy\n",
|
||||
"➕ Đã tạo folder mẹ: ROIs1970_fall\n",
|
||||
" ➡ Di chuyển: ROIs1970_fall_s1 -> ROIs1970_fall/s1\n",
|
||||
" ➡ Di chuyển: ROIs1970_fall_s2 -> ROIs1970_fall/s2\n",
|
||||
" ➡ Di chuyển: ROIs1970_fall_s2_cloudy -> ROIs1970_fall/s2_cloudy\n",
|
||||
"➕ Đã tạo folder mẹ: ROIs2017_winter\n",
|
||||
" ➡ Di chuyển: ROIs2017_winter_s1 -> ROIs2017_winter/s1\n",
|
||||
" ➡ Di chuyển: ROIs2017_winter_s2 -> ROIs2017_winter/s2\n",
|
||||
" ➡ Di chuyển: ROIs2017_winter_s2_cloudy -> ROIs2017_winter/s2_cloudy\n",
|
||||
"\n",
|
||||
"✅ ĐÃ XONG! Đã di chuyển 12 thư mục.\n",
|
||||
"👉 Bây giờ hãy chạy lại Cell load dữ liệu (Train code), nó sẽ hoạt động.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"# ==========================================\n",
|
||||
"# SCRIPT SẮP XẾP LẠI THƯ MỤC (CHẠY 1 LẦN)\n",
|
||||
"# ==========================================\n",
|
||||
"\n",
|
||||
"# 1. Xác định đường dẫn gốc\n",
|
||||
"user_home = os.path.expanduser(\"~\")\n",
|
||||
"base_dir = os.path.join(user_home, \"cloud_train\", \"dataset\")\n",
|
||||
"print(f\"📂 Đang thực hiện sắp xếp tại: {base_dir}\")\n",
|
||||
"\n",
|
||||
"if not os.path.exists(base_dir):\n",
|
||||
" print(\"❌ Lỗi: Không tìm thấy thư mục dataset!\")\n",
|
||||
"else:\n",
|
||||
" # 2. Định nghĩa cấu trúc chuẩn cần gom\n",
|
||||
" # Key = Tên thư mục mẹ (Mới)\n",
|
||||
" # Value = Danh sách các thư mục con (Cũ - đang nằm lẻ)\n",
|
||||
" seasons_map = {\n",
|
||||
" \"ROIs1158_spring\": [\"ROIs1158_spring_s1\", \"ROIs1158_spring_s2\", \"ROIs1158_spring_s2_cloudy\"],\n",
|
||||
" \"ROIs1868_summer\": [\"ROIs1868_summer_s1\", \"ROIs1868_summer_s2\", \"ROIs1868_summer_s2_cloudy\"],\n",
|
||||
" \"ROIs1970_fall\": [\"ROIs1970_fall_s1\", \"ROIs1970_fall_s2\", \"ROIs1970_fall_s2_cloudy\"],\n",
|
||||
" \"ROIs2017_winter\": [\"ROIs2017_winter_s1\", \"ROIs2017_winter_s2\", \"ROIs2017_winter_s2_cloudy\"]\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" count_moved = 0\n",
|
||||
" \n",
|
||||
" for target_season, sub_folders in seasons_map.items():\n",
|
||||
" # Tạo thư mục mẹ (ví dụ: ROIs1158_spring)\n",
|
||||
" target_path = os.path.join(base_dir, target_season)\n",
|
||||
" if not os.path.exists(target_path):\n",
|
||||
" os.makedirs(target_path)\n",
|
||||
" print(f\"➕ Đã tạo folder mẹ: {target_season}\")\n",
|
||||
" \n",
|
||||
" for sub in sub_folders:\n",
|
||||
" source_path = os.path.join(base_dir, sub) # Đường dẫn thư mục lẻ hiện tại\n",
|
||||
" \n",
|
||||
" # Nếu thư mục lẻ tồn tại, di chuyển nội dung của nó vào thư mục mẹ\n",
|
||||
" if os.path.exists(source_path):\n",
|
||||
" # Đổi tên thư mục lẻ cho đúng chuẩn (bỏ prefix mùa đi)\n",
|
||||
" # Ví dụ: ROIs1158_spring_s1 -> s1\n",
|
||||
" \n",
|
||||
" new_folder_name = \"\"\n",
|
||||
" if \"_s1\" in sub: new_folder_name = \"s1\"\n",
|
||||
" elif \"_s2_cloudy\" in sub: new_folder_name = \"s2_cloudy\"\n",
|
||||
" elif \"_s2\" in sub: new_folder_name = \"s2\"\n",
|
||||
" \n",
|
||||
" final_dest = os.path.join(target_path, new_folder_name)\n",
|
||||
" \n",
|
||||
" print(f\" ➡ Di chuyển: {sub} -> {target_season}/{new_folder_name}\")\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" # Nếu đích chưa có thì move thẳng folder sang và đổi tên\n",
|
||||
" if not os.path.exists(final_dest):\n",
|
||||
" shutil.move(source_path, final_dest)\n",
|
||||
" count_moved += 1\n",
|
||||
" else:\n",
|
||||
" print(f\" ⚠️ Thư mục {new_folder_name} đã tồn tại trong {target_season}, bỏ qua.\")\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\" ❌ Lỗi: {e}\")\n",
|
||||
" else:\n",
|
||||
" # Kiểm tra xem nó đã nằm đúng chỗ chưa\n",
|
||||
" check_path = os.path.join(target_path, sub.split(\"_\")[-1].replace(\"cloudy\", \"s2_cloudy\") if \"cloudy\" in sub else sub.split(\"_\")[-1])\n",
|
||||
" # Logic check đơn giản\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" if count_moved > 0:\n",
|
||||
" print(f\"\\n✅ ĐÃ XONG! Đã di chuyển {count_moved} thư mục.\")\n",
|
||||
" print(\"👉 Bây giờ hãy chạy lại Cell load dữ liệu (Train code), nó sẽ hoạt động.\")\n",
|
||||
" else:\n",
|
||||
" print(\"\\nℹ️ Không có gì thay đổi. Có thể cấu trúc đã đúng hoặc tên file không khớp.\")\n",
|
||||
" print(\"Hãy kiểm tra thủ công bằng lệnh: !ls -F /home/jovyan/cloud_train/dataset/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "0b006b3f-54f9-46d2-87a6-c4ee00f7b1e4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"📂 Đang xử lý tại: /home/jovyan/cloud_train/dataset\n",
|
||||
"🔍 Đang quét mùa: ROIs1158_spring\n",
|
||||
" found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s1\n",
|
||||
" found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s2\n",
|
||||
" found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s2_cloudy\n",
|
||||
"🔍 Đang quét mùa: ROIs1868_summer\n",
|
||||
" found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s1\n",
|
||||
" found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s2\n",
|
||||
" found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s2_cloudy\n",
|
||||
"🔍 Đang quét mùa: ROIs1970_fall\n",
|
||||
" found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s1\n",
|
||||
" found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s2\n",
|
||||
" found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s2_cloudy\n",
|
||||
"🔍 Đang quét mùa: ROIs2017_winter\n",
|
||||
" found nested folder: s1 -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s1\n",
|
||||
" found nested folder: s2 -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s2\n",
|
||||
" found nested folder: s2_cloudy -> Đang di chuyển nội dung ra ngoài...\n",
|
||||
" ✅ Đã xóa folder rỗng: s2_cloudy\n",
|
||||
"\n",
|
||||
"✅ ĐÃ SỬA XONG CẤU TRÚC!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"# ==========================================\n",
|
||||
"# FIX CẤU TRÚC PHASE 2: ĐƯA THƯ MỤC CON RA NGOÀI\n",
|
||||
"# ==========================================\n",
|
||||
"\n",
|
||||
"user_home = os.path.expanduser(\"~\")\n",
|
||||
"base_dir = os.path.join(user_home, \"cloud_train\", \"dataset\")\n",
|
||||
"print(f\"📂 Đang xử lý tại: {base_dir}\")\n",
|
||||
"\n",
|
||||
"seasons = [\n",
|
||||
" \"ROIs1158_spring\", \n",
|
||||
" \"ROIs1868_summer\", \n",
|
||||
" \"ROIs1970_fall\", \n",
|
||||
" \"ROIs2017_winter\"\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Các folder trung gian cần loại bỏ\n",
|
||||
"sub_types = [\"s1\", \"s2\", \"s2_cloudy\"]\n",
|
||||
"\n",
|
||||
"for season in seasons:\n",
|
||||
" season_path = os.path.join(base_dir, season)\n",
|
||||
" if not os.path.exists(season_path):\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" print(f\"🔍 Đang quét mùa: {season}\")\n",
|
||||
" \n",
|
||||
" for sub in sub_types:\n",
|
||||
" # Đường dẫn tới folder trung gian (ví dụ: ROIs1158_spring/s1)\n",
|
||||
" nested_path = os.path.join(season_path, sub)\n",
|
||||
" \n",
|
||||
" if os.path.exists(nested_path):\n",
|
||||
" print(f\" found nested folder: {sub} -> Đang di chuyển nội dung ra ngoài...\")\n",
|
||||
" \n",
|
||||
" # Lấy danh sách các scene bên trong (s1_1, s1_2...)\n",
|
||||
" scenes = os.listdir(nested_path)\n",
|
||||
" for scene in scenes:\n",
|
||||
" src = os.path.join(nested_path, scene)\n",
|
||||
" dst = os.path.join(season_path, scene)\n",
|
||||
" \n",
|
||||
" # Di chuyển ra folder mẹ\n",
|
||||
" if not os.path.exists(dst):\n",
|
||||
" shutil.move(src, dst)\n",
|
||||
" \n",
|
||||
" # Sau khi chuyển hết thì xóa folder rỗng đi\n",
|
||||
" try:\n",
|
||||
" os.rmdir(nested_path)\n",
|
||||
" print(f\" ✅ Đã xóa folder rỗng: {sub}\")\n",
|
||||
" except OSError:\n",
|
||||
" print(f\" ⚠️ Không xóa được {sub} (có thể còn file rác)\")\n",
|
||||
"\n",
|
||||
"print(\"\\n✅ ĐÃ SỬA XONG CẤU TRÚC!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e6151206-f103-451c-a572-029a4da25816",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
270
sen12ms_cr_dataLoader.py
Executable file
270
sen12ms_cr_dataLoader.py
Executable file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Generic data loading routines for the SEN12MS-CR dataset of corresponding Sentinel 1,
|
||||
Sentinel 2 and cloudy Sentinel 2 data.
|
||||
|
||||
The SEN12MS-CR class is meant to provide a set of helper routines for loading individual
|
||||
image patches as well as triplets of patches from the dataset. These routines can easily
|
||||
be wrapped or extended for use with many deep learning frameworks or as standalone helper
|
||||
methods. For an example use case please see the "main" routine at the end of this file.
|
||||
|
||||
NOTE: Some folder/file existence and validity checks are implemented but it is
|
||||
by no means complete.
|
||||
|
||||
Authors: Patrick Ebel (patrick.ebel@tum.de), Lloyd Hughes (lloyd.hughes@tum.de),
|
||||
based on the exemplary data loader code of https://mediatum.ub.tum.de/1474000, with minimal modifications applied.
|
||||
"""
|
||||
|
||||
import os
|
||||
import rasterio
|
||||
|
||||
import numpy as np
|
||||
|
||||
from enum import Enum
|
||||
from glob import glob
|
||||
|
||||
|
||||
class S1Bands(Enum):
|
||||
VV = 1
|
||||
VH = 2
|
||||
ALL = [VV, VH]
|
||||
NONE = []
|
||||
|
||||
|
||||
class S2Bands(Enum):
|
||||
B01 = aerosol = 1
|
||||
B02 = blue = 2
|
||||
B03 = green = 3
|
||||
B04 = red = 4
|
||||
B05 = re1 = 5
|
||||
B06 = re2 = 6
|
||||
B07 = re3 = 7
|
||||
B08 = nir1 = 8
|
||||
B08A = nir2 = 9
|
||||
B09 = vapor = 10
|
||||
B10 = cirrus = 11
|
||||
B11 = swir1 = 12
|
||||
B12 = swir2 = 13
|
||||
ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12]
|
||||
RGB = [B04, B03, B02]
|
||||
NONE = []
|
||||
|
||||
|
||||
class Seasons(Enum):
|
||||
SPRING = "ROIs1158_spring"
|
||||
SUMMER = "ROIs1868_summer"
|
||||
FALL = "ROIs1970_fall"
|
||||
WINTER = "ROIs2017_winter"
|
||||
ALL = [SPRING, SUMMER, FALL, WINTER]
|
||||
|
||||
|
||||
class Sensor(Enum):
|
||||
s1 = "s1"
|
||||
s2 = "s2"
|
||||
s2cloudy = "s2cloudy"
|
||||
|
||||
# Note: The order in which you request the bands is the same order they will be returned in.
|
||||
|
||||
|
||||
class SEN12MSCRDataset:
|
||||
def __init__(self, base_dir):
|
||||
self.base_dir = base_dir
|
||||
|
||||
if not os.path.exists(self.base_dir):
|
||||
raise Exception(
|
||||
"The specified base_dir for SEN12MS-CR dataset does not exist")
|
||||
|
||||
"""
|
||||
Returns a list of scene ids for a specific season.
|
||||
"""
|
||||
|
||||
def get_scene_ids(self, season):
|
||||
season = Seasons(season).value
|
||||
path = os.path.join(self.base_dir, season)
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise NameError("Could not find season {} in base directory {}".format(
|
||||
season, self.base_dir))
|
||||
|
||||
# add all dirs except "s2_cloudy" (which messes with subsequent string splits)
|
||||
scene_list = [os.path.basename(s)
|
||||
for s in glob(os.path.join(path, "*")) if "s2_cloudy" not in s]
|
||||
scene_list = [int(s.split("_")[1]) for s in scene_list]
|
||||
return set(scene_list)
|
||||
|
||||
"""
|
||||
Returns a list of patch ids for a specific scene within a specific season
|
||||
"""
|
||||
|
||||
def get_patch_ids(self, season, scene_id):
|
||||
season = Seasons(season).value
|
||||
path = os.path.join(self.base_dir, season, f"s1_{scene_id}")
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise NameError(
|
||||
"Could not find scene {} within season {}".format(scene_id, season))
|
||||
|
||||
patch_ids = [os.path.splitext(os.path.basename(p))[0]
|
||||
for p in glob(os.path.join(path, "*"))]
|
||||
patch_ids = [int(p.rsplit("_", 1)[1].split("p")[1]) for p in patch_ids]
|
||||
|
||||
return patch_ids
|
||||
|
||||
"""
|
||||
Return a dict of scene ids and their corresponding patch ids.
|
||||
key => scene_ids, value => list of patch_ids
|
||||
"""
|
||||
|
||||
def get_season_ids(self, season):
|
||||
season = Seasons(season).value
|
||||
ids = {}
|
||||
scene_ids = self.get_scene_ids(season)
|
||||
|
||||
for sid in scene_ids:
|
||||
ids[sid] = self.get_patch_ids(season, sid)
|
||||
|
||||
return ids
|
||||
|
||||
"""
|
||||
Returns raster data and image bounds for the defined bands of a specific patch
|
||||
This method only loads a sinlge patch from a single sensor as defined by the bands specified
|
||||
"""
|
||||
|
||||
def get_patch(self, season, scene_id, patch_id, bands):
|
||||
season = Seasons(season).value
|
||||
sensor = None
|
||||
|
||||
if isinstance(bands, (list, tuple)):
|
||||
b = bands[0]
|
||||
else:
|
||||
b = bands
|
||||
|
||||
if isinstance(b, S1Bands):
|
||||
sensor = Sensor.s1.value
|
||||
bandEnum = S1Bands
|
||||
elif isinstance(b, S2Bands):
|
||||
sensor = Sensor.s2.value
|
||||
bandEnum = S2Bands
|
||||
else:
|
||||
raise Exception("Invalid bands specified")
|
||||
|
||||
if isinstance(bands, (list, tuple)):
|
||||
bands = [b.value for b in bands]
|
||||
else:
|
||||
bands = bands.value
|
||||
|
||||
scene = "{}_{}".format(sensor, scene_id)
|
||||
filename = "{}_{}_p{}.tif".format(season, scene, patch_id)
|
||||
patch_path = os.path.join(self.base_dir, season, scene, filename)
|
||||
|
||||
with rasterio.open(patch_path) as patch:
|
||||
data = patch.read(bands)
|
||||
bounds = patch.bounds
|
||||
|
||||
if len(data.shape) == 2:
|
||||
data = np.expand_dims(data, axis=0)
|
||||
|
||||
return data, bounds
|
||||
|
||||
"""
|
||||
Returns a triplet of patches. S1, S2 and cloudy S2 as well as the geo-bounds of the patch
|
||||
"""
|
||||
|
||||
def get_s1s2s2cloudy_triplet(self, season, scene_id, patch_id, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
|
||||
s1, bounds = self.get_patch(season, scene_id, patch_id, s1_bands)
|
||||
s2, _ = self.get_patch(season, scene_id, patch_id, s2_bands)
|
||||
s2cloudy, _ = self.get_patch(season, scene_id, patch_id, s2cloudy_bands)
|
||||
|
||||
return s1, s2, s2cloudy, bounds
|
||||
|
||||
"""
|
||||
Returns a triplet of numpy arrays with dimensions D, B, W, H where D is the number of patches specified
|
||||
using scene_ids and patch_ids and B is the number of bands for S1, S2 or cloudy S2
|
||||
"""
|
||||
|
||||
def get_triplets(self, season, scene_ids=None, patch_ids=None, s1_bands=S1Bands.ALL, s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL):
|
||||
season = Seasons(season)
|
||||
scene_list = []
|
||||
patch_list = []
|
||||
bounds = []
|
||||
s1_data = []
|
||||
s2_data = []
|
||||
s2cloudy_data = []
|
||||
|
||||
# This is due to the fact that not all patch ids are available in all scenes
|
||||
# And not all scenes exist in all seasons
|
||||
if isinstance(scene_ids, list) and isinstance(patch_ids, list):
|
||||
raise Exception("Only scene_ids or patch_ids can be a list, not both.")
|
||||
|
||||
if scene_ids is None:
|
||||
scene_list = self.get_scene_ids(season)
|
||||
else:
|
||||
try:
|
||||
scene_list.extend(scene_ids)
|
||||
except TypeError:
|
||||
scene_list.append(scene_ids)
|
||||
|
||||
if patch_ids is not None:
|
||||
try:
|
||||
patch_list.extend(patch_ids)
|
||||
except TypeError:
|
||||
patch_list.append(patch_ids)
|
||||
|
||||
for sid in scene_list:
|
||||
if patch_ids is None:
|
||||
patch_list = self.get_patch_ids(season, sid)
|
||||
|
||||
for pid in patch_list:
|
||||
s1, s2, s2cloudy, bound = self.get_s1s2s2cloudy_triplet(
|
||||
season, sid, pid, s1_bands, s2_bands, s2cloudy_bands)
|
||||
s1_data.append(s1)
|
||||
s2_data.append(s2)
|
||||
s2cloudy_data.append(s2cloudy)
|
||||
bounds.append(bound)
|
||||
|
||||
return np.stack(s1_data, axis=0), np.stack(s2_data, axis=0), np.stack(s2cloudy_data, axis=0), bounds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
# Load the dataset specifying the base directory
|
||||
sen12mscr = SEN12MSCRDataset(".")
|
||||
|
||||
spring_ids = sen12mscr.get_season_ids(Seasons.SPRING)
|
||||
cnt_patches = sum([len(pids) for pids in spring_ids.values()])
|
||||
print("Spring: {} scenes with a total of {} patches".format(
|
||||
len(spring_ids), cnt_patches))
|
||||
|
||||
start = time.time()
|
||||
# Load the RGB bands of the first S2 patch in scene 8
|
||||
SCENE_ID = 8
|
||||
s2_rgb_patch, bounds = sen12mscr.get_patch(Seasons.SPRING, SCENE_ID,
|
||||
spring_ids[SCENE_ID][0], bands=S2Bands.RGB)
|
||||
print("Time Taken {}s".format(time.time() - start))
|
||||
|
||||
print("S2 RGB: {} Bounds: {}".format(s2_rgb_patch.shape, bounds))
|
||||
|
||||
print("\n")
|
||||
|
||||
# Load a triplet of patches from the first three scenes of Spring - all S1 bands, NDVI S2 bands, and NDVI S2 cloudy bands
|
||||
i = 0
|
||||
start = time.time()
|
||||
for scene_id, patch_ids in spring_ids.items():
|
||||
if i >= 3:
|
||||
break
|
||||
|
||||
s1, s2, s2cloudy, bounds = sen12mscr.get_s1s2s2cloudy_triplet(Seasons.SPRING, scene_id, patch_ids[0], s1_bands=S1Bands.ALL,
|
||||
s2_bands=[S2Bands.red, S2Bands.nir1], s2cloudy_bands=[S2Bands.red, S2Bands.nir1])
|
||||
print(
|
||||
f"Scene: {scene_id}, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}, Bounds: {bounds}")
|
||||
i += 1
|
||||
|
||||
print("Time Taken {}s".format(time.time() - start))
|
||||
print("\n")
|
||||
|
||||
start = time.time()
|
||||
# Load all bands of all patches in a specified scene (scene 106)
|
||||
s1, s2, s2cloudy, _ = sen12mscr.get_triplets(Seasons.SPRING, 106, s1_bands=S1Bands.ALL,
|
||||
s2_bands=S2Bands.ALL, s2cloudy_bands=S2Bands.ALL)
|
||||
|
||||
print(f"Scene: 106, S1: {s1.shape}, S2: {s2.shape}, cloudy S2: {s2cloudy.shape}")
|
||||
print("Time Taken {}s".format(time.time() - start))
|
||||
Reference in New Issue
Block a user