読者です 読者をやめる 読者になる 読者になる

Wavelet Matrix

Java

一部で大ブームのWavelet Matrix(ウェーブレット行列)をJavaで実装してみました。
以下を参考にしました。

高速文字列解析の世界――データ圧縮・全文検索・テキストマイニング (確率と情報の科学)

高速文字列解析の世界――データ圧縮・全文検索・テキストマイニング (確率と情報の科学)

また、以下のスライドを参考にしました。

http://d.hatena.ne.jp/echizen_tm/20120930/1348999144

ウェーブレット行列は、ウェーブレット木と同等のことができるデータ構造です。
ウェーブレット木については、書こうと思って書いてるうちに忙しくなって投げ出した経緯があります。

http://rn.hatenablog.com/entry/2012/06/22/000315

ただ、まぁ、ウェーブレット木についてはそこそこ説明していて、実装を公開するのを投げた状態です。
(この時期は会社でひたすらコーディングしていたので、家でコーディングするのが嫌になっていた時期ですね。)

なんかそうこうしているうちに、ウェーブレット木は、ウェーブレット行列で実現するのがいいって流れになっていたので、勉強がてら作ってみました。
とりあえず、簡単そうなaccessとrankを作ってみました。

真面目な解説は後日に書くとして、とりあえず、実装だけでも公開しておきます。
ウェーブレット行列内で利用する簡潔ビットベクトルは、以前に自分で作って公開しているsbvjを使いました。

http://code.google.com/p/sbvj/

また、簡潔ビットベクトルの解説は以下にあります。

http://rn.hatenablog.com/entry/20120328/1332951451

以下がウェーブレット行列です。

public class WaveletMatrix {
	
	private String mStr;
	
	private int mStrCodePointCount;
	
	private int mBVNum;
	
	/**
	 * 
	 * <p>
	 * ウェーブレット行列
	 * </p>
	 * 
	 */
	private List<SuccinctBitVector> mWaveletMatrix;
	
	/**
	 * 
	 * <p>
	 * コンストラクタ
	 * </p>
	 * 
	 */
	public WaveletMatrix( final String pStr ) {
		this.mStr = pStr;
		this.mStrCodePointCount = this.mStr.codePointCount(0, this.mStr.length());
		
		// 文字の最大ビット数を調べる。
		this.mBVNum = 0;
		for ( int i = 0; i < this.mStrCodePointCount; i++ ) {
			int codePoint = this.mStr.codePointAt(i);
			int msb = WMUtility.getMSB(codePoint);
			
			if ( this.mBVNum < msb ) {
				this.mBVNum = msb;
			}
		}
		
		this.mWaveletMatrix = new ArrayList<SuccinctBitVector>(this.mBVNum);
		for ( int i = 0; i < this.mBVNum; i++ ) {
			this.mWaveletMatrix.add(i, new SuccinctBitVector( this.mStrCodePointCount ));
		}
	}
	
	private String buildBitVector(
			final int pBitPos,
			final String pCurStr
			) {
		
		assert( 0 < pBitPos );
		
		SuccinctBitVector prevSBV = this.mWaveletMatrix.get(pBitPos - 1);
		SuccinctBitVector curSBV = this.mWaveletMatrix.get(pBitPos);
		
		long boundaryIndex = prevSBV.getRank(this.mStrCodePointCount, 0);
		
		long leftIndex = 0;
		long rightIndex = 0;
		StringBuffer leftNextStr = new StringBuffer();
		StringBuffer rightNextStr = new StringBuffer();
		for ( int i = 0; i < this.mStrCodePointCount; i++ ) {
			
			// 直前のビットベクトルでのビットを調べる。
			int prevBit = prevSBV.getBit( i );
			
			int codePoint = pCurStr.codePointAt(i);
			int bit = (codePoint >> pBitPos) & 1;
			
			if ( 0 == prevBit ) {
				// 左側に分類する。
				curSBV.setBit(leftIndex, bit);
				leftNextStr.appendCodePoint(codePoint);
				leftIndex++;
			} else {
				curSBV.setBit(rightIndex + boundaryIndex, bit);
				rightNextStr.appendCodePoint(codePoint);
				rightIndex++;
			}
		}
		
		assert( (long) this.mStrCodePointCount == (leftIndex + rightIndex) );
		
		curSBV.build( BuildType.SIMPLE );
		
		leftNextStr.append( rightNextStr );
		
		return leftNextStr.toString();
		
	}
	
	public void build() {
		
		// 1ビット目を分類する。
		SuccinctBitVector sbv = this.mWaveletMatrix.get(0);
		for ( int i = 0; i < this.mStrCodePointCount; i++ ) {
			int codePoint = this.mStr.codePointAt( i );
			int bit = codePoint & 1;
			sbv.setBit(i, bit);
		}
		sbv.build( BuildType.SIMPLE );
		
		// 2ビット目以降は前回のSBVを使って分類する。
		String nextStr = this.mStr;
		for ( int bitPos = 1; bitPos < this.mBVNum; bitPos++ ) {
			nextStr = buildBitVector(bitPos, nextStr);
		}
		
		return;
	}
	
	public int access( final int pPos ) {
		
		int codePoint = 0;
		
		int sbvPos = pPos;
		for ( int bitPos = 0; bitPos < this.mBVNum; bitPos++ ) {
			SuccinctBitVector sbv = this.mWaveletMatrix.get(bitPos);
			
			int bit = sbv.getBit( sbvPos );
			codePoint |= (bit << bitPos);
			
			sbvPos = (int) sbv.getRank(sbvPos, bit);
			if ( 0 != bit ) {
				// 次回のビット位置を計算する。
				int boundaryIndex = (int) sbv.getRank(this.mStrCodePointCount, 0);
				sbvPos += boundaryIndex;
			}
			
		}
		
		return codePoint;
		
	}
	
	public int getRank( final int pPos , final int pCodePoint) {
		
		int rank = 0;
		
		int start = 0;
		int end = pPos;
		for ( int bitPos = 0; bitPos < this.mBVNum; bitPos++ ) {
			SuccinctBitVector sbv = this.mWaveletMatrix.get(bitPos);
			
			int bit = (pCodePoint >> bitPos) & 1;
			
			start = (int) sbv.getRank(start, bit);
			end = (int) sbv.getRank(end, bit);
			
			if ( 0 != bit ) {
				int boundaryIndex = (int) sbv.getRank(this.mStrCodePointCount, 0);
				start += boundaryIndex;
				end += boundaryIndex;
			}
			
			if ( (this.mBVNum - 1) == bitPos ) {
				int rank1 = (int) sbv.getRank(start, bit);
				int rank2 = (int) sbv.getRank(end, bit);
				
				assert( rank1 <= rank2 );
				rank = rank2 - rank1;
			}
			
		}
		
		return rank;
		
	}

}

以下がユーティリティです。

public class WMUtility {
	
	// インスタンス化の抑制を行う。
	private WMUtility() {}
	
	/** 1バイトの値のrank値 */
	static final int RANK_TABLE[];
	static {
		RANK_TABLE = new int [] {
				0, 1, 1, 2, 1, 2, 2, 3,
				1, 2, 2, 3, 2, 3, 3, 4,
				1, 2, 2, 3, 2, 3, 3, 4,
				2, 3, 3, 4, 3, 4, 4, 5,
				1, 2, 2, 3, 2, 3, 3, 4,
				2, 3, 3, 4, 3, 4, 4, 5,
				2, 3, 3, 4, 3, 4, 4, 5,
				3, 4, 4, 5, 4, 5, 5, 6,
				1, 2, 2, 3, 2, 3, 3, 4,
				2, 3, 3, 4, 3, 4, 4, 5,
				2, 3, 3, 4, 3, 4, 4, 5,
				3, 4, 4, 5, 4, 5, 5, 6,
				2, 3, 3, 4, 3, 4, 4, 5,
				3, 4, 4, 5, 4, 5, 5, 6,
				3, 4, 4, 5, 4, 5, 5, 6,
				4, 5, 5, 6, 5, 6, 6, 7,
				1, 2, 2, 3, 2, 3, 3, 4,
				2, 3, 3, 4, 3, 4, 4, 5,
				2, 3, 3, 4, 3, 4, 4, 5,
				3, 4, 4, 5, 4, 5, 5, 6,
				2, 3, 3, 4, 3, 4, 4, 5,
				3, 4, 4, 5, 4, 5, 5, 6,
				3, 4, 4, 5, 4, 5, 5, 6,
				4, 5, 5, 6, 5, 6, 6, 7,
				2, 3, 3, 4, 3, 4, 4, 5,
				3, 4, 4, 5, 4, 5, 5, 6,
				3, 4, 4, 5, 4, 5, 5, 6,
				4, 5, 5, 6, 5, 6, 6, 7,
				3, 4, 4, 5, 4, 5, 5, 6,
				4, 5, 5, 6, 5, 6, 6, 7,
				4, 5, 5, 6, 5, 6, 6, 7,
				5, 6, 6, 7, 6, 7, 7, 8
		};
	}
	
	/**
	 * 
	 * <p>
	 * 32bitのWORDから1のrank値を計算するメソッドである。
	 * </p>
	 * 
	 * @param pWord 32bitのWORD
	 * @return 32bitのWORD内の1のrank値
	 */
	static int getWordRankFromTable( final int pWord ) {
		// 32bit分のrank値を計算する。
		return RANK_TABLE[ (pWord >> 24) & 0xff ]
				+ RANK_TABLE[ (pWord >> 16) & 0xff ]
				+ RANK_TABLE[ (pWord >> 8) & 0xff ]
				+ RANK_TABLE[ pWord & 0xff ];
	}
	
	/**
	 * 
	 * <p>
	 * MSBを求めるメソッドである。
	 * </p>
	 * 
	 * @param pValue 整数値
	 * @return 指定された整数値のMSB
	 */
	static int getMSB( final int pValue ) {
		int rank = 0;
		int value = pValue;
		if ( 0 < pValue ) {
			value |= (value >> 1);
			value |= (value >> 2);
			value |= (value >> 4);
			value |= (value >> 16);
			value |= (value >> 32);
			
			// MSB以下は、すべて1が立っているので、その数を数える。
			rank = getWordRankFromTable(value);
		}
		return rank;
	}

}

以下がテストコードです。

public class WaveletMatrixTest {

	/**
	 * @param args
	 */
	public static void main(String[] args) {
		
		//int [] codePointArray = new int [] { 1, 7, 3, 5, 6, 2, 4, 0, 4, 1, 4, 7 };
		//String str = new String( codePointArray, 0, codePointArray.length );
		String str = "abracadabra";
		
		// Wavelet Matrix
		WaveletMatrix wm = new WaveletMatrix( str );
		wm.build();
		
		// access, rank
		int codePointCount = str.codePointCount(0, str.length());
		for ( int i = 0; i < codePointCount; i++ ) {
			
			// access
			int codePoint = wm.access(i);
			System.out.println( "access( " + i + " )" + " : " + codePoint + ", " + new String( Character.toChars(codePoint) ) );
			
			// rank
			int rank = wm.getRank(codePointCount, str.codePointAt(i));
			System.out.println( "rank(" + codePointCount + ", " + new String( Character.toChars(str.codePointAt(i)) ) + ") : " + rank );
		}

	}

}

さぁ、これで、BWTとウェーブレット行列が揃ったので、圧縮全文索引であるFM-Indexが実装できます。
なので、明日は、FM-Indexを実装してみます。

広告を非表示にする