Умножение чисел с помощью быстрого преоразования Фурье



Необходимо перемножить два больших числа


Дискретным преобразованием Фурье многочлена A (x) n-ной степени называется вектор
{b0, b1, b2, ... , bn - 1} = {A (wn0), A (wn1), A (wn2), ... , A (wnn - 1)},
где wn - комплексный корень n-ной степени из 1, не равный 1.
Если бы мы умели находить быстро прямое и обратное ДПФ для многочлена, то
ДПФ (A(x) * B(x)) = ДПФ (A) * ДПФ (B)
A(x) * B(x) = Обратное_ДПФ (ДПФ (A) * ДПФ (B)),
где под умножением ДПФ подразумевается произведение соответствующих компонент векторов.

Предположим, что n является степенью двойки (если это не так, добавим нули).

Идея быстрого ДПФ заключается в разбиении исходного многочлена степени n на два многочлена степени n / 2, нахождении ДПФ для них и получении ДПФ для всего многочлена.
Рассмотрим вышесказанное подробнее.
Разделим многочлен A (x) = a0x0 + a1x1 + ... + an-1xn-1, на два многочлена
A0 (x) = a0x0 + a2x1 + a4x2 + ... + an-2x(n-2)/2
A1 (x) = a1x0 + a3x1 + a5x2 + ... + an-1x(n-2)/2
Тогда A(x) = A0 (x2) + x * A1 (x2).
Заметим, что среди всех чисел (wni)2 (0 <= i < n), только n / 2 различных. Поэтому, ДПФ A0 и A1 будут (n / 2)-элементными.
А так как ДПФ A0 и A1 состоит из n / 2 элементов, то комплексным корнем из 1 там будет корень степени n / 2.
Значит, A(wnk) = A0(wn2k) + wnk * A1(wn2k) = A0(wn/2k) + wnk * A1(wn/2k), где 0 <= k < n / 2
и A(wnk+n/2) = A0(wn2(k+n/2)) + wnk+n/2 * A1(wn2(k+n/2)) = A0(wn/2k+n/2) + wnk+n/2 * A1(wn/2k+n/2) = A0(wn/2k) - wnk * A1(wn/2k).
Мы использовали свойства комплексных чисел: различных корней степени n всего n; wnk+n/2 = wnk * e2*Pi/n * n/2 = wnk * ePi = - wnk.

Получаем рекурсивный алгоритм:
ДПФ одного элемента равно этому элементу
для нахождения ДПФ A разбиваем коэффициенты A на чётные и нечётные, получаем два многочлена A0 и A1, находим ДПФ для них, находим ДПФ A:
bk = b0k + wnk * b1k
bk + n/2 = b0k - wnk * b1k
для 0 <= k < n / 2.

Можно доказать, что обратное ДПФ можно находить тем же самым способом, что и прямое ДПФ, только за комплексный корень из 1 нужно брать симметричный относительно вещественной оси тому, что мы брали, и результаты (коэффициенты многочлена) необходимо будет поделить на n.

Можно обойтись без рекурсии. В общих словах, используя рекурсию мы идём "сверху вниз", если пойти "снизу вверх", то повысится быстродействие и значительно уменьшатся затраты на память.
Согласно нашим рассуждениям, рекурсия будет вызываться вот так (пример для n = 8):
(0, 1, 2, 3, 4, 5, 6, 7)
(0, 2, 4, 6)(1, 3, 5, 7)
(0, 4)(2, 6)(1, 5)(3, 7)
Если мы сначала возьмём коэффициенты многочлена и переставим их согласно номерам в примере (сначала коэффициент с индексом 0, потом с индексом 4), то для вычисления ДПФ всего многочлена нужно будет пройти по каждой строке снизу вверх (всего строк log2n) и объединить в ДПФ каждую пару наборов коэффициентов.
Любопытно, что числа в последней строке отсортированы по перевернутым битовым записям:
0 1 0 1 0 1 0 1
0 0 1 1 0 0 1 1
0 0 0 0 1 1 1 1
0 4 2 6 1 5 3 7
0 1 2 3 4 5 6 7
Значит, мы можем легко построить последнюю строку индексов. Поменять местами коэффициенты многочлена согласно этой строке. А потом, поднимаясь от нижней строки к верхней и объединяя соседние наборы коэффициентов в ДПФ, получим ДПФ всего многочлена.

Для умножения двух чисел необходимо найти ДПФ каждого числа. Перемножить элементы обоих ДПФ, которые имеют одинаковый индекс (или степень корня из 1). Найти обратное ДПФ от полученной ДПФ. Найденный многочлен будет результатом произведения исходных чисел. Но так как это многочлен, а не число, то там не соблюдены переносы, если число > 9. Пройдёмся один раз по найденному вектору, выполним нужные межразрядные переносы и получим результат.
Заметим, что так как результат в два раза больше по длине, чем каждый множитель, то исходные числа до нахождения ДПФ необходимо дополнить нулями, чтобы длины множителей были больше либо равны длине результата и равнялись степени двойки.

Сложность алгоритма O(N * log N).



Листинг С++

#include <complex>
#include <vector>
#include <cmath>
using namespace std;

#define comp complex < double >
#define vcomp vector < comp >

#define sint int
#define vsint vector < sint >
#define sz size ()
#define at(v, i) (0 <= (i) && (i) < (v).size () ? v[i] : 0)
#define erase_null(v) while ((v).size () > 1 && (v).back () == 0) (v).pop_back ();
#define next_mod(v, i)     if (0 <= (i) + 1 && (i) + 1 < (v).size ()) \
                        {\
                            v[(i) + 1] += v[(i)] / 10;\
                            v[(i)] %= 10; \
                        }
                       


int reverse_bits (int a, int n)
{
    int res = 0;
    for (int i = 0; i < n; ++ i)
    {
        res <<= 1;
        res |= a & 1;
        a >>= 1;
    }
    return res;
}

int logn (int a)
{
    int cnt = 0;
    do
    {
        ++ cnt;
        a >>= 1;
    }
    while (a);
    return cnt;
}

int topow2 (int a)
{
    for (int i = 0; i < 32; ++ i)
        if ((1 << i) >= a)
            return 1 << i;
    return - 1;
}

void swap_revbits (vcomp & a)
{
    int l = logn (a.size ()) - 1;
    vcomp res (a.size ());
    for (int i = 0; i < a.size  (); ++ i)
        res[i] = a[reverse_bits (i, l)];
    a = res;
}

void fft (vcomp & a, bool back)
{
    swap_revbits (a);
    int n = a.size ();
    double t = (back ? - 1 : 1);
    
    for (int m = 2; m <= n; m *= 2)
    {
        comp wm (cos (t * 2 * M_PI / (double)m), sin (t * 2 * M_PI / (double)m));
        
        for (int k = 0; k < n; k += m)
        {
            comp w (1);
            for (int j = 0; j < m / 2; ++ j)
            {
                comp a0 = a[k + j];
                comp w_a1 = w * a[k + j + m / 2];
                a[k + j] = a0 + w_a1;
                a[k + j + m / 2] = a0 - w_a1;
                
                if (back)
                {
                    a[k + j] /= 2.0;
                    a[k + j + m / 2] /= 2.0;
                }
                w *= wm;
            }
        }
    }
}

vsint fft_mul (vsint a, vsint b)
{
    int n = topow2 (max (a.size (), b.size ())) * 2;
    a.resize (n);
    b.resize (n);
    
    vcomp ac (a.begin (), a.end ());
    vcomp bc (b.begin (), b.end ());
    
    fft (ac, false);
    fft (bc, false);
    
    vcomp cc (n);
    for (int i = 0; i < n; ++ i)
        cc[i] = ac[i] * bc[i];
    fft (cc, true);
    
    vsint c (n, 0);
    for (int i = 0; i < n; ++ i)
        c[i] += (int)(cc[i].real () + 0.5);
    for (int i = 0; i < n - 1; ++ i)
        if (c[i] > 9)
        {
            c[i + 1] += c[i] / 10;
            c[i] %= 10;
        }
    while (c[c.size () - 1] > 9)
    {
        c.push_back (c[c.size () - 1] / 10);
        c[c.size () - 2] %= 10;
    }
    
    while (c.back () == 0 && c.size () > 1)
        c.pop_back ();
    
    return c;
}

17:48
14.07.2010


По всем вопросам обращаться: rumterg@gmail.com