Windless
订阅/Feed
稗田千秋(i@wind.moe)

「数论篇」矩阵快速幂

稗田千秋
Dec.24 2016 algorithm

荒废了好久的算法,写几篇关于算法的复习笔记找找手感,在测试 Mathjax 的矩阵时有感而发,便写写矩阵快速幂吧。

幂运算即指数运算,用于表示某数自乘数次,因此易得最基本计算的方法复杂度为 O(N),

快速幂全称快速幂取模,可以在 O(log₂N) 得出结果,用于快速计算某个数的n次幂,此处的n一般很大而导致 O(N) 复杂度超时,而且题目通常会给出一个额外要求使结果对数k取模。

原理用二进制相对较好理解,十进制的 10 转为二进制是 1010,即可以写作$ 10 = 2^3 + 2^1 $,先将n当前末位进行按位与运算 n&1,可以得出n是否为奇数,当 n & 1 为 true 时,就要给当前的结果乘上 a,n >>= 1 则是右移一位,因为当前末位已经无用了,等效于 n /= 2,综合起来就是当 n 可以被2整除时, 用二分的思想转化为 a = a * a,n /= 2,若某一步的 n 不能被2整除,此时 a *=2 便不能计入结果,要给当前结果单独乘上一个 a,待到之后能整除再将此时的a算入,这样便将复杂度巧妙地降到了 O(log₂N) 。

int pow(int a, int n, int k) {
  int res = 1;
  while (n) {
    if (n & 1)
      res = (res * a) % k;
    n >>= 1;
    a = (a * a) % k;
  }
  return res;
}

易推得朴素算法求矩阵相乘的复杂度为 O(N³),既然要求幂,那么自然可以考虑能否由整数的快速幂来进行相似的加速,先构造一个矩阵,同时重载运算符减少之后的工作量。

struct Matrix {
  int n;
  int num[MAX][MAX];
  Matrix operator*(Matrix &a) {
    Matrix res;
    res.n = n;
    for (int i = 1; i <= n; i++)
      for (int j = 1; j <= n; j++) {
        res.num[i][j] = 0;
        for (int k = 1; k <= n; k++)
          res.num[i][j] = (res.num[i][j] + num[i][k] * a.num[k][j]) % MOD;
      }
    return res;
  }
};

模仿上面的快速幂,只不过这里要把结果矩阵初始化为一个单位矩阵,学过线性代数的都知道单位矩阵 $ E $与任何矩阵相乘,结果不改变,那么就等效于整数快速幂中的初始值 1。.

Matrix pow(Matrix &A, int n) {
  Matrix res;
  res.n = A.n;
  for (int i = 1; i <= A.n; i++)
    for (int j = 1; j <= A.n; j++)
      res.num[i][j] = (i == j) ? 1 : 0;

  while (n) {
    if (n & 1)
      res = (res * A);
    n >>= 1;
    A = A * A;
  }
  return res;
}

这样,就实现了一个最简单的矩阵快速幂算法。

部分题解

HDU 1575

#include <cstdio>
#define MAX 100
#define MOD 9973

struct Matrix {
  int n;
  int num[MAX][MAX];
  Matrix operator*(Matrix &a) {
    Matrix res;
    res.n = n;
    for (int i = 1; i <= n; i++)
      for (int j = 1; j <= n; j++) {
        res.num[i][j] = 0;
        for (int k = 1; k <= n; k++)
          res.num[i][j] = (res.num[i][j] + num[i][k] * a.num[k][j]) % MOD;
      }
    return res;
  }
};

Matrix pow(Matrix &A, int n) {
  Matrix res;
  res.n = A.n;
  for (int i = 1; i <= A.n; i++)
    for (int j = 1; j <= A.n; j++)
      res.num[i][j] = (i == j) ? 1 : 0;

  while (n) {
    if (n & 1)
      res = (res * A);
    n >>= 1;
    A = A * A;
  }
  return res;
}

int main(int argc, char const *argv[]) {
  Matrix a;
  int t, m;
  scanf("%d
", &t);
  while (t--) {
    scanf("%d %d
", &a.n, &m);
    for (int i = 1; i <= a.n; i++) {
      for (int j = 1; j <= a.n; j++) {
        scanf("%d", &a.num[i][j]);
      }
    }
    Matrix r = pow(a, m);
    int sum = 0;
    for (int i = 1; i <= a.n; i++) {
      sum += r.num[i][i];
    }
    sum %= MOD;
    printf("%d
", sum);
  }
  return 0;
}

--END--
文章创建于 2016-12-24 01:41:21,最后更新 2016-12-24 01:41:21
Comment
尝试加载Disqus评论, 失败则会使用基础模式.
    • play_arrow

    About this site

    version:1.02 Alpha
    博客主题: Lime
    联系方式: i@wind.moe
    写作语言: zh_CN & en_US
    博客遵循 CC BY-NC-SA 4.0许可进行创作

    此外,本博客会基于访客的Request Headers记录部分匿名数据用于统计(Logger的源码见Github),包含Referer, User-Agent & IP Address.个人绝不会主动将数据泄露给第三方