Project

General

Profile

Download (7.95 KB) Statistics
| Branch: | Revision:

distorted-objectlib / src / main / java / org / distorted / objectlib / tablebases / PruningTable.java @ 00947987

1
///////////////////////////////////////////////////////////////////////////////////////////////////
2
// Copyright 2023 Leszek Koltunski                                                               //
3
//                                                                                               //
4
// This file is part of Magic Cube.                                                              //
5
//                                                                                               //
6
// Magic Cube is proprietary software licensed under an EULA which you should have received      //
7
// along with the code. If not, check https://distorted.org/magic/License-Magic-Cube.html        //
8
///////////////////////////////////////////////////////////////////////////////////////////////////
9

    
10
package org.distorted.objectlib.tablebases;
11

    
12
///////////////////////////////////////////////////////////////////////////////////////////////////
13

    
14
class PruningTable
15
{
16
  private static final int HEADER_SIZE = 5;
17
  private final byte[] mData;
18
  private final int mNumBits;
19
  private final int mBucketBytes;
20
  private final int mNumBuckets;
21
  private final int mLevel;
22

    
23
///////////////////////////////////////////////////////////////////////////////////////////////////
24

    
25
  private static String printByte(byte b)
26
    {
27
    return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0');
28
    }
29

    
30
///////////////////////////////////////////////////////////////////////////////////////////////////
31

    
32
  private int numBytesForIndices(int total)
33
    {
34
    int bytes = 0;
35

    
36
    while(total!=0)
37
      {
38
      total>>=8;
39
      bytes++;
40
      }
41

    
42
    return bytes;
43
    }
44

    
45
///////////////////////////////////////////////////////////////////////////////////////////////////
46

    
47
  private void writeHeader()
48
    {
49
    mData[0] = (byte)mLevel;
50
    mData[1] = (byte)mBucketBytes;
51
    mData[2] = (byte)mNumBits;
52
    mData[3] = (byte)(mNumBuckets/256);
53
    mData[4] = (byte)(mNumBuckets%256);
54
    }
55

    
56
///////////////////////////////////////////////////////////////////////////////////////////////////
57

    
58
  private int getBucketOffset(int bucket)
59
    {
60
    if( bucket>=mNumBuckets ) return mData.length;
61

    
62
    int ret = 0;
63
    int start = HEADER_SIZE + bucket*mBucketBytes;
64

    
65
    for(int i=0; i<mBucketBytes; i++)
66
      {
67
      ret<<=8;
68
      ret += ( ((int)(mData[start+i]))&0xff);
69
      }
70

    
71
    return HEADER_SIZE+mNumBuckets*mBucketBytes+ret;
72
    }
73

    
74
///////////////////////////////////////////////////////////////////////////////////////////////////
75

    
76
  private void writeBucketPointers(int[] bucket)
77
    {
78
    int start=HEADER_SIZE;
79

    
80
    for( int b : bucket )
81
      {
82
      for( int i=0; i<mBucketBytes; i++)
83
        {
84
        mData[start+mBucketBytes-1-i] = (byte)(b%256);
85
        b>>=8;
86
        }
87
      start += mBucketBytes;
88
      }
89
    }
90

    
91
///////////////////////////////////////////////////////////////////////////////////////////////////
92

    
93
  private int retrieveData(int index, int entryInBucket)
94
    {
95
    int bits = entryInBucket*mNumBits;
96
    int bytes = index + bits/8;
97
    int ret=0;
98

    
99
    switch(mNumBits)
100
      {
101
      case  4: ret = (bits%8)!=0 ? (mData[bytes]&0xf) : ((mData[bytes]>>>4)&0xf);
102
               break;
103
      case  8: ret = (mData[bytes]&0xff);
104
               break;
105
      case 12: ret = (bits%8)!=0 ? ((mData[bytes]&0xf)*256 + (mData[bytes+1]&0xff)) : ( (mData[bytes]&0xff)*16 + ((mData[bytes+1]>>>4)&0xf));
106
               break;
107
      case 16: ret = (mData[bytes]&0xff)*256 + (mData[bytes+1]&0xff);
108
               break;
109
      }
110

    
111
    return ret;
112
    }
113

    
114
///////////////////////////////////////////////////////////////////////////////////////////////////
115

    
116
  private void insertData(int index, int entryInBucket, int value)
117
    {
118
    int bits = entryInBucket*mNumBits;
119
    int bytes = index + bits/8;
120

    
121
    switch(mNumBits)
122
      {
123
      case  4: if( (bits%8) !=0 ) mData[bytes] = (byte)((mData[bytes]&(~0xf)) + value  );
124
               else               mData[bytes] = (byte)((value<<4) + (mData[bytes]&0xf));
125
               break;
126
      case  8: mData[bytes] = (byte)value;
127
               break;
128
      case 12: if( (bits%8) !=0 )
129
                 {
130
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>8));
131
                 mData[bytes+1]= (byte)value;
132
                 }
133
               else
134
                 {
135
                 mData[bytes]  = (byte)(value>>>4);
136
                 mData[bytes+1]= (byte)((byte)(value<<4) + (mData[bytes+1]&( 0xf)));
137
                 }
138
               break;
139
      case 16: mData[bytes]   = (byte)(value>>>8);
140
               mData[bytes+1] = (byte)(value&(0xff));
141
      }
142
    }
143

    
144
///////////////////////////////////////////////////////////////////////////////////////////////////
145

    
146
  private void writeBuckets(Tablebase table)
147
    {
148
    int mask = (1<<mNumBits);
149
    int prevBucket = -1;
150
    int entryInBucket = 0;
151
    int index = 0;
152
    int size = table.getSize();
153

    
154
    for(int i=0; i<size; i++)
155
      if( table.retrieveUnpacked(i)==mLevel )
156
        {
157
        int toBeWritten = (i%mask);
158
        int currBucket = (i>>mNumBits);
159

    
160
        if( currBucket!=prevBucket )
161
          {
162
          index = getBucketOffset(currBucket);
163
          entryInBucket = 0;
164
          }
165
        else
166
          {
167
          entryInBucket ++;
168
          }
169

    
170
        prevBucket = currBucket;
171
        insertData(index,entryInBucket,toBeWritten);
172
        }
173
    }
174

    
175
///////////////////////////////////////////////////////////////////////////////////////////////////
176
// PUBLIC API
177
///////////////////////////////////////////////////////////////////////////////////////////////////
178
// only numBits = 4,8,12,16 actually supported
179

    
180
  PruningTable(Tablebase table, int level, int numBits)
181
    {
182
    mNumBits = numBits;
183
    mLevel = level;
184
    int size = table.getSize();
185
    int entrySize = (1<<numBits);
186
    mNumBuckets = (size+entrySize-1)/entrySize;
187
    int[] bucket = new int[mNumBuckets];
188

    
189
    for(int i=0; i<size; i++)
190
      if( table.retrieveUnpacked(i)==level )
191
        {
192
        int currBucket = i/entrySize;
193
        bucket[currBucket]++;
194
        }
195

    
196
    int totalBytesForTable = 0;
197

    
198
    for(int i=0; i<mNumBuckets; i++)
199
      {
200
      int numBytes = (bucket[i]*numBits+7)/8;
201
      bucket[i] = totalBytesForTable;
202
      totalBytesForTable+=numBytes;
203
      }
204

    
205
    mBucketBytes = numBytesForIndices(totalBytesForTable);
206
    int totalBytesForBuckets = mNumBuckets*mBucketBytes;
207
    int totalSize = HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
208
    mData = new byte[totalSize];
209

    
210
    writeHeader();
211
    writeBucketPointers(bucket);
212
    writeBuckets(table);
213
    }
214

    
215
///////////////////////////////////////////////////////////////////////////////////////////////////
216

    
217
  PruningTable(byte[] data)
218
    {
219
    mData = data;
220

    
221
    mLevel       = ((int)mData[0])&0xff;
222
    mBucketBytes = ((int)mData[1])&0xff;
223
    mNumBits     = ((int)mData[2])&0xff;
224
    mNumBuckets  =(((int)mData[3])&0xff)*256 + ((int)mData[4])&0xff;
225
    }
226

    
227
///////////////////////////////////////////////////////////////////////////////////////////////////
228

    
229
  byte[] getPacked()
230
    {
231
    return mData;
232
    }
233

    
234
///////////////////////////////////////////////////////////////////////////////////////////////////
235

    
236
  int getLevel()
237
    {
238
    return mLevel;
239
    }
240

    
241
///////////////////////////////////////////////////////////////////////////////////////////////////
242

    
243
  boolean belongs(int number)
244
    {
245
    int bucket = (number>>mNumBits);
246
    int offset1 = getBucketOffset(bucket);
247
    int offset2 = getBucketOffset(bucket+1);
248
    int numBytesInBucket = offset2-offset1;
249
    int numEntriesInBucket = (8*numBytesInBucket)/mNumBits;
250

    
251
    if( mNumBits==4 && numBytesInBucket>0 )
252
      {
253
      byte val = mData[offset2-1];
254
      if( (val/16)*16 == val ) numEntriesInBucket--;
255
      }
256

    
257
    int lower = 0;
258
    int upper = numEntriesInBucket-1;
259

    
260
    while( upper>=lower )
261
      {
262
      int mid = (lower+upper)/2;
263
      int value = retrieveData(offset1,mid);
264

    
265
      value += (bucket<<mNumBits);
266
      if( value==number ) return true;
267

    
268
      if( value>number ) upper = mid-1;
269
      else               lower = mid+1;
270
      }
271

    
272
    return false;
273
    }
274
}
(2-2/12)