ABC275 D - Yet Another Recursive Function の考え方を書きました。

問題の概要

# Python

def f(k):
    if k==0:
        return 1
    else:
        return f(k//2) + f(k//3)
# Ruby

def f(k)
  if k==0
    1
  else
    f(k/2) + f(k/3)
  end
end

について、 f(N) を求めます。
与えられるNの範囲は、 0以上 かつ 10^18以下 です。

愚直な実装をすると何が起きるか

上のコードにNの入力とf(N)の出力を追加しました。

# Python

N = int(input())

def f(k):
    if k==0:
        return 1
    else:
        # "k//N" は、Nで割って小数点以下切り捨て の意味
        return f(k//2) + f(k//3)

print(f(N))
# Ruby

N = gets.to_i

def f(k)
  if k==0
    1
  else
    f(k/2) + f(k/3)
  end
end

puts f(N)

これで、正しい答えを求められるようになりました。

しかし、このまま提出すると TLE(Time Limit Exceeded / 実行時間制限超過) になってしまいます。 この問題の実行時間制限 “2秒以内” で処理が終わらない場合がある、ということです。

入出力は整数たった1つですから、重いのは(実行時間のほとんどは) f() の部分と考えてよいでしょう。

このようなとき、まず”ループの実行回数”を考えると、高速化の手がかりになります。
しかし、このコードにはfor文がありません。こまった。

無駄な処理はどこか

実動作から”無駄”を探してみることにします。

何を計算中か、都度出力するようにしてみました↓

# Python

def f(k):
    if k==0:
        return 1
    else:
        print(f'f({k})を計算中...')
        return f(k//2) + f(k//3)
# Ruby

def f(k)
  if k==0
    1
  else
    puts("f(#{k})を計算中...")
    f(k/2) + f(k/3)
  end
end

さらに、見やすさのため、再帰の深さに応じてインデントを追加しました↓

# Python

def f(k, depth=0):
    if k==0:
        return 1
    else:
        print('   '*depth, f'f({k})を計算中...')
        return f(k//2, depth+1) + f(k//3, depth+1)
# Ruby

def f(k, depth=0)
  if k==0
    1
  else
    puts('   '*depth + "f(#{k})を計算中...")
    f(k/2, depth+1) + f(k/3, depth+1)
  end
end

Pythonインタプリタで動かしてみると、以下の出力が得られました。

>>> f(36)
 f(36)を計算中...
    f(18)を計算中...
       f(9)を計算中...
          f(4)を計算中...
             f(2)を計算中...
                f(1)を計算中...
             f(1)を計算中...
          f(3)を計算中...
             f(1)を計算中...
             f(1)を計算中...
       f(6)を計算中...
          f(3)を計算中...
             f(1)を計算中...
             f(1)を計算中...
          f(2)を計算中...
             f(1)を計算中...
    f(12)を計算中...
       f(6)を計算中...
          f(3)を計算中...
             f(1)を計算中...
             f(1)を計算中...
          f(2)を計算中...
             f(1)を計算中...
       f(4)を計算中...
          f(2)を計算中...
             f(1)を計算中...
          f(1)を計算中...

f(1) が繰り返し計算されています。
その理由が、 f(2)f(3) の計算において必要となるから ということもわかりました。
さらによく見ると、 f(6)f(4) の計算がそれぞれ2回ずつ行われています。

もし、この計算を手作業でやった場合、「f(6)は先ほど計算したから…」と、計算結果を再利用することでしょう。
このコードは計算結果の記憶と再利用をしないため、全く同じ計算を繰り返してしまっています。
N=36 なら実行時間への影響は誤差程度ですが、特にNが大きい(10^18に近い)場合には重複した計算の実行回数が膨大になってしまいます。

結果を”メモ”しよう

重複した計算を行わないようにするため、計算結果をメモして再利用できるようにします。

# Python

memo = {}

def f(k):
    if k==0:
        return 1
    elif k in memo: # まずはメモにないかチェック
        return memo[k] # あるなら利用する
    else:
        print(f'f({k})を計算中...')
        val = f(k//2) + f(k//3)
        memo[k] = val # メモに書き込む
        return val
# Ruby

@memo = {}

def f(k)
  if k==0
    1
  elsif @memo.has_key?(k) # まずはメモにないかチェック
    @memo[k] # あるなら利用する
  else
    puts("f(#{k})を計算中...")
    val = f(k/2) + f(k/3)
    @memo[k] = val # メモに書き込む
    val
  end
end

表示整形用のインデントは削除しました。

Pythonインタプリタで動かしてみると、以下の出力が得られました。

>>> f(36)
f(36)を計算中...
f(18)を計算中...
f(9)を計算中...
f(4)を計算中...
f(2)を計算中...
f(1)を計算中...
f(3)を計算中...
f(6)を計算中...
f(12)を計算中...

先程の結果と比べて、明らかに計算回数が減っています。

同じ入力に対しては常に同じ出力であり、かつ、同じ入出力が何度も必要になる という場面では、一度計算した”入力と出力のペア”をメモすることで無駄な処理を省くことができます。
このテクニックを再帰関数に適用する手法は「メモ化再帰」と呼ばれます。

ということで、メモ化再帰の典型問題でした。

計算量の考察については、公式解説をご確認ください。

実装例

  • Python 3.8.2 (実行時間: 27 ms)
  • Ruby 2.7.1 (実行時間: 59 ms)
    • ||= 演算子を使うと、memoの参照と代入をシンプルに書けます。