Project

General

Profile

Download (11.1 KB) Statistics
| Branch: | Revision:

distorted-objectlib / src / main / java / org / distorted / objectlib / tablebases / PruningTable.java @ fb8eaaec

1
///////////////////////////////////////////////////////////////////////////////////////////////////
2
// Copyright 2023 Leszek Koltunski                                                               //
3
//                                                                                               //
4
// This file is part of Magic Cube.                                                              //
5
//                                                                                               //
6
// Magic Cube is proprietary software licensed under an EULA which you should have received      //
7
// along with the code. If not, check https://distorted.org/magic/License-Magic-Cube.html        //
8
///////////////////////////////////////////////////////////////////////////////////////////////////
9

    
10
package org.distorted.objectlib.tablebases;
11

    
12
///////////////////////////////////////////////////////////////////////////////////////////////////
13

    
14
class PruningTable
15
{
16
  static final int[] SUPPORTED = {4,8,12,16,20,24};
17

    
18
  private static final int HEADER_SIZE = 6;
19
  private byte[] mData;
20
  private int mNumBits;
21
  private int mBucketBytes;
22
  private int mNumBuckets;
23
  private int mLevel;
24

    
25
///////////////////////////////////////////////////////////////////////////////////////////////////
26

    
27
  private static String printByte(byte b)
28
    {
29
    return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0');
30
    }
31

    
32
///////////////////////////////////////////////////////////////////////////////////////////////////
33

    
34
  private int numBytesForIndices(int total)
35
    {
36
    int bytes = 0;
37

    
38
    while(total!=0)
39
      {
40
      total>>=8;
41
      bytes++;
42
      }
43

    
44
    return bytes;
45
    }
46

    
47
///////////////////////////////////////////////////////////////////////////////////////////////////
48

    
49
  private void writeHeader()
50
    {
51
    mData[0] = (byte)mLevel;
52
    mData[1] = (byte)mBucketBytes;
53
    mData[2] = (byte)mNumBits;
54
    mData[3] = (byte)(mNumBuckets>>>16);
55
    mData[4] = (byte)(mNumBuckets>>> 8);
56
    mData[5] = (byte)(mNumBuckets&0xff);
57
    }
58

    
59
///////////////////////////////////////////////////////////////////////////////////////////////////
60

    
61
  private int getBucketOffset(int bucket)
62
    {
63
    if( bucket>=mNumBuckets ) return mData.length;
64

    
65
    int ret = 0;
66
    int start = HEADER_SIZE + bucket*mBucketBytes;
67

    
68
    for(int i=0; i<mBucketBytes; i++)
69
      {
70
      ret<<=8;
71
      ret += ( ((int)(mData[start+i]))&0xff);
72
      }
73

    
74
    return HEADER_SIZE+mNumBuckets*mBucketBytes+ret;
75
    }
76

    
77
///////////////////////////////////////////////////////////////////////////////////////////////////
78

    
79
  private void writeBucketPointers(int[] bucket)
80
    {
81
    int start=HEADER_SIZE;
82

    
83
    for( int b : bucket )
84
      {
85
      for( int i=0; i<mBucketBytes; i++)
86
        {
87
        mData[start+mBucketBytes-1-i] = (byte)(b%256);
88
        b>>=8;
89
        }
90
      start += mBucketBytes;
91
      }
92
    }
93

    
94
///////////////////////////////////////////////////////////////////////////////////////////////////
95

    
96
  private int retrieveData(int index, int entryInBucket)
97
    {
98
    int bits = entryInBucket*mNumBits;
99
    int bytes = index + bits/8;
100
    int ret=0;
101

    
102
    switch(mNumBits)
103
      {
104
      case  4: ret = (bits%8)!=0 ? (mData[bytes]&0xf) : ((mData[bytes]>>>4)&0xf);
105
               break;
106
      case  8: ret = (mData[bytes]&0xff);
107
               break;
108
      case 12: ret = (bits%8)!=0 ? (((mData[bytes]&0xf)<<8) + (mData[bytes+1]&0xff)) : (((mData[bytes]&0xff)<<4) + ((mData[bytes+1]>>>4)&0xf));
109
               break;
110
      case 16: ret = ((mData[bytes]&0xff)<<8) + (mData[bytes+1]&0xff);
111
               break;
112
      case 20: if( (bits%8) !=0 ) ret = (((mData[bytes]&0x0f)<<16) + ((mData[bytes+1]&0xff)<<8) + ( mData[bytes+2]     &0xff));
113
               else               ret = (((mData[bytes]&0xff)<<12) + ((mData[bytes+1]&0xff)<<4) + ((mData[bytes+2]>>>4)&0x0f));
114
               break;
115
      case 24: ret = ((mData[bytes]&0xff)<<16) + ((mData[bytes+1]&0xff)<<8) + (mData[bytes+2]&0xff);
116
               break;
117
      }
118

    
119
    return ret;
120
    }
121

    
122
///////////////////////////////////////////////////////////////////////////////////////////////////
123

    
124
  private void insertData(int index, int entryInBucket, int value)
125
    {
126
    int bits = entryInBucket*mNumBits;
127
    int bytes = index + bits/8;
128

    
129
    switch(mNumBits)
130
      {
131
      case  4: if( (bits%8) !=0 ) mData[bytes] = (byte)((mData[bytes]&(~0xf)) + value  );
132
               else               mData[bytes] = (byte)((value<<4) + (mData[bytes]&0xf));
133
               break;
134
      case  8: mData[bytes] = (byte)value;
135
               break;
136
      case 12: if( (bits%8) !=0 )
137
                 {
138
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>8));
139
                 mData[bytes+1]= (byte)value;
140
                 }
141
               else
142
                 {
143
                 mData[bytes]  = (byte)(value>>>4);
144
                 mData[bytes+1]= (byte)((byte)(value<<4) + (mData[bytes+1]&0xf));
145
                 }
146
               break;
147
      case 16: mData[bytes]   = (byte)(value>>>8);
148
               mData[bytes+1] = (byte)(value&(0xff));
149
               break;
150
      case 20: if( (bits%8) !=0 )
151
                 {
152
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>16));
153
                 mData[bytes+1]= (byte)(value>>>8);
154
                 mData[bytes+2]= (byte)(value&(0xff));
155
                 }
156
               else
157
                 {
158
                 mData[bytes]  = (byte)(value>>>12);
159
                 mData[bytes+1]= (byte)(value>>>4);
160
                 mData[bytes+2]= (byte)((byte)(value<<4) + (mData[bytes+2]&0xf));
161
                 }
162
               break;
163
      case 24: mData[bytes]   = (byte)(value>>>16);
164
               mData[bytes+1] = (byte)(value>>>8);
165
               mData[bytes+2] = (byte)(value&(0xff));
166
               break;
167
      }
168
    }
169

    
170
///////////////////////////////////////////////////////////////////////////////////////////////////
171

    
172
  private void writeBuckets(Tablebase table)
173
    {
174
    int mask = (1<<mNumBits);
175
    int prevBucket = -1;
176
    int entryInBucket = 0;
177
    int index = 0;
178
    int size = table.getSize();
179

    
180
    for(int i=0; i<size; i++)
181
      if( table.retrieveUnpacked(i)==mLevel )
182
        {
183
        int toBeWritten = (i%mask);
184
        int currBucket = (i>>mNumBits);
185

    
186
        if( currBucket!=prevBucket )
187
          {
188
          index = getBucketOffset(currBucket);
189
          entryInBucket = 0;
190
          }
191
        else
192
          {
193
          entryInBucket ++;
194
          }
195

    
196
        prevBucket = currBucket;
197
        insertData(index,entryInBucket,toBeWritten);
198
        }
199
    }
200

    
201
///////////////////////////////////////////////////////////////////////////////////////////////////
202

    
203
  private int retNumPositions(Tablebase table, int level)
204
    {
205
    int size=0, tableSize = table.getSize();
206

    
207
    for(int i=0; i<tableSize; i++)
208
      if( table.retrieveUnpacked(i)==level ) size++;
209

    
210
    return size;
211
    }
212

    
213
///////////////////////////////////////////////////////////////////////////////////////////////////
214

    
215
  private int computeApproximateSize(int tableSize, int positionSize, int numBits)
216
    {
217
    int entrySize = (1<<numBits);
218
    int numBuckets = (tableSize+entrySize-1)/entrySize;
219
    int overhang = (numBits==4 || numBits==12) ? numBuckets/4 : 0;
220
    int totalBytesForTable = (positionSize*numBits)/8 + overhang;
221
    int bucketBytes = numBytesForIndices(totalBytesForTable);
222
    int totalBytesForBuckets = numBuckets*bucketBytes;
223

    
224
    return HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
225
    }
226

    
227
///////////////////////////////////////////////////////////////////////////////////////////////////
228

    
229
  private void construct(Tablebase table, int level, int numBits)
230
    {
231
    mNumBits = numBits;
232
    mLevel = level;
233
    int size = table.getSize();
234
    int entrySize = (1<<numBits);
235
    mNumBuckets = (size+entrySize-1)/entrySize;
236
    int[] bucket = new int[mNumBuckets];
237

    
238
    for(int i=0; i<size; i++)
239
      if( table.retrieveUnpacked(i)==level )
240
        {
241
        int currBucket = i/entrySize;
242
        bucket[currBucket]++;
243
        }
244

    
245
    int totalBytesForTable = 0;
246

    
247
    for(int i=0; i<mNumBuckets; i++)
248
      {
249
      int numBytes = (bucket[i]*numBits+7)/8;
250
      bucket[i] = totalBytesForTable;
251
      totalBytesForTable+=numBytes;
252
      }
253

    
254
    mBucketBytes = numBytesForIndices(totalBytesForTable);
255
    int totalBytesForBuckets = mNumBuckets*mBucketBytes;
256
    int totalSize = HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
257
    mData = new byte[totalSize];
258

    
259
    android.util.Log.e("D", "Constructing a pruningTable, numBits="+numBits+" size "+totalSize);
260

    
261
    writeHeader();
262
    writeBucketPointers(bucket);
263
    writeBuckets(table);
264
    }
265

    
266
///////////////////////////////////////////////////////////////////////////////////////////////////
267
// PUBLIC API
268
///////////////////////////////////////////////////////////////////////////////////////////////////
269
// only numBits = 4,8,12,16,20,24 actually supported
270

    
271
  PruningTable(Tablebase table, int level, int numBits)
272
    {
273
    construct(table,level,numBits);
274
    }
275

    
276
///////////////////////////////////////////////////////////////////////////////////////////////////
277

    
278
  PruningTable(Tablebase table, int level)
279
    {
280
    int minimum = 0, chosen =  0;
281
    int size = retNumPositions(table,level);
282
    int tableSize = table.getSize();
283

    
284
    android.util.Log.e("D", "----- PRUNING TABLE LEVEL "+level+" ------");
285

    
286
    for(int i=0; i<SUPPORTED.length; i++)
287
      {
288
      int approx = computeApproximateSize( tableSize, size, SUPPORTED[i]);
289

    
290
      if( i==0 || approx<minimum )
291
        {
292
        minimum = approx;
293
        chosen  = i;
294
        }
295

    
296
      android.util.Log.e("D", "trying numBits: "+SUPPORTED[i]+" approx size: "+approx);
297
      }
298

    
299
    construct(table,level,SUPPORTED[chosen]);
300
    }
301

    
302
///////////////////////////////////////////////////////////////////////////////////////////////////
303

    
304
  PruningTable(byte[] data)
305
    {
306
    mData = data;
307

    
308
    mLevel       =  ((int)mData[0])&0xff;
309
    mBucketBytes =  ((int)mData[1])&0xff;
310
    mNumBits     =  ((int)mData[2])&0xff;
311
    mNumBuckets  =((((int)mData[3])&0xff)<<16) + ((((int)mData[4])&0xff)<<8) + (((int)mData[5])&0xff);
312
    }
313

    
314
///////////////////////////////////////////////////////////////////////////////////////////////////
315

    
316
  byte[] getPacked()
317
    {
318
    return mData;
319
    }
320

    
321
///////////////////////////////////////////////////////////////////////////////////////////////////
322

    
323
  int getLevel()
324
    {
325
    return mLevel;
326
    }
327

    
328
///////////////////////////////////////////////////////////////////////////////////////////////////
329

    
330
  boolean contains(int number)
331
    {
332
    int bucket = (number>>mNumBits);
333
    int offset1 = getBucketOffset(bucket);
334
    int offset2 = getBucketOffset(bucket+1);
335
    int numBytesInBucket = offset2-offset1;
336
    int numEntriesInBucket = (8*numBytesInBucket)/mNumBits;
337

    
338
    if( mNumBits==4 && numBytesInBucket>0 )
339
      {
340
      byte val = mData[offset2-1];
341
      if( (val/16)*16 == val ) numEntriesInBucket--;
342
      }
343

    
344
    int lower = 0;
345
    int upper = numEntriesInBucket-1;
346

    
347
    while( upper>=lower )
348
      {
349
      int mid = (lower+upper)/2;
350
      int value = retrieveData(offset1,mid);
351

    
352
      value += (bucket<<mNumBits);
353
      if( value==number ) return true;
354

    
355
      if( value>number ) upper = mid-1;
356
      else               lower = mid+1;
357
      }
358

    
359
    return false;
360
    }
361
}
(2-2/12)