pix2pix & cyclegan

به نام خدا

کلاس نوت – محمد حسین نیکی ملکی
مقدمه ای بر cyclegan & pix2pix

ساختار شبکه ای pix2pix از ساختار شبکه های شرطی GAN برای ساخت دیتا ها و تبدیل و generate کردن آن هااستفاده میکند .
فرض کنید تصاویری ورودی شبکه ما باینری شده و انتظار خروجی شبکه امان هم از همان جنس دیتا رنگی rgb باشد .طبیعتا انتظار داریم که generator در اینجا یاد بگیرد که چطور رمگ امیزی کند .و همچنین discriminator هم تلاش میکند تا تفاوت بین رنگی که generator تولید کرده و رنگ اصلی به نظر صحیح که در بین دیتاست ها بوده را محاسبه کند .

ساختار generator بصورت “encoder-decoder” است و pix2pix هم تقریبا مانند همین (در شکل پایین هم میبینید ) است .

مقادیر دایمنشن های هر لایه در کنار آن نوشته شده است .در اینجا ورودی شبکه ما یک 256*256*3 است و پس از اعمال تغییرات دوباره به همان شیپ ابتدایی در خروجی باز میگردد .
در اینجا generator ورودی هارا میگیرد وبا انکودر ها مقادیرش را کم میکند ( لازم به یاداوری است که انکودر محصولی از کانولوشن و اکتیویشن فانکشن مربوطه است .) همچنین در ادامه ی مسیر هم از دیکودر برای بازگشت ز مسیر استفاده میکنیم .
نکته : برای بهبود \رفورمنس شبکه و سرعت پردازش میتونیم از شبکه unet هم به جای encoder-decoder استفاده کنیم . شبکه unet در واقع همون encoder-decoder است اما بصورت موازی هر لایه encoder را به لایه decoder مربوطه اش متصل میکند و سودش در این است که ممکن است به برخی از لایه ها احتیاجی نداشته باشیم و از آنها ب\ریم و گذر نکنیم .
Descriminator :

کار descriminator این است که دو عکس یکی از ورودی و یکی از خروجی یا عکس هدف میگیرد و تصمیم میگیرد که ایا عکس دوم توسط generator ساخته شده است یا خیر .

ساختار DISCRIMINATOR شبیه به ENCODER-DECODER به نظر میرسد اما فرق دارد .

خروجی اش یک عکس ۳۰ در ۳۰ باینری است که مشخص میکند چقدر قسمت های مختلف عکس دوم به اول شبیه هستند . در پیاده سازی PIX2PIX هر پیکسل ۳۰ در ۳۰ به یک قسمت ۷۰ در ۷۰ عکس ورودی مپ میشود .

How to train :

برای ترین این شبکه باید هم discriminator و هم generator را اموزش بدهیم .
برای اموزس دیسکریمینیتور ابتدا جنریتور خروجی را تولید میکند و دیسکریمینیتور هر دفه خروجی هارا با ورودی ها مقایسه میکند و یاد میگیرد و میتواند حدس بزند که چقدر آنها طبیعی هستند . پس وزن های دیسکریمینیتور براسا ورودی و خروجی های شبکه ی جنریتور ما تعیین میگردد .
همینطور وزن های جنریتور هم براساس خروجی های دیسکریمینیتور اپدیت میشود هر دفعه .

cyclegans :

pipx2pix میتواند خروجی های بینظیری را تولید کند . چالش اصلی در قسمت ترین است . قسمت لیبل زدن و دو عکسی برای عملیات ترینینگ باید به ورودی شبکه بدهیم به عنوان x و y .
که ممکن است کار را برا ی ما بسیار سخت و گاها غیر ممکن کند.
به خاطر همین از سایکل گن ها استفاده میکنیم .
نکته اصلی انها این است که میتوانند بر اساس معماری pix2pix هم پیاده سازی بشوند اما به شما اجازه دهند که به دو مجموعه از تصاویر اشاره کنید و فراخوانی کنید .
در اینجا بعنوان مثال فرض کنید دو دسته عکس از ساحل دارید یکی در زمان صبح و دیگری در عصر هنگام غروب . سایکل گن میتواند یادبگیرد که چگونه عکس های مربوط به صبح را به عصر تبدیل کند و برعکس بدون آنکه نیاز باشد تا بصورت یک در یک مقابل هم قرارشان دهیم . دلیل اصلی که سایکل را تا این حد قدرتمند کرده پیاده سازی این فرضیه است که بطور کامل تبدیل رفت و برگشتی صورت میپذیرد که باعث میشود هردو جنریتور هم بطور همزمان تقویت شوند .

Loss function :

قدرت سایکل گن در این است که loss function ای برای هر سایکل رفت و برگشت ست میکند و آن را به عنوان یک optimizer اضافی در مسله قرار میدهد .
در این قسمت نحوه ی محاسبه ی loss function برای generator را میبینیم :


g_loss_G_disc = tf.reduce_mean((discY_fake — tf.ones_like(discY_fake)) ** 2)
g_loss_F_dicr = tf.reduce_mean((discX_fake — tf.ones_like(discX_fake)) ** 2)


g_loss_G_cycle = tf.reduce_mean(tf.abs(real_X — genF_back)) + tf.reduce_mean(tf.abs(real_Y — genG_back))
g_loss_F_cycle = tf.reduce_mean(tf.abs(real_X — genF_back)) + tf.reduce_mean(tf.abs(real_Y — genG_back))

در این قسمت نحوه ی محاسبه ی loss function برای discriminator را بررسی میکنیم :


DY_loss_real = tf.reduce_mean((DY — tf.ones_like(DY))** 2)
DY_loss_fake = tf.reduce_mean((DY_fake_sample — tf.zeros_like(DY_fake_sample)) ** 2)
DY_loss = (DY_loss_real + DY_loss_fake) / 2
DX_loss_real = tf.reduce_mean((DX — tf.ones_like(DX)) ** 2)
DX_loss_fake = tf.reduce_mean((DX_fake_sample — tf.zeros_like(DX_fake_sample)) ** 2)
DX_loss = (DX_loss_real + DX_loss_fake) / 2