OCamlで階層的手法によるクラスタリングをやってみました

Blogopolis すごい、しかもRubyのとこの隅っこに自分のIDがあってびっくりしました。とにかく、クラスタリングって格好いい!と思いました。クラスタリングって言語学だと概論で歴史言語学の何かに使われてるのくらいしか見たことがないから、これまで触る機会が全然ありませんでした。そこで、少しは知っておかなきゃと思いましたので、はてな村の地図『HatenarMaps』を公開しました - kaisehのブログに載っていたクラスタリング (クラスター分析)を参考に、OCamlクラスタリングをしてみることにしました。

とりあえず、いきなり難しいことは無理ですから、手始めにランダムな整数の列をクラスタリングしてみることにしました。というわけで、階層的手法における最短距離法、最長距離法、群平均法の三つにチャレンジしてみました。今回は整数だけだから、ウォード法ってのはダメなんですよね。クラスタのセントロイドってのがベクトルじゃないと求められないからなのかな?

type 'a cluster =
  | Empty
  | Leaf of 'a
  | Node of ('a cluster * 'a cluster)
;;

let rec elements_of_cluster = function
  | Empty -> []
  | Leaf elt -> [elt]
  | Node (a, b) ->
      List.rev_append (elements_of_cluster a) (elements_of_cluster b)
;;

let rec string_of_cluster = function
  | Empty -> ""
  | Leaf e -> Printf.sprintf "%d" e
  | Node (a, b) ->
      Printf.sprintf "(%s, %s)" (string_of_cluster a) (string_of_cluster b)
;;

module Clustering =
  struct
    module type ELEMENT =
      sig
	type t
	val distance : t -> t -> int
      end

    module type METHOD =
      sig
	type elt
	type t = elt cluster
	val distance : t -> t -> int
      end

    module type MAKE_METHOD =
      functor(Elt : ELEMENT) -> METHOD with type elt = Elt.t

    module type S =
      sig
	type t
	val join : t list -> t
      end

    module Make = functor(Method : METHOD) ->
    struct
      type t = Method.t
      let rec find clusters =
	let rec f (d, pair) = function
	  | [] | _::[] -> (d, pair)
	  | c1::tl -> f (List.fold_left (fun (d, pair) c2 ->
	      let d' = Method.distance c1 c2 in
	      if d = -1 || d > d' then (d', (c1, c2)) else (d, pair)
	    ) (d, pair) tl) tl
	in
	snd (f (-1, (Empty, Empty)) clusters)
      let rec join = function
	| cluster::[] -> cluster
	| clusters ->
	    let c1, c2 = find clusters in
	    let rest =
	      List.filter (fun c -> c <> c1 && c <> c2) clusters
	    in
	    join (Node (c1, c2) :: rest)
    end
  end
;;

module MakeSingleLinkage : Clustering.MAKE_METHOD =
  functor(Elt : Clustering.ELEMENT) ->
struct
  type elt = Elt.t
  type t = elt cluster
  let distance c1 c2 =
    let l1 = elements_of_cluster c1 and l2 = elements_of_cluster c2 in
    List.fold_left (fun d e1 -> List.fold_left (fun d' e2 ->
      let n = Elt.distance e1 e2 in
      if d'= -1 || n < d' then n else d'
    ) d l2) ~-1 l1
end

module MakeCompleteLinkage : Clustering.MAKE_METHOD =
  functor(Elt : Clustering.ELEMENT) ->
struct
  type elt = Elt.t
  type t = elt cluster
  let distance c1 c2 =
    let l1 = elements_of_cluster c1 and l2 = elements_of_cluster c2 in
    List.fold_left (fun d e1 -> List.fold_left (fun d' e2 ->
      let n = Elt.distance e1 e2 in
      if d'= -1 || n > d' then n else d'
    ) d l2) ~-1 l1
end

module MakeGroupAverage : Clustering.MAKE_METHOD =
  functor(Elt : Clustering.ELEMENT) ->
struct
  type elt = Elt.t
  type t = elt cluster
  let distance c1 c2 =
    let l1 = elements_of_cluster c1 and l2 = elements_of_cluster c2 in
    (List.fold_left (fun d e1 -> List.fold_left (fun d' e2 ->
      d' + (Elt.distance e1 e2)
    ) d l2) 0 l1) / ((List.length l1) * (List.length l2))
end

module IntElement =
  struct
    type t = int
    let distance e1 e2 = abs (e1 - e2)
  end
;;

module SingleLinkage = Clustering.Make(MakeSingleLinkage(IntElement))
module CompleteLinkage = Clustering.Make(MakeCompleteLinkage(IntElement))
module GroupAverage = Clustering.Make(MakeGroupAverage(IntElement))

let make_cluster size =
  let rec make_list list = function
    | 0 -> list
    | size -> Random.int 1000 :: (make_list list (size - 1))
  in
  List.map (fun e -> Leaf e) (make_list [] size)

let main =
  Random.self_init ();
  let clusters = (make_cluster 100) in
  print_endline (Printf.sprintf "[%s]"
    (List.fold_left (fun s c ->
      s ^ (Printf.sprintf "%s; " (string_of_cluster c))
    ) "" clusters));
  print_endline (string_of_cluster (SingleLinkage.join clusters));
  print_endline (string_of_cluster (CompleteLinkage.join clusters));
  print_endline (string_of_cluster (GroupAverage.join clusters))

なんというかfunctorがうまい事書けてない気がするんですが、クラスタリングできましたよ!

[194; 642; 604; 553; 162; 384; 25; 527; 419; 762; 603; 138; 516; 808; 807; 422; 502; 73; 150; 261; 986; 437; 211; 132; 415; 553; 525; 672; 266; 740; 891; 674; 841; 994; 892; 710; 338; 98; 853; 818; 149; 795; 372; 463; 879; 772; 76; 246; 396; 859; 975; 636; 576; 253; 616; 344; 801; 434; 661; 276; 46; 927; 498; 601; 613; 808; 766; 53; 929; 35; 346; 256; 677; 115; 286; 350; 861; 782; 516; 895; 752; 915; 835; 814; 694; 964; 0; 64; 381; 84; 944; 285; 993; 818; 296; 893; 942; 81; 1; 529; ]

という列に対して、最短距離法は次のようにクラスタリングしてくれました。

(((((((553, 553), 576), ((((527, 525), 529), (516, 516)), (502, 498))), ((((710, 694), (((672, 674), 677), 661)), (642, 636)), ((616, 613), ((604, 603), 601)))), ((((((((((762, 766), 772), 782), 752), 740), (((((818, 818), 814), ((808, 808), 807)), 801), 795)), ((841, 835), ((859, 861), 853))), ((((891, 892), 893), 895), 879)), (((927, 929), 915), (944, 942))), ((((994, 993), 986), 975), 964))), ((((((419, 422), 415), (437, 434)), (((384, 381), 372), 396)), (((344, 346), 350), 338)), 463)), (((((((((25, 35), (46, 53)), (((84, 81), (73, 76)), 64)), 98), 115), (((138, 132), (150, 149)), 162)), (0, 1)), (194, 211)), ((((286, 285), 276), ((((253, 256), 261), 266), 246)), 296)))

最長距離法では次のようになりました。

(((((((((419, 422), 415), (437, 434)), 463), ((((527, 525), 529), (516, 516)), (502, 498))), ((((384, 381), 372), 396), (((344, 346), 350), 338))), ((((((672, 674), 677), 661), (642, 636)), (710, 694)), (((553, 553), 576), ((616, 613), ((604, 603), 601))))), ((((((795, 801), 782), (((818, 818), 814), ((808, 808), 807))), ((740, 752), ((762, 766), 772))), ((((859, 861), 853), (841, 835)), ((((891, 892), 893), 895), 879))), (((975, 964), ((994, 993), 986)), (((927, 929), 915), (944, 942))))), ((((((46, 53), 64), ((84, 81), (73, 76))), ((25, 35), (0, 1))), ((((150, 149), 162), (138, 132)), (98, 115))), (((((253, 256), 246), (261, 266)), (((286, 285), 276), 296)), (194, 211))))

群平均法では次のようになりました。

(((((((((384, 381), 372), 396), (((344, 346), 350), 338)), ((((419, 422), 415), (437, 434)), 463)), ((((286, 285), 276), 296), (((253, 256), 246), (261, 266)))), (((((616, 613), ((604, 603), 601)), (642, 636)), ((710, 694), (((672, 674), 677), 661))), (((553, 553), 576), ((((527, 525), 529), (516, 516)), (502, 498))))), (((((25, 35), (46, 53)), (0, 1)), ((98, 115), (((84, 81), (73, 76)), 64))), ((((150, 149), 162), (138, 132)), (194, 211)))), (((((795, 782), ((((808, 808), 807), 801), ((818, 818), 814))), ((740, 752), ((762, 766), 772))), ((((859, 861), 853), (841, 835)), ((((891, 892), 893), 895), 879))), (((975, 964), ((994, 993), 986)), (((927, 929), 915), (944, 942)))))

一応そっれぽくないでしょうか。うーん、やってみたのはいいけれども、これらの結果が本当に正しいのかどうかの確認方法が分かりません。あと、デンドログラムを描画してみたいのですが、これを手軽にやるためにはどうすれば良いのでしょうか。graphvizあたりに落とせばいいのかな?