diff --git a/fastai/plots.py b/fastai/plots.py index efa989018..4b334008d 100755 --- a/fastai/plots.py +++ b/fastai/plots.py @@ -39,13 +39,13 @@ def plots_from_files(imspaths, figsize=(10,5), rows=1, titles=None, maintitle=No plt.imshow(img) -def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues): +def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues, figsize=None): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. (This function is copied from the scikit docs.) """ - plt.figure() + plt.figure(figsize=figsize) plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar()