《TensorFlow深度学习算法原理与编程实战》
作者:蒋子阳
在所有笔记中搜索你感兴趣的关键词!
作者:蒋子阳
先,代码引用自https://github.com/wiseodd/generative-models
感谢这位网友的代码支持。
每个月总有30天不想看论文,所以直接看源码或许是一个好办法。因为有些时候它的改动就那么一点点。而论文却要用晦涩难懂的语言证明上十几页。
上边这个链接中给出了很多GAN和VAE以及各种变体的源码,并且写得清晰易懂,再次感谢这位网友的贡献。
ConditionalGAN顾名思义是条件GAN,就是给GAN增加一个条件。具体是怎么回事呢?看代码:
这段代码使用mnist数据集,来生成手写数字。以下代码可以直接正确运行。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)
""" Discriminator Net model """
X = tf.placeholder(tf.float32, shape=[None, 784])
y = tf.placeholder(tf.float32, shape=[None, y_dim])
D_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
D_W2 = tf.Variable(xavier_init([h_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))
theta_D = [D_W1, D_W2, D_b1, D_b2]
以上与普通的GAN没有区别,从下边开始可以看到discriminator除了输入原来的x,还输入了一个y。这个y就是我们所说的condition。接下来的generator也一样,多了一个y。
def discriminator(x, y):
inputs = tf.concat(axis=1, values=[x, y])
D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit
""" Generator Net model """
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
G_W1 = tf.Variable(xavier_init([Z_dim + y_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
theta_G = [G_W1, G_W2, G_b1, G_b2]
def generator(z, y):
inputs = tf.concat(axis=1, values=[z, y])
G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)
return G_prob
def sample_Z(m, n):
return np.random.uniform(-1., 1., size=[m, n])
def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)
可以看出来这边的discriminator和generator都是多输入了一个条件y。
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
loss还是没有变化。
sess = tf.Session()
sess.run(tf.global_variables_initializer())
if not os.path.exists('out/'):
os.makedirs('out/')
i = 0
for it in range(1000000):
if it % 1000 == 0:
n_sample = 16
Z_sample = sample_Z(n_sample, Z_dim)
y_sample = np.zeros(shape=[n_sample, y_dim])
y_sample[:, 7] = 1
samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
i += 1
plt.close(fig)
X_mb, y_mb = mnist.train.next_batch(mb_size)
Z_sample = sample_Z(mb_size, Z_dim)
_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})
在训练中,输入的y是输入的x所一一对应的真实标签。
在生成的过程中,我们想生成什么就输入对应的标签。
例如以上代码中我们输入的是7的标签,也就是one-hot形式的label中第7位位1,其他位为0。
if it % 1000 == 0:
print('Iter: {}'.format(it))
print('D loss: {:.4}'. format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
print()
到这里就结束了,这么一点代码就可以生成“我想要的(也就是附加了条件的)”逼真的手写数字,是不是很简单呢?
参考资料: https://blog.csdn.net/yinruiyang94/article/details/78122629
评论 (0)