OCamlで階層的手法によるクラスタリングをやってみました
Blogopolis すごい、しかもRubyのとこの隅っこに自分のIDがあってびっくりしました。とにかく、クラスタリングって格好いい!と思いました。クラスタリングって言語学だと概論で歴史言語学の何かに使われてるのくらいしか見たことがないから、これまで触る機会が全然ありませんでした。そこで、少しは知っておかなきゃと思いましたので、はてな村の地図『HatenarMaps』を公開しました - kaisehのブログに載っていたクラスタリング (クラスター分析)を参考に、OCamlでクラスタリングをしてみることにしました。
とりあえず、いきなり難しいことは無理ですから、手始めにランダムな整数の列をクラスタリングしてみることにしました。というわけで、階層的手法における最短距離法、最長距離法、群平均法の三つにチャレンジしてみました。今回は整数だけだから、ウォード法ってのはダメなんですよね。クラスタのセントロイドってのがベクトルじゃないと求められないからなのかな?
type 'a cluster = | Empty | Leaf of 'a | Node of ('a cluster * 'a cluster) ;; let rec elements_of_cluster = function | Empty -> [] | Leaf elt -> [elt] | Node (a, b) -> List.rev_append (elements_of_cluster a) (elements_of_cluster b) ;; let rec string_of_cluster = function | Empty -> "" | Leaf e -> Printf.sprintf "%d" e | Node (a, b) -> Printf.sprintf "(%s, %s)" (string_of_cluster a) (string_of_cluster b) ;; module Clustering = struct module type ELEMENT = sig type t val distance : t -> t -> int end module type METHOD = sig type elt type t = elt cluster val distance : t -> t -> int end module type MAKE_METHOD = functor(Elt : ELEMENT) -> METHOD with type elt = Elt.t module type S = sig type t val join : t list -> t end module Make = functor(Method : METHOD) -> struct type t = Method.t let rec find clusters = let rec f (d, pair) = function | [] | _::[] -> (d, pair) | c1::tl -> f (List.fold_left (fun (d, pair) c2 -> let d' = Method.distance c1 c2 in if d = -1 || d > d' then (d', (c1, c2)) else (d, pair) ) (d, pair) tl) tl in snd (f (-1, (Empty, Empty)) clusters) let rec join = function | cluster::[] -> cluster | clusters -> let c1, c2 = find clusters in let rest = List.filter (fun c -> c <> c1 && c <> c2) clusters in join (Node (c1, c2) :: rest) end end ;; module MakeSingleLinkage : Clustering.MAKE_METHOD = functor(Elt : Clustering.ELEMENT) -> struct type elt = Elt.t type t = elt cluster let distance c1 c2 = let l1 = elements_of_cluster c1 and l2 = elements_of_cluster c2 in List.fold_left (fun d e1 -> List.fold_left (fun d' e2 -> let n = Elt.distance e1 e2 in if d'= -1 || n < d' then n else d' ) d l2) ~-1 l1 end module MakeCompleteLinkage : Clustering.MAKE_METHOD = functor(Elt : Clustering.ELEMENT) -> struct type elt = Elt.t type t = elt cluster let distance c1 c2 = let l1 = elements_of_cluster c1 and l2 = elements_of_cluster c2 in List.fold_left (fun d e1 -> List.fold_left (fun d' e2 -> let n = Elt.distance e1 e2 in if d'= -1 || n > d' then n else d' ) d l2) ~-1 l1 end module MakeGroupAverage : Clustering.MAKE_METHOD = functor(Elt : Clustering.ELEMENT) -> struct type elt = Elt.t type t = elt cluster let distance c1 c2 = let l1 = elements_of_cluster c1 and l2 = elements_of_cluster c2 in (List.fold_left (fun d e1 -> List.fold_left (fun d' e2 -> d' + (Elt.distance e1 e2) ) d l2) 0 l1) / ((List.length l1) * (List.length l2)) end module IntElement = struct type t = int let distance e1 e2 = abs (e1 - e2) end ;; module SingleLinkage = Clustering.Make(MakeSingleLinkage(IntElement)) module CompleteLinkage = Clustering.Make(MakeCompleteLinkage(IntElement)) module GroupAverage = Clustering.Make(MakeGroupAverage(IntElement)) let make_cluster size = let rec make_list list = function | 0 -> list | size -> Random.int 1000 :: (make_list list (size - 1)) in List.map (fun e -> Leaf e) (make_list [] size) let main = Random.self_init (); let clusters = (make_cluster 100) in print_endline (Printf.sprintf "[%s]" (List.fold_left (fun s c -> s ^ (Printf.sprintf "%s; " (string_of_cluster c)) ) "" clusters)); print_endline (string_of_cluster (SingleLinkage.join clusters)); print_endline (string_of_cluster (CompleteLinkage.join clusters)); print_endline (string_of_cluster (GroupAverage.join clusters))
なんというかfunctorがうまい事書けてない気がするんですが、クラスタリングできましたよ!
[194; 642; 604; 553; 162; 384; 25; 527; 419; 762; 603; 138; 516; 808; 807; 422; 502; 73; 150; 261; 986; 437; 211; 132; 415; 553; 525; 672; 266; 740; 891; 674; 841; 994; 892; 710; 338; 98; 853; 818; 149; 795; 372; 463; 879; 772; 76; 246; 396; 859; 975; 636; 576; 253; 616; 344; 801; 434; 661; 276; 46; 927; 498; 601; 613; 808; 766; 53; 929; 35; 346; 256; 677; 115; 286; 350; 861; 782; 516; 895; 752; 915; 835; 814; 694; 964; 0; 64; 381; 84; 944; 285; 993; 818; 296; 893; 942; 81; 1; 529; ]
という列に対して、最短距離法は次のようにクラスタリングしてくれました。
(((((((553, 553), 576), ((((527, 525), 529), (516, 516)), (502, 498))), ((((710, 694), (((672, 674), 677), 661)), (642, 636)), ((616, 613), ((604, 603), 601)))), ((((((((((762, 766), 772), 782), 752), 740), (((((818, 818), 814), ((808, 808), 807)), 801), 795)), ((841, 835), ((859, 861), 853))), ((((891, 892), 893), 895), 879)), (((927, 929), 915), (944, 942))), ((((994, 993), 986), 975), 964))), ((((((419, 422), 415), (437, 434)), (((384, 381), 372), 396)), (((344, 346), 350), 338)), 463)), (((((((((25, 35), (46, 53)), (((84, 81), (73, 76)), 64)), 98), 115), (((138, 132), (150, 149)), 162)), (0, 1)), (194, 211)), ((((286, 285), 276), ((((253, 256), 261), 266), 246)), 296)))
最長距離法では次のようになりました。
(((((((((419, 422), 415), (437, 434)), 463), ((((527, 525), 529), (516, 516)), (502, 498))), ((((384, 381), 372), 396), (((344, 346), 350), 338))), ((((((672, 674), 677), 661), (642, 636)), (710, 694)), (((553, 553), 576), ((616, 613), ((604, 603), 601))))), ((((((795, 801), 782), (((818, 818), 814), ((808, 808), 807))), ((740, 752), ((762, 766), 772))), ((((859, 861), 853), (841, 835)), ((((891, 892), 893), 895), 879))), (((975, 964), ((994, 993), 986)), (((927, 929), 915), (944, 942))))), ((((((46, 53), 64), ((84, 81), (73, 76))), ((25, 35), (0, 1))), ((((150, 149), 162), (138, 132)), (98, 115))), (((((253, 256), 246), (261, 266)), (((286, 285), 276), 296)), (194, 211))))
群平均法では次のようになりました。
(((((((((384, 381), 372), 396), (((344, 346), 350), 338)), ((((419, 422), 415), (437, 434)), 463)), ((((286, 285), 276), 296), (((253, 256), 246), (261, 266)))), (((((616, 613), ((604, 603), 601)), (642, 636)), ((710, 694), (((672, 674), 677), 661))), (((553, 553), 576), ((((527, 525), 529), (516, 516)), (502, 498))))), (((((25, 35), (46, 53)), (0, 1)), ((98, 115), (((84, 81), (73, 76)), 64))), ((((150, 149), 162), (138, 132)), (194, 211)))), (((((795, 782), ((((808, 808), 807), 801), ((818, 818), 814))), ((740, 752), ((762, 766), 772))), ((((859, 861), 853), (841, 835)), ((((891, 892), 893), 895), 879))), (((975, 964), ((994, 993), 986)), (((927, 929), 915), (944, 942)))))
一応そっれぽくないでしょうか。うーん、やってみたのはいいけれども、これらの結果が本当に正しいのかどうかの確認方法が分かりません。あと、デンドログラムを描画してみたいのですが、これを手軽にやるためにはどうすれば良いのでしょうか。graphvizあたりに落とせばいいのかな?