レキシカルな Measure.with が実現できたっぽい

前回の日記で書いた改良はまだやってないです・・・今日は、できるかどうか分からないけどやってみたらできちゃったので、レキシカルな Measure.with の紹介をします。

現行の Measure.form は動的に単位メソッドの振舞いが置き換わります。たとえば、トップレベルで定義された以下の foo メソッドについて考えてみましょう。

def foo
  1.km
end

このメソッドは、普通に呼び出すと NoMethodError を投げます。

p foo.to_s #=> NoMethodError

しかし、次のように呼び出すと、Measure オブジェクトを返してくれます。

require 'measure'
require 'measure/support'
require 'measure/length'
p Measure.form { foo }.to_s #=> "1 [km]"

自由変数についてのこのような挙動は動的バインディングって呼びます (e.g. Emacs Lisp)。メソッドの場合も同じ呼称でいいのかな?

個人的に、ダイナミックバインディングは気持ち悪いです。なぜかというと、自分以外の人が定義した Numeric のメソッドが意図せずに置き換えられてしまうからです。たとえば、Measure と ActiveSupport を同時に使う場合に Numeric#seconds などが被る可能性は多いに有り得ます。そこで、Measure.form (次期バージョンでは Measure.with) の後置ブロックのフレームとその下すべてではなく、そのブロックの do から end までのレキシカルな範囲だけで単位メソッドをバインディングしたいのです。

いろいろと試行錯誤した結果、set_trace_func を使ってメソッド呼び出しを監視し、Proc.new, Kernel.lambda, そして Kernel.proc を置き換えることでレキシカルな Measure.with を実現できました。以下にプロトタイプコードを貼っておきます。このコードは http://gist.github.com/27299 にも置いてあります。

module Measure
  class Unit
    class << self
      def unit_name
        return nil
      end

      alias __inspect__ inspect
      def inspect
        if unit_name
          "#<Measure::Unit[#{unit_name}]>"
        else
          __inspect__
        end
      end
    end
  end

  class Context
    def initialize
      @units = {}
    end

    def has_unit?(name)
      name = name.intern if String === name
      return @units.has_key? name
    end

    def [](name)
      name = name.intern if String === name
      return nil unless has_unit? name
      return @units[name]
    end

    def def_unit(name)
      name = name.intern if String === name
      unless has_unit? name
        @units[name] = Class.new Unit
        @units[name].class_eval do
          @@unit_name = name
          def self.unit_name
            @@unit_name
          end
        end
      end
      return @units[name]
    end
  end

  class Quantity
    def initialize(value, unit)
      @value, @unit = value, unit
    end

    def to_s
      "#{@value} [#{@unit.unit_name}]"
    end
  end

  class << self
    def default_context
      @@default_context ||= Context.new
      return @@default_context
    end

    def current_context
      get_thread_context || default_context
    end

    def with(ctx, &block)
      begin
        saved_ctx = get_thread_context
        set_thread_context ctx
        block.call
      ensure
        set_thread_context saved_ctx
      end
    end

    private

    def has_thread_storage?
      return !!Thread.current[:measure]
    end

    def get_thread_storage(key)
      return nil unless has_thread_storage?
      return Thread.current[:measure][key]
    end

    def set_thread_storage(key, value)
      Thread.current[:measure] ||= {}
      Thread.current[:measure][key] = value
      return nil
    end

    def get_thread_context
      return nil unless has_thread_storage?
      return get_thread_storage(:context)
    end

    def set_thread_context(context)
      if context == Measure.default_context
        set_thread_storage :context, nil
      else
        set_thread_storage :context, context
      end
      return context
    end

    def push_call_stack(*values)
      stack = get_thread_storage :call_stack
      unless stack
        stack = []
        set_thread_storage :call_stack, stack
      end
      stack.push values
      return nil
    end

    def pop_call_stack
      stack = get_thread_storage :call_stack
      return nil unless stack
      return stack.pop
    end

    def call_stack(index)
      stack = get_thread_storage :call_stack
      return stack[index] if stack
      return nil
    end

    def dump_call_stack
      p get_thread_storage :call_stack
    end

    def with_context
      # dump_call_stack
      a = call_stack(-3) # expect Measure.with
      b = call_stack(-2) # expect Proc.call with Measure::Context
      return get_thread_context if a && a[0] == Measure && a[1] == :with
      return b[2] if b && b[0] == Proc && b[1] == :call
      return nil
    end

    def set_context_to_proc(f)
      # dump_call_stack
      a = call_stack(-3) # expected Measure.with
      b = call_stack(-2) # expected Proc.call with Measure::Context
      if (a && a[0] == Measure && a[1] == :with) ||
          (b && b[0] == Proc && b[1] == :call && b[2])
        context = Measure.current_context
        f.instance_eval { @_measure_context = context }
      end
      return f
    end

    tracer = lambda do |event, file, line, id, binding, klass|
      case event
      when "call"
        next if klass == Measure && id != :with
        if klass == Proc && (id == :call || id == :[])
          f = binding.eval("self")
          context = f.instance_eval { @_measure_context }
          Measure.module_eval { push_call_stack Proc, :call, context }
        else
          Measure.module_eval { push_call_stack klass, id }
        end
      when "return"
        top = Measure.module_eval { call_stack(-1) }
        if top && top[0] == klass && top[1] == id
          Measure.module_eval { pop_call_stack }
        end
      end
    end
    set_trace_func tracer
  end
end

$measure = Measure::Context.new
$measure.def_unit :km

module Kernel
  alias __lambda__ lambda
  def lambda(&block)
    f = __lambda__(&block)
    return Measure.module_eval { set_context_to_proc f }
  end

  alias __proc__ proc
  def proc(&block)
    f = __proc__(&block)
    return Measure.module_eval { set_context_to_proc f }
  end
end

class Proc
  class << self
    alias __new__ new
    def new(&block)
      f = __new__(&block)
      return Measure.module_eval { set_context_to_proc f }
    end
  end

  alias __call__ call
  def call(*args)
    if @_measure_context
      begin
        ctx = @_measure_context
        save_ctx = Measure.current_context
        Measure.module_eval { set_thread_context ctx }
        return __call__(*args)
      ensure
        Measure.module_eval { set_thread_context save_ctx }
      end
    else
      return __call__(*args)
    end
  end
  alias [] call
end

class Numeric
  def km
    context = Measure.module_eval { with_context }
    if context
      Measure::Quantity.new(self, context[:km])
    else
      self
    end
  end
end

def method_outside_with
  puts "method_outside_with(1): #{1.km}"
end

method_outside_with
$proc_outside_with = lambda { puts "proc_outside_with(1): #{1.km}" }
Measure.with($measure) do
  method_outside_with
  puts "inside_with(1 [km]): #{1.km}"
  lambda { puts "proc_inside_with_call(1 [km]): #{1.km}" }.call
  $proc_inside_with = lambda { puts "proc_inside_with(1 [km]): #{1.km}" }
  $proc_outside_with.call
  lambda {
    $proc_inside_proc_inside_with = lambda {
      puts "proc_inside_proc_inside_with(1 [km]): #{1.km}"
    }
  }.call
  lambda {
    lambda {
      $proc_inside_proc_inside_proc_inside_with = lambda {
        puts "proc_inside_proc_inside_proc_inside_with(1 [km]): #{1.km}"
      }
    }.call
  }.call
  lambda {
    lambda {
      Measure.with(nil) do
        $context_conservation = lambda {
          puts "(false): #{Measure.current_context == $measure}"
        }
      end
    }.call
  }.call
  $proc_inside_proc_inside_with.call
  $context_conservation.call
end
$proc_inside_with.call
$proc_inside_proc_inside_with.call
$proc_inside_proc_inside_proc_inside_with.call
$context_conservation.call

## $ ruby measure_lexical_with_test.rb
## method_outside_with(1): 1
## method_outside_with(1): 1
## inside_with(1 [km]): 1 [km]
## proc_inside_with_call(1 [km]): 1 [km]
## proc_outside_with(1): 1
## proc_inside_proc_inside_with(1 [km]): 1 [km]
## (false): false
## proc_inside_with(1 [km]): 1 [km]
## proc_inside_proc_inside_with(1 [km]): 1 [km]
## proc_inside_proc_inside_proc_inside_with(1 [km]): 1 [km]
## (false): false