diff --git a/.gitignore b/.gitignore index d2ad6f7..3998024 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,7 @@ *.png *.jpg itchat.pkl +*.tgz +English +train +valid diff --git a/Preprocessing dataset.html b/Preprocessing dataset.html new file mode 100644 index 0000000..41436c2 --- /dev/null +++ b/Preprocessing dataset.html @@ -0,0 +1,12303 @@ + + +
+数据集来自 http://www.ee.surrey.ac.uk/CVSSP/demos/chars74k/
+我们下载的是EnglishFnt.tgz,是印刷体数字加大小写字母。
+ + +import requests
+from tqdm import tqdm
+import os
+
+fileurl = 'http://www.ee.surrey.ac.uk/CVSSP/demos/chars74k/EnglishFnt.tgz'
+filename = 'EnglishFnt.tgz'
+if not os.path.exists(filename):
+ r = requests.get(fileurl, stream=True)
+ with open(filename, 'wb') as f:
+ for chunk in tqdm(r.iter_content(1024), unit='KB', total=int(r.headers['Content-Length'])/1024):
+ f.write(chunk)
+
import tarfile
+import shutil
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def rmdir(path):
+ if os.path.exists(path):
+ shutil.rmtree(path)
+
+with tarfile.open(filename, 'r') as tfile:
+ print 'loading'
+ members = tfile.getmembers()
+ for member in tqdm(members):
+ if tarfile.TarInfo.isdir(member):
+ mkdir(member.name)
+ continue
+ with open(member.name, 'wb') as f:
+ f.write(tfile.extractfile(member).read())
+
notnumdir = 'English/Fnt/Sample011/'
+for i in tqdm(range(12, 63)):
+ path = 'English/Fnt/Sample%03d/' % i
+ for filename in os.listdir(path):
+ os.rename(path+filename, notnumdir+filename)
+ os.rmdir(path)
+
import cv2
+import numpy as np
+
+def resize(rawimg): # resize img to 28*28
+ fx = 28.0 / rawimg.shape[0]
+ fy = 28.0 / rawimg.shape[1]
+ fx = fy = min(fx, fy)
+ img = cv2.resize(rawimg, None, fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC)
+ outimg = np.ones((28, 28), dtype=np.uint8) * 255
+ w = img.shape[1]
+ h = img.shape[0]
+ x = (28 - w) / 2
+ y = (28 - h) / 2
+ outimg[y:y+h, x:x+w] = img
+ return outimg
+
+def convert(imgpath):
+ img = cv2.imread(imgpath)
+ gray = cv2.imread(imgpath, cv2.IMREAD_GRAYSCALE)
+ bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 25)
+ img2, ctrs, hier = cv2.findContours(bw.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ rects = [cv2.boundingRect(ctr) for ctr in ctrs]
+ x, y, w, h = rects[-1]
+ roi = gray[y:y+h, x:x+w]
+ return resize(roi)
+
import matplotlib.pyplot as plt
+
+%matplotlib inline
+
+imgpath = 'English/Fnt/Sample001/img001-00001.png'
+img = cv2.imread(imgpath)
+rsz = convert(imgpath)
+
+plt.subplot(1, 2, 1)
+plt.imshow(img, cmap='gray')
+plt.subplot(1, 2, 2)
+plt.imshow(rsz, cmap='gray')
+
rmdir('train')
+
+for i in range(11):
+ path = 'English/Fnt/Sample%03d/' % (i+1)
+ trainpath = 'train/%d/' % i
+ mkdir(trainpath)
+ for filename in tqdm(os.listdir(path), desc=trainpath):
+ try:
+ cv2.imwrite(trainpath + filename, convert(path + filename))
+ except:
+ pass
+
from sklearn.model_selection import train_test_split
+for i in range(11):
+ trainpath = 'train/%d/' % i
+ validpath = 'valid/%d/' % i
+ mkdir(validpath)
+ imgs = os.listdir(trainpath)
+ trainimgs, validimgs = train_test_split(imgs, test_size=0.1)
+ for filename in validimgs:
+ os.rename(trainpath+filename, validpath+filename)
+
+
from keras.models import Sequential
+from keras.layers.core import Dense, Dropout, Activation, Flatten
+from keras.preprocessing import image
+
+import numpy as np
+import tensorflow as tf
+import matplotlib.pyplot as plt
+import cv2
+
+%matplotlib inline
+%config InlineBackend.figure_format = 'retina'
+
from keras.preprocessing.image import ImageDataGenerator
+
+train_datagen = ImageDataGenerator(rescale=1.0/255)
+train_generator = train_datagen.flow_from_directory(
+ 'train', # this is the target directory
+ classes=map(str, range(11)),
+ color_mode='grayscale',
+ target_size=(28,28),
+ batch_size=128,
+ class_mode='categorical')
+
+validation_datagen = ImageDataGenerator(rescale=1.0/255)
+validation_generator = validation_datagen.flow_from_directory(
+ 'valid', # this is the target directory
+ classes=map(str, range(11)),
+ color_mode='grayscale',
+ target_size=(28,28),
+ batch_size=128,
+ class_mode='categorical')
+
x, y = train_generator.next()
+
+for i, (img, label) in enumerate(zip(x, y)[:24]):
+ plt.subplot(3, 8, i+1)
+ plt.title(str(np.argmax(label)))
+ plt.axis('off')
+ plt.imshow(img[:,:,0], interpolation="nearest", cmap='gray')
+
我们的模型结构很简单,784->512->512->11
+ +model = Sequential()
+model.add(Flatten(input_shape=(28, 28, 1)))
+model.add(Dense(512))
+model.add(Activation('relu'))
+model.add(Dropout(0.2))
+model.add(Dense(512))
+model.add(Activation('relu'))
+model.add(Dropout(0.2))
+model.add(Dense(11))
+model.add(Activation('softmax'))
+
+model.compile(loss='categorical_crossentropy',
+ optimizer='adadelta',
+ metrics=['accuracy'])
+
+model.summary()
+
model.fit_generator(
+ train_generator,
+ samples_per_epoch=51200,
+ nb_epoch=10,
+ validation_data=validation_generator,
+ nb_val_samples=5120)
+
with open('model.json', 'w') as f:
+ f.write(model.to_json())
+model.save_weights('model.h5', overwrite=True)
+
def resize(rawimg): # resize img to 28*28
+ fx = 28.0 / rawimg.shape[0]
+ fy = 28.0 / rawimg.shape[1]
+ fx = fy = min(fx, fy)
+ img = cv2.resize(rawimg, None, fx=fx, fy=fy, interpolation=cv2.INTER_CUBIC)
+ outimg = np.ones((28, 28), dtype=np.uint8) * 255
+ w = img.shape[1]
+ h = img.shape[0]
+ x = (28 - w) / 2
+ y = (28 - h) / 2
+ outimg[y:y+h, x:x+w] = img
+ return outimg
+
+
+def convert(imgpath): # read digits
+ img = cv2.imread(imgpath)
+ gray = cv2.imread(imgpath, cv2.IMREAD_GRAYSCALE)
+ bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 25)
+ img2, ctrs, hier = cv2.findContours(bw.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ rects = [cv2.boundingRect(ctr) for ctr in ctrs]
+
+ for rect in rects:
+ x, y, w, h = rect
+ roi = gray[y:y+h, x:x+w]
+ hw = float(h) / w
+ if (w < 200) & (h < 200) & (h > 10) & (w > 10) & (1.1 < hw) & (hw < 5):
+ res = resize(roi)
+ res = np.resize(res, (1, 28, 28, 1))
+
+ predictions = model.predict(res)
+ predictions = np.argmax(predictions)
+ if predictions != 10:
+ cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 1)
+ cv2.putText(img, '{:.0f}'.format(predictions), (x, y), cv2.FONT_HERSHEY_DUPLEX, h/25.0, (255, 0, 0))
+ return img
+
+plt.imshow(convert('test.png')[:,:,::-1])
+
+