#include "stdafx.h"

// 640 ticks @ 5ns = 3.2ms cell time (360 RPM)
const uint32_t kNominalFMBitCellTime = 640;	// 4us @ 288 RPM
//const uint32_t kNominalFMBitCellTime = 614;	// 4us @ 300 RPM

class SectorEncoder {
public:
	SectorEncoder()
		: mTime(0)
		, mLastByte(0)
	{

	}

	void SetBitCellTime(uint32_t bct) {
		mBitCellTime = bct;
	}

	void BeginCritical() {
		mCriticalStart = mTime;
	}

	void EndCritical() {
		mCriticalEnd = mTime;
	}

	void EncodeByteFM(uint8_t v) {
		EncodeByteFM(0xFF, v);
	}

	void EncodeByteFM(uint8_t clock, uint8_t data) {
		for(int i=0; i<8; ++i) {
			if (clock & 0x80)
				mStream.push_back(mTime);

			if (data & 0x80)
				mStream.push_back(mTime + mBitCellTime);

			clock += clock;
			data += data;

			mTime += mBitCellTime*2;
		}		
	}

	void EncodeWeakByteFM() {
		for(int i=0; i<5; ++i) {
			mStream.push_back(mBitCellTime);

			mTime += (mBitCellTime * 3) >> 1;

			mStream.push_back(mBitCellTime);

			mTime += (mBitCellTime * 3 + 1) >> 1;
		}

		mTime += mBitCellTime;
	}

	void EncodeByteMFM(uint8_t v) {
		EncodeByteMFM(0xFF, v);
	}

	void EncodeByteMFM(uint8_t clock_mask, uint8_t data) {
		uint8_t clock = ~((data >> 1) | (mLastByte << 7) | data) & clock_mask;
		mLastByte = data;

		for(int i=0; i<8; ++i) {
			if (clock & 0x80)
				mStream.push_back(mTime);

			if (data & 0x80)
				mStream.push_back(mTime + mBitCellTime);

			clock += clock;
			data += data;

			mTime += mBitCellTime*2;
		}
	}

	void EncodeWeakByteMFM() {
		for(int i=0; i<5; ++i) {
			mStream.push_back(mBitCellTime);

			mTime += (mBitCellTime * 3) >> 1;

			mStream.push_back(mBitCellTime);

			mTime += (mBitCellTime * 3 + 1) >> 1;
		}

		mTime += mBitCellTime;
	}

	std::vector<uint32_t> mStream;
	uint32_t mTime;
	uint32_t mCriticalStart;
	uint32_t mCriticalEnd;
	uint32_t mBitCellTime;
	uint8_t mLastByte;
};

struct SectorCopy {
	const SectorInfo *mpSector;
	const SectorEncoder *mpEncodedSector;
	uint32_t mPosition;
	uint32_t mEncodeStart;
	uint32_t mEncodeEnd;
};

void encode_track(RawTrack& dst, TrackInfo& src, int track, double periodMultiplier) {
	uint32_t bitCellTime = (uint32_t)(0.5 + kNominalFMBitCellTime * periodMultiplier);

	// check if we have MFM sectors
	bool mfm = false;

	for(auto it = src.mSectors.begin(), itEnd = src.mSectors.end();
		it != itEnd;
		++it)
	{
		if (it->mbMFM)
			mfm = true;
	}

	if (mfm)
		bitCellTime >>= 1;

	// We use 5ns encoding so as to be able to hit both KryoFlux (40ns) and SuperCard Pro (25ns).
	// Rotational speed is 360 RPM.
	dst.mSamplesPerRev = 200000000.0f / 6.0f;

	for(int i=0; i<6; ++i)
		dst.mIndexTimes.push_back(200000000 * (i + 1) / 6);

	// collect sectors to encode
	std::vector<SectorInfo *> sectors;
	sift_sectors(src, track, sectors);

	// find the lowest numbered sector and use that for the index mark
	SectorInfo *lowest_sec = nullptr;
	size_t numsecs = sectors.size();
	for(size_t i=0; i<numsecs; ++i) {
		if (!lowest_sec || sectors[i]->mIndex < lowest_sec->mIndex)
			lowest_sec = sectors[i];
	}

	// find the biggest gap and use that for the splice point
	SectorInfo *first_sec = nullptr;

	if (numsecs)
		first_sec = sectors[0];

	dst.mSpliceStart = -1;
	dst.mSpliceEnd = -1;

	// encode sectors to flux transition bitstreams
	std::vector<SectorEncoder> sector_encoders(sectors.size());

	// encode sector to FM or MFM
	for(size_t i=0; i<numsecs; ++i) {
		const SectorInfo& sec = *sectors[i];
		SectorEncoder& enc = sector_encoders[i];
		enc.SetBitCellTime(bitCellTime);

		bool first_sec = (&sec == lowest_sec);

		if (mfm) {
			for(int j=0; j<11; ++j)
				enc.EncodeByteMFM(0x00);

			enc.BeginCritical();
			enc.EncodeByteMFM(0x00);

			// write sector header
			uint8_t sechdr[10] = {
				0xA1,
				0xA1,
				0xA1,
				0xFE,
				(uint8_t)track,
				0,
				(uint8_t)sec.mIndex,
				(uint8_t)(sec.mSectorSize > 128 ? 1 : 0),
				0,
				0
			};

			const uint16_t crc = ComputeCRC(sechdr, 8);
			sechdr[8] = (uint8_t)(crc >> 8);
			sechdr[9] = (uint8_t)(crc >> 0);

			// first three bytes require special clocking, but are included in CRC
			enc.EncodeByteMFM(0xFB, 0xA1);
			enc.EncodeByteMFM(0xFB, 0xA1);
			enc.EncodeByteMFM(0xFB, 0xA1);

			for(int i=3; i<10; ++i)
				enc.EncodeByteMFM(sechdr[i]);

			for(int i=0; i<22; ++i)
				enc.EncodeByteMFM(0x4E);

			for(int i=0; i<12; ++i)
				enc.EncodeByteMFM(0x0D);

			// write DAM
			enc.EncodeByteMFM(0xFB, 0xA1);
			enc.EncodeByteMFM(0xFB, 0xA1);
			enc.EncodeByteMFM(0xFB, 0xA1);
			enc.EncodeByteMFM(sec.mAddressMark);

			// write payload
			for(uint32_t i=0; i<sec.mSectorSize; ++i)
				enc.EncodeByteMFM(~sec.mData[i]);

			// compute and write CRC
			const uint8_t secdhdr[4] = {
				0xA1,
				0xA1,
				0xA1,
				sec.mAddressMark,
			};

			uint16_t crc2 = ComputeCRC(secdhdr, 4);

			crc2 = ComputeInvertedCRC(sec.mData, sec.mSectorSize, crc2);

			if (sec.mRecordedCRC != sec.mComputedCRC)
				crc2 = ~crc2;

			enc.EncodeByteMFM((uint8_t)(crc2 >> 8));
			enc.EncodeByteMFM((uint8_t)(crc2 >> 0));

			// write trailer
			for(int i=0; i<24; ++i)
				enc.EncodeByteMFM(0x4E);
		} else {
			if (first_sec) {
				enc.BeginCritical();
				enc.EncodeByteFM(0x00);
				enc.EncodeByteFM(0xD7, 0xFC);
			}

			for(int j=0; j<4; ++j)
				enc.EncodeByteFM(0x00);

			if (!first_sec)
				enc.BeginCritical();
		
			enc.EncodeByteFM(0x00);
			enc.EncodeByteFM(0x00);
			enc.EncodeByteFM(0xC7, 0xFE);
		
			uint8_t sechdr[7]={
				0xFE,
				(uint8_t)track,
				0,
				(uint8_t)sec.mIndex,
				0,
				0,
				0
			};

			switch(sec.mSectorSize) {
				case 128:
				default:
					sechdr[4] = 0;
					break;

				case 256:
					sechdr[4] = 1;
					break;

				case 512:
					sechdr[4] = 2;
					break;

				case 1024:
					sechdr[4] = 3;
					break;
			}

			uint16_t crc = ComputeCRC(sechdr, 5);
			sechdr[5] = (uint8_t)(crc >> 8);
			sechdr[6] = (uint8_t)crc;

			for(int j=1; j<7; ++j)
				enc.EncodeByteFM(sechdr[j]);

			for(int j=0; j<17; ++j)
				enc.EncodeByteFM(0x00);

			uint8_t secdat[1024 + 3];
			secdat[0] = sec.mAddressMark;

			for(uint32_t j=0; j<sec.mSectorSize; ++j)
				secdat[j+1] = ~sec.mData[j];

			secdat[sec.mSectorSize + 1] = (uint8_t)(sec.mRecordedCRC >> 8);
			secdat[sec.mSectorSize + 2] = (uint8_t)sec.mRecordedCRC;

			enc.EncodeByteFM(0xC7, secdat[0]);

			// If this sector has a CRC error AND it's a long sector, don't bother writing
			// out the full sector to save room on the track.
			if (sec.mComputedCRC != sec.mRecordedCRC && sec.mSectorSize > 128) {
				for(uint32_t j=1; j<131; ++j) {
					if (sec.mWeakOffset >= 0 && j >= (uint32_t)sec.mWeakOffset+1)
						enc.EncodeWeakByteFM();
					else
						enc.EncodeByteFM(secdat[j]);
				}
			} else {
				for(uint32_t j=1; j<sec.mSectorSize + 3; ++j) {
					if (sec.mWeakOffset >= 0 && j >= (uint32_t)sec.mWeakOffset+1) {
						enc.EncodeWeakByteFM();
					} else
						enc.EncodeByteFM(secdat[j]);
				}
			}

			enc.EncodeByteFM(0x00);
			enc.EndCritical();

			for(int j=0; j<8; ++j)
				enc.EncodeByteFM(0x00);
		}
	}

	// create copies of all sectors
	std::vector<SectorCopy> copies;

	uint32_t encodingPosition = 0;

	if (lowest_sec)
		encodingPosition = (uint32_t)(0.5 + lowest_sec->mPosition * 200000000.0 / 6.0 / (double)bitCellTime) * bitCellTime;

	for(size_t i=0; i<numsecs; ++i) {
		const SectorInfo& sec = *sectors[i];
		const SectorEncoder& enc = sector_encoders[i];

		for(int j=0; j<7; ++j) {
			// we round off the position to bitcells to make the bitstream cleaner
			//uint32_t position = (uint32_t)(0.5 + (sec.mPosition + j) * 200000000.0 / 6.0 / (double)bitCellTime) * bitCellTime;
			uint32_t position = encodingPosition + (uint32_t)(0.5 + j * 200000000.0 / 6.0 / (double)bitCellTime) * bitCellTime;

			SectorCopy copy;

			copy.mpSector = &sec;
			copy.mpEncodedSector = &enc;
			copy.mPosition = position;
			copy.mEncodeStart = position;
			copy.mEncodeEnd = position + enc.mTime;

			A8RC_RT_ASSERT(enc.mTime >= enc.mCriticalEnd);

			copies.push_back(copy);

			if (!j && g_verbosity >= 1) {
				printf("Encoding track %2u, sector %2u at %.3f-%.3f (critical %.3f-%.3f)\n"
					, track
					, sec.mIndex
					, fmod(encodingPosition / (200000000.0 / 6.0), 1.0)
					, fmod((encodingPosition + enc.mTime) / (200000000.0 / 6.0), 1.0)
					, fmod((encodingPosition + enc.mCriticalStart) / (200000000.0 / 6.0), 1.0)
					, fmod((encodingPosition + enc.mCriticalEnd) / (200000000.0 / 6.0), 1.0));
			}
		}

		encodingPosition += enc.mTime;
	}

	// sort all copies by rotational position
	std::sort(copies.begin(), copies.end(), [](const SectorCopy& x, const SectorCopy& y) { return x.mPosition < y.mPosition; });

	// scan the copies and try to deal with overlaps
	std::set<std::pair<uint32_t, uint32_t>> reportedOverlaps;

	const size_t numcopies = copies.size();

	for(size_t i=1; i<numcopies; ++i) {
		SectorCopy& cp0 = copies[i-1];
		SectorCopy& cp1 = copies[i];

		if (cp0.mEncodeEnd > cp1.mEncodeStart) {
			uint32_t cut = cp1.mEncodeStart + ((cp0.mEncodeEnd - cp1.mEncodeStart) >> 1);
			uint32_t lo = cp0.mPosition + cp0.mpEncodedSector->mCriticalEnd;
			uint32_t hi = cp1.mPosition + cp1.mpEncodedSector->mCriticalStart;

			if (lo > hi) {
				if (reportedOverlaps.insert(std::make_pair((int)cp0.mpSector->mIndex, (int)cp1.mpSector->mIndex)).second) {
					printf("WARNING: Track %u, sectors %u and %u overlapped during encoding. Encoded track may not work.\n"
						, track
						, cp0.mpSector->mIndex
						, cp1.mpSector->mIndex);
				}
			} else {
				if (cut < lo)
					cut = lo;
				else if (cut > hi)
					cut = hi;
			}

			// trim both sectors
			cp0.mEncodeEnd = cut;
			cp1.mEncodeStart = cut;
		}
	}

	// encode unified bitstream
	uint32_t time_last = 0;

	for(size_t i=0; i<numcopies; ++i) {
		const SectorCopy& cp = copies[i];
		uint32_t sector_start = cp.mEncodeStart;

		// if sector start is with first rev, mark it as the splice point
		if (cp.mpSector == first_sec && cp.mPosition >= dst.mIndexTimes[0] && cp.mPosition < dst.mIndexTimes[1]) {
			dst.mSpliceStart = (int32_t)cp.mPosition;
			dst.mSpliceEnd = dst.mSpliceStart + (dst.mIndexTimes[1] - dst.mIndexTimes[0]);

			if (g_verbosity >= 2) {
				printf("Using [%u, %u] as the splice points for track\n", dst.mSpliceStart, dst.mSpliceEnd);
			}
		}

		// encode 1 bits until the sector starts, while we have room
		if (mfm) {
			while(sector_start - time_last > bitCellTime * 2) {
				dst.mTransitions.push_back(time_last);
				time_last += bitCellTime * 2;
			}
		} else {
			while(sector_start - time_last > bitCellTime) {
				dst.mTransitions.push_back(time_last);
				time_last += bitCellTime;
			}
		}

		// determine the range to copy over from the original encoded stream
		uint32_t xfer_start = cp.mEncodeStart - cp.mPosition;
		uint32_t xfer_end = cp.mEncodeEnd - cp.mPosition;

		A8RC_RT_ASSERT(xfer_start < 0x80000000U);
		A8RC_RT_ASSERT(xfer_end <= cp.mpEncodedSector->mTime);

		if (g_verbosity >= 2)
			printf("Encoding %u-%u of sector %u (critical %u-%u) to %u-%u\n", xfer_start, xfer_end, cp.mpSector->mIndex, cp.mpEncodedSector->mCriticalStart, cp.mpEncodedSector->mCriticalEnd, cp.mEncodeStart, cp.mEncodeEnd);

		if (xfer_end > xfer_start) {
			const auto& src_stream = cp.mpEncodedSector->mStream;
			auto xfer1 = std::lower_bound(src_stream.begin(), src_stream.end(), xfer_start);
			auto xfer2 = std::lower_bound(xfer1, src_stream.end(), xfer_end);

			for(auto it = xfer1; it != xfer2; ++it)
				dst.mTransitions.push_back(cp.mPosition + *it);
		}

		time_last = cp.mEncodeEnd;
	}
}

void encode_disk(RawDisk& dst, DiskInfo& src, double periodMultiplier, int trackSelect) {
	for(int i=0; i<40; ++i) {
		if (trackSelect >= 0 && trackSelect != i)
			continue;

		if (g_verbosity >= 1)
			printf("Encoding track %u\n", i);

		encode_track(dst.mTracks[i], src.mTracks[i], i, periodMultiplier);
	}
}
