require 'druby/contract'

module DRuby
  module Contract

    class Registry
      @gradual_classes = {}
      def self.register(clazz, meth, sig)
        @gradual_classes[clazz] = {} unless @gradual_classes[clazz]
        @gradual_classes[clazz][meth] = sig
      end
      
      def self.get_sig(name,m)
        if meths = @gradual_classes[name]
        then return meths[m]
        else return nil
        end
      end
    end

    class BaseContract
      def build_ctx_string(msg,ctx)
        s = msg
        ctx.each {|m| s << "\n  " << m}
        s << "\n"
      end
      
      def Violation(ctx,msg)
        s = build_ctx_string(msg, ctx)
        throw :ContractViolation, [:err,s]
      end
    end

    class Params < BaseContract
      def initialize(reqs,opts = [],var = nil)
        @required = reqs
        @opts = opts
        @varargs = var
      end

      def assert_contract(ctx_,actuals)
        ctx = ["checking parameter list"] + ctx_
        assert_length(ctx,actuals.length)
        req_actuals = actuals[0...(@required.length)]
        opt_actuals = actuals[(@required.length)...(@required.length+@opts.length)]
        var_actuals = actuals[(@required.length+@opts.length)...(actuals.length)]
        num = 0
        ctx = ["checking required parameters list"] + ctx_
        @required.zip(req_actuals).each do |(req,act)| 
          num += 1;
          ctx = ["checking parameter #{num}"] + ctx_
          req.assert_contract(ctx,act)
        end
        ctx = ["checking optional parameters list"] + ctx_
        @opts.zip(opt_actuals).each do |(opt,act)| 
          num += 1;
          ctx = ["checking parameter #{num}"] + ctx_
          opt.assert_contract(ctx,act) if (opt && act)
        end
        ctx = ["checking variable parameters list"] + ctx_
        if @varargs
          var_actuals.each do |act|
            num += 1;
            ctx = ["checking parameter #{num}"] + ctx_
            @varargs.assert_contract(ctx,act)
          end
        end

      end

      def assert_length(ctx,num_actuals)
        if num_actuals < @required.length
          Violation(ctx,"exptected at least %s formal parameters, got %d actuals", 
                    @required.length, num_actuals)
        end
        return true if @varargs
        if num_actuals > @required.length+@opts.length
          Violation(ctx,
                    sprintf("exptected at most %s formal parameters, got %d actuals", 
                            (@required.length+@opts.length), num_actuals)
                    )
        end
      end
    end


    class UnionType < BaseContract
      def initialize(tlist)
        @types = tlist
      end
      def assert_contract(ctx,tprime)
        ctx = ["Union Type"] + ctx
        return if tprime.nil?
        unless @types.any? do |t| 
            (catch :ContractViolation do
               t.assert_contract(ctx,tprime)
               true
             end) == true
          end
          Violation(ctx,"contract type error: expected #{clazz.name}, got #{tprime.class.name}")
        end
      end
    end

    class Type < BaseContract
      def initialize(t)
        @cname = t
      end
      def assert_contract(ctx,tprime)
        clazz = eval @cname
        unless tprime.nil?
          unless clazz === tprime
            Violation(ctx,"contract type error: expected #{clazz.name}, got #{tprime.class.name}")
          end
        end
      end
    end

    class MonoMethod < BaseContract
      def initialize(args,blk,ret,pos)
        @pos = pos
        @params = args
        @block = blk
        @ret = ret
      end
      
      def assert_contract(actuals,blk,&meth)
        res = catch :ContractViolation do
          ctx_ = [@pos.to_s,*caller]
          ctx = ["checking parameters"] + ctx_
          @params.assert_contract(ctx,actuals)
          ctx = ["verifying block signature"] + ctx_
          blk2 = if @block then @block.project(ctx,blk) else blk end
          r = meth.call(actuals,blk2)
          ctx = ["verifying return type"] + ctx_
          @ret.assert_contract(ctx,r)
          r
        end
        err,msg = res
        if err == :err
          puts msg
          exit 1
        end
        res
      end
    end

    class InterMethod < BaseContract
      def initialize(funcs, pos)
        @funcs = funcs
        @pos = pos
      end
      
      def assert_contract(actuals,blk,&meth)
        res = catch :ContractViolation do
          ctx = [@pos.to_s,*caller]
          valid = []
          
          @funcs.each do |(fargs,fblk,fret)|
            catch :ContractViolation do
              fargs.assert_contract(ctx,actuals)
              valid.push [fblk,fret]
              next
            end
          end
          blk_ctx = ctx.clone
          blk2 = Proc.new do |*actuals|
            if valid.empty? 
              Violation(ctx,"all members of intersection type failed in before block test")
            end

            valid.reject! do |(fblk,fret)|
              unless fblk
                next(true) 
              end
              res = catch :ContractViolation do
                fblk.args.assert_contract(blk_ctx,actuals) 
                false
              end
              if res == false then next(false)
              else next(true)
              end
            end
            if valid.empty? 
              Violation(ctx,"all members of intersection type failed in block test: #{actuals.inspect}")
            end


            r = blk.call(*actuals)
            valid.reject! do |(fblk,fret)|
              next(true) unless fblk
              err, m = catch :ContractViolation do
                fblk.ret.assert_contract(ctx,r) 
              end 
              next(true) if err == :err 
              next false
            end
            if valid.empty? 
              Violation(ctx,"all members of intersection type failed in block return test")
            end
            r
          end if blk

          r = meth.call(actuals,blk2)
          ctx = ["verifying return type #{r.class.name}"] + ctx
          valid.reject! do |(fblk,fret)|
            err,m = catch :ContractViolation do
              fret.assert_contract(ctx,r) 
            end
            next(true) if err == :err
            next false
          end

          if valid.empty? 
            Violation(ctx,"all members of intersection type failed in method return")
          end
          r

        end
        err,msg = res
        if err == :err
          puts msg
          exit 1
        end
        res
      end

    end

    class Block < BaseContract
      attr_accessor :args, :ret
      def initialize(args,ret)
        @args = args
        @ret = ret
      end

      def project(ctx_,blk)
        Proc.new do |*actuals|
          Violation(ctx_,"block was not provided") unless blk
          ctx = ["verify block arguments"] + ctx_
          @args.assert_contract(ctx,actuals)
          r = blk.call(*actuals)
          ctx = ["verify block return"] + ctx_
          @ret.assert_contract(ctx,r)
          r
        end
      end
    end
  end
end

class Module
  @@inserted = {}

  alias :old_include :include
  def include(*args)
    #    args.each {|m| puts "#{self.name} include #{m}"}
    old_include(*args)
  end
  
  alias :old_added :method_added
  def method_added(meth)
    if DRuby::Contract::Registry.get_sig(self.name,meth)
      @@inserted[self.name] = {} unless @@inserted[self.name]
      if @@inserted[self.name][meth] && !@inserting_cast
        #        puts "Already instrumented: #{self.name} #{meth}"
      end
      @@inserted[self.name][meth] = 1
      insert_cast(meth) unless meth == :initialize
    end
    old_added(meth)
  end
  @@count = 1
  def insert_cast(m)
    if @inserting_cast then return
    else 
      @inserting_cast = true
      @@count += 1
      puts "instrumenting #{self.name} #{m}"
      if /^[a-zA-Z_]*$/ =~ m.to_s
        orig = :"untyped_#{m}"
      else
        orig = :"untyped_#{@@count}"
      end
      self.class_eval do 
        define_method(orig,instance_method(m))
        if /^([^:]+):(\d+)/ =~ caller[3]
          file = $1
          line = $2.to_i
        else
          file = __FILE__
          line = __LINE__+3
        end
        eval <<-EOF, binding(), file, line
        def #{m}(*args,&blk)
          sig = DRuby::Contract::Registry.get_sig("#{self.name}",#{m.inspect})
                                           sig.assert_contract(args,blk) do |args,blk| 
                                             if args then
                                               #{orig}(*args,&blk) 
                                             else
                                               #{orig}(&blk) 
                                             end
                                           end
                                         end
        EOF
      end
    @inserting_cast = false
    end
  end
end
