Strassen算法笔记

1.算法推导

矩阵通用的乘法如下面演示:

Strassen算法笔记_第1张图片

即使对于方阵也是类似的计算过程:

Strassen算法笔记_第2张图片

不过,好在有德国数学家Volker Strassen的研究,提出了Strassen算法。该算法的简单介绍如下:
不妨设有矩阵A、B,矩阵C=A*B,如下:

用矩阵乘法可以得出下面的式子:

Strassen算法笔记_第3张图片

现在定义7个新矩阵(读者可以试着思考下它们的由来):

Strassen算法笔记_第4张图片

最后,通过化简得到:

Strassen算法笔记_第5张图片

Strassen算法仅仅比通用矩阵乘法好一点点儿,从运算上看,减少了一次乘运算。通用矩阵乘法算法的时间复杂度是,而Strassen算法的时间复杂度是。但是,随着矩阵行列数的增大,Strassen算法得到的收益也越大。

Strassen算法笔记_第6张图片

从图中不难发现,在矩阵行列数比较小的时候(<45),通用矩阵乘法算法是更好选择。但是,当行列数大于100时,Strassen就要优越很多。

在图论中有邻接矩阵,因此当处理某些图论算法时,可以用到Strassen算法处理矩阵相乘。

2.算法实现

Python代码:

from optparse import OptionParser
from math import ceil, log

def read(filename):
    lines = open(filename, 'r').read().splitlines()
    A = []
    B = []
    matrix = A
    for line in lines:
        if line != "":
            matrix.append(list(map(int, line.split("\t"))))
        else:
            matrix = B
    return A, B

def printMatrix(matrix):
    for line in matrix:
        print("\t".join(map(str,line)))

def ikjMatrixProduct(A, B):
    n = len(A)
    C = [[0 for i in range(n)] for j in range(n)]
    for i in range(n):
        for k in range(n):
            for j in range(n):
                C[i][j] += A[i][k] * B[k][j]
    return C

def add(A, B):
    n = len(A)
    C = [[0 for j in range(0, n)] for i in range(0, n)]
    for i in range(0, n):
        for j in range(0, n):
            C[i][j] = A[i][j] + B[i][j]
    return C

def subtract(A, B):
    n = len(A)
    C = [[0 for j in range(0, n)] for i in range(0, n)]
    for i in range(0, n):
        for j in range(0, n):
            C[i][j] = A[i][j] - B[i][j]
    return C

def strassenR(A, B):
    """ 
        Implementation of the strassen algorithm.
    """
    n = len(A)

    if n <= LEAF_SIZE:
        return ikjMatrixProduct(A, B)
    else:
        # initializing the new sub-matrices
        newSize = int(n/2)
#         print(newSize)
        a11 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        a12 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        a21 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        a22 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]

        b11 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        b12 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        b21 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        b22 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]

        aResult = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        bResult = [[0 for j in range(0, newSize)] for i in range(0, newSize)]

        # dividing the matrices in 4 sub-matrices:
        for i in range(0, newSize):
            for j in range(0, newSize):
                a11[i][j] = A[i][j]            # top left
                a12[i][j] = A[i][j + newSize]    # top right
                a21[i][j] = A[i + newSize][j]    # bottom left
                a22[i][j] = A[i + newSize][j + newSize] # bottom right
 
                b11[i][j] = B[i][j]            # top left
                b12[i][j] = B[i][j + newSize]    # top right
                b21[i][j] = B[i + newSize][j]    # bottom left
                b22[i][j] = B[i + newSize][j + newSize] # bottom right

        # Calculating p1 to p7:
        aResult = add(a11, a22)
        bResult = add(b11, b22)
        p1 = strassenR(aResult, bResult) # p1 = (a11+a22) * (b11+b22)
 
        aResult = add(a21, a22)      # a21 + a22
        p2 = strassenR(aResult, b11)  # p2 = (a21+a22) * (b11)
 
        bResult = subtract(b12, b22) # b12 - b22
        p3 = strassenR(a11, bResult)  # p3 = (a11) * (b12 - b22)
 
        bResult = subtract(b21, b11) # b21 - b11
        p4 =strassenR(a22, bResult)   # p4 = (a22) * (b21 - b11)
 
        aResult = add(a11, a12)      # a11 + a12
        p5 = strassenR(aResult, b22)  # p5 = (a11+a12) * (b22)   
 
        aResult = subtract(a21, a11) # a21 - a11
        bResult = add(b11, b12)      # b11 + b12
        p6 = strassenR(aResult, bResult) # p6 = (a21-a11) * (b11+b12)
 
        aResult = subtract(a12, a22) # a12 - a22
        bResult = add(b21, b22)      # b21 + b22
        p7 = strassenR(aResult, bResult) # p7 = (a12-a22) * (b21+b22)

        # calculating c21, c21, c11 e c22:
        c12 = add(p3, p5) # c12 = p3 + p5
        c21 = add(p2, p4)  # c21 = p2 + p4
 
        aResult = add(p1, p4) # p1 + p4
        bResult = add(aResult, p7) # p1 + p4 + p7
        c11 = subtract(bResult, p5) # c11 = p1 + p4 - p5 + p7
 
        aResult = add(p1, p3) # p1 + p3
        bResult = add(aResult, p6) # p1 + p3 + p6
        c22 = subtract(bResult, p2) # c22 = p1 + p3 - p2 + p6
 
        # Grouping the results obtained in a single matrix:
        C = [[0 for j in range(0, n)] for i in range(0, n)]
        for i in range(0, newSize):
            for j in range(0, newSize):
                C[i][j] = c11[i][j]
                C[i][j + newSize] = c12[i][j]
                C[i + newSize][j] = c21[i][j]
                C[i + newSize][j + newSize] = c22[i][j]
        return C

def strassen(A, B):
    assert type(A) == list and type(B) == list
    assert len(A) == len(A[0]) == len(B) == len(B[0])

    # Make the matrices bigger so that you can apply the strassen
    # algorithm recursively without having to deal with odd
    # matrix sizes
    nextPowerOfTwo = lambda n: 2**int(ceil(log(n,2)))
    n = len(A)
    m = nextPowerOfTwo(n)
    print(m)
    APrep = [[0 for i in range(m)] for j in range(m)]
    BPrep = [[0 for i in range(m)] for j in range(m)]
    for i in range(n):
        for j in range(n):
            APrep[i][j] = A[i][j]
            BPrep[i][j] = B[i][j]
    CPrep = strassenR(APrep, BPrep)
#     print(CPrep)
    C = [[0 for i in range(n)] for j in range(n)]
    for i in range(n):
        for j in range(n):
            C[i][j] = CPrep[i][j]
    return C
    
if __name__ == "__main__":
    parser = OptionParser()
    parser.add_option("-i", dest="filename", default="2000.in",
         help="input file with two matrices", metavar="FILE")
    parser.add_option("-l", dest="LEAF_SIZE", default="8",
         help="when do you start using ikj", metavar="LEAF_SIZE")
    (options, args) = parser.parse_args()
 
    LEAF_SIZE = options.LEAF_SIZE
#     LEAF_SIZE = 10
    A, B = read(options.filename)
#     print(A)
#     A = [[1,2,3],[4,5,6],[7,8,9]]
#     B = [[1,1,1],[1,1,1],[1,1,1]]
    C = strassen(A, B)
    printMatrix(C)


参考文献:

http://www.ituring.com.cn/article/17978

GEMMW: a portable level 3 BLAS Winograd variant of Strassen's matrix-matrix multiply algorithm

你可能感兴趣的:(算法,Strassen)