そのうちシンギュラリティを起こすblog

強い人工知能を作ってそのうちシンギュラリティを起こします。

Numpyのndarrayを使ったPythonのコードをCythonで高速化する

 私がPythonに入門したその日に行ったのは、「C や C++ による Python の拡張」だ。MNISTのサンプルを改造して画像を8bitパソコンみたいに変換するようにしてみたのだが、とにかく遅い。

 速さは力だ!トライアンドエラーをたくさん繰り返して学ぶのは人間もニューラルネットワークも同じだ。スピードが遅ければそれだけ学習も遅くなる。だからとにかく速くしたい!

 ErlangのNIFでC言語を使って高速化するのには慣れているのでPythonでもやってみた。でもあまり速くならなかった。

 Pythonで高速化するにはもっとPythonの作法を知らなければならない。配列とかnumpyのことをよく知らずに小手先だけ高速化してもほとんど効果はないのだ。

 さて、Pillowで読み込んだイメージからタイルをnumpyのarrayにどんどん切り出していく初めて書いたコードがこれだ。

def get_tiles2c(im,ts2,off):
    wlen = im.width
    buf = np.array(im.getdata(),np.float32).reshape( im.width*im.height*Components ) / 255.0
    for ty in xrange(Margin,im.height -Margin):
        tiles = []
        for tx in xrange(Margin,im.width -Margin):

            list = array.array('f')
            for y in xrange(ts2):
                for x in xrange(ts2):
                    idx = ( x+tx-Margin + ( y+ty-Margin ) * wlen ) * Components
                    list.append( buf[ idx ] )
            tiles.append( np.frombuffer(list,dtype=np.float32) )
        yield tiles
    yield []

うそです。最初はもっとひどいコードでした。いろいろいじくりまわしてこれでも相当高速化したつもりです。
# 302.02212405204773秒

これをCythonで高速化していく。CythonはCPythonの書き間違いじゃないよ。
なんと、CythonはPythonソースコードC言語にコンバートたりPythonから読み込めるDLL,SOに変換してくれるのだ!!
このコードを何も考えずにCythonでコンパイルしてみると
# 274.2430958747864秒
少し速くなった!

これに型情報をつけて高速化していく

@cython.boundscheck(False)
@cython.wraparound(False) 
def get_tiles2(im,ts2,off):
    cdef int tx,ty
    cdef int x,y
    cdef int wlen
    cdef int idx
    cdef np.ndarray buf

    wlen = im.width
    buf = np.array(im.getdata(),np.float32).reshape( im.width*im.height*Components ) / 255.0
    for ty in xrange(Margin,im.height -Margin):
        tiles = []
        for tx in xrange(Margin,im.width -Margin):

            list = array.array('f')
            for y in xrange(ts2):
                for x in xrange(ts2):
                    idx = ( x+tx-Margin + ( y+ty-Margin ) * wlen ) * Components
                    list.append( buf[ idx ] )
            tiles.append( np.frombuffer(list,dtype=np.float32) )
        yield tiles
    yield []

# 144.34209299087524秒
単純に型を追加するだけでだいたいPythonの倍速くなる。すごい!

公式のチュートリアルによるとバッファーの型を指定することでもっと高速化することができるらしい。

cdef np.ndarray[float,ndim=1] buf

しかしこれをやるとコンパイルが通らなくなる。

Buffer types only allowed as function local variables

というエラーなのだが、ちゃんと関数内で定義しているのにさっぱり意味がわからない。
1日悩んで、どうやらyieldしているとこの定義ができないらしい。だから次のようにyieldが必要な部分とそうでない部分に分離する。

@cython.boundscheck(False)
@cython.wraparound(False) 
def get_tiles2d(im,int ts2,int off):
    #cdef np.ndarray[float,ndim=1] buf
    cdef np.ndarray buf
    cdef int ty

    buf = np.array(im.getdata(),np.float32).reshape( im.width*im.height*Components ) / 255.0
    for ty in xrange(Margin,im.height -Margin):
        yield get_tiles2_(im,ts2,off,ty,im.width,buf)
    yield []

@cython.boundscheck(False) # 関数全体で境界チェックを無効化
@cython.wraparound(False) 
cdef object get_tiles2_(im,int ts2,int off ,int ty,int wlen,np.ndarray[float,ndim=1] buf):
    cdef int x,y,tx
    cdef int idx

    list = array.array('f')
    for tx in xrange(Margin,im.width -Margin):
        for y in xrange(ts2):
            for x in xrange(ts2):
                idx = ( x+tx-Margin + ( y+ty-Margin ) * wlen ) * Components
                list.append( buf[ idx ] )
    
    return np.frombuffer(list,dtype=np.float32).reshape( im.width-Margin*2,ts2*ts2) 

# 95.15703916549683秒
キタコレ! オリジナルコードの3倍速い!

でもねぇ、C言語でばりばりに高速化したら3倍どころじゃないと思うんですよ。
もうちょっとなんとかならないもんですかねぇ。

arrayで馬鹿正直にappendするんじゃなくて、バッファーに直接書き込むようなことはできないんでしょうか。
いろいろ試してみたところ、次のようなコードでコンパイルが通った。

@cython.boundscheck(False)
@cython.wraparound(False) 
cdef object get_tiles2_d(im,int ts2,int off ,int ty,int wlen,np.ndarray[float,ndim=1] buf):
    cdef int x,y,tx
    cdef int idx
    cdef np.ndarray[float,ndim=1] list = np.empty(ts2*ts2*(im.width-Margin*2), dtype=np.float32)
    cdef int ci = 0
    cdef int yy

    for tx in xrange(Margin,im.width -Margin):
        for y in xrange(ts2):
            idx = ( tx-Margin + ( y+ty-Margin )*wlen  ) * Components
            for x in xrange(ts2):
                list[ci] = buf[ idx ]
                ci += 1
                idx += Components
    
    return list.reshape( im.width-Margin*2,ts2*ts2) 

見るがいい!かものはしペリー!
# 5.6306798458099365秒

オリジナルコードの53倍速くなったぞ!

今の私にはこれ以上高速化できなそうなので、後はこの記事を見たプロフェッショナルに任せよう。