//	Altirra - Atari 800/800XL/5200 emulator
//	Copyright (C) 2009-2014 Avery Lee
//
//	This program is free software; you can redistribute it and/or modify
//	it under the terms of the GNU General Public License as published by
//	the Free Software Foundation; either version 2 of the License, or
//	(at your option) any later version.
//
//	This program is distributed in the hope that it will be useful,
//	but WITHOUT ANY WARRANTY; without even the implied warranty of
//	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//	GNU General Public License for more details.
//
//	You should have received a copy of the GNU General Public License
//	along with this program; if not, write to the Free Software
//	Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

#include <stdafx.h>
#include <vd2/system/binary.h>
#include <vd2/system/error.h>
#include "scsi.h"
#include "scsidisk.h"
#include "idedisk.h"
#include "uirender.h"

class ATSCSIDiskDevice : public vdrefcounted<IATSCSIDiskDevice> {
public:
	ATSCSIDiskDevice();

	void *AsInterface(uint32 iid);

	IATIDEDisk *GetDisk() const { return mpDisk; }

	void Init(IATIDEDisk *disk);

	void SetUIRenderer(IATUIRenderer *r);

	virtual void Attach(ATSCSIBusEmulator *bus);
	virtual void Detach();

	virtual void BeginCommand(const uint8 *command, uint32 length);
	virtual void AdvanceCommand();
	virtual void AbortCommand();

protected:
	enum State {
		kState_None,
		kState_RequestSense_0,
		kState_Read_0,
		kState_Write_0,
		kState_Write_1,
		kState_UnknownCommand,
		kState_Status,
		kState_Complete
	};

	void DecodeCmdGroup0(const uint8 *command);

	ATSCSIBusEmulator *mpBus;
	IATUIRenderer *mpUIRenderer;
	vdrefptr<IATIDEDisk> mpDisk;

	State mState;
	uint8 mLUN;
	uint32 mLBA;
	uint32 mBlockCount;
	bool mbClearErrorNextCommand;
	uint8 mError;
	uint32 mErrorLBA;

	uint8 mTransferBuffer[512];
};

ATSCSIDiskDevice::ATSCSIDiskDevice()
	: mpBus(nullptr)
	, mpUIRenderer(nullptr)
	, mpDisk()
	, mState(kState_None)
	, mbClearErrorNextCommand(false)
	, mError(0)
	, mErrorLBA(0)
{
}

void *ATSCSIDiskDevice::AsInterface(uint32 iid) {
	switch(iid) {
		case IATSCSIDevice::kTypeID: return static_cast<IATSCSIDevice *>(this);
	}

	return nullptr;
}

void ATSCSIDiskDevice::Init(IATIDEDisk *disk) {
	mpDisk = disk;
}

void ATSCSIDiskDevice::SetUIRenderer(IATUIRenderer *r) {
	mpUIRenderer = r;
}

void ATSCSIDiskDevice::Attach(ATSCSIBusEmulator *bus) {
	mpBus = bus;
}

void ATSCSIDiskDevice::Detach() {
	mpBus = nullptr;
}

void ATSCSIDiskDevice::BeginCommand(const uint8 *command, uint32 length) {
	if (mbClearErrorNextCommand) {
		mbClearErrorNextCommand = false;

		mError = 0;
	}

	switch(command[0] & 0xe0) {
		case 0x00:
			DecodeCmdGroup0(command);
			break;
	}

	switch(command[0]) {
		case 0x03:	// request sense
			mState = kState_RequestSense_0;

		case 0x08:	// read
			mState = kState_Read_0;
			break;

		case 0x0A:	// write
			mState = kState_Write_0;
			break;

		default:
			mState = kState_UnknownCommand;
			break;
	}
}

void ATSCSIDiskDevice::AdvanceCommand() {
	switch(mState) {
		case kState_RequestSense_0:
			mbClearErrorNextCommand = true;

			mTransferBuffer[0] = mError;
			mTransferBuffer[1] = (uint8)(mErrorLBA >> 16);
			mTransferBuffer[2] = (uint8)(mErrorLBA >> 8);
			mTransferBuffer[3] = (uint8)mErrorLBA;

			mpBus->CommandSendData(ATSCSIBusEmulator::kSendMode_DataIn, mTransferBuffer, 4);
			mState = kState_Complete;
			break;

		case kState_Read_0:
			if (mLBA >= mpDisk->GetSectorCount()) {
				mError = 0x21;		// Class 2 Illegal block address
				mState = kState_Status;
			} else {
				try {
					if (mpUIRenderer)
						mpUIRenderer->SetIDEActivity(false, mLBA);

					mpDisk->ReadSectors(mTransferBuffer, mLBA++, 1);
					mpBus->CommandSendData(ATSCSIBusEmulator::kSendMode_DataIn, mTransferBuffer, 512);

					if (!--mBlockCount)
						mState = kState_Status;
				} catch(const MyError&) {
					mError = 0x21;		// Class 2 Illegal block address
					mState = kState_Status;
				}
			}
			break;

		case kState_Write_0:
			if (mLBA >= mpDisk->GetSectorCount()) {
				mError = 0x21;		// Class 2 Illegal block address
				mState = kState_Status;
			} else {
				mpBus->CommandReceiveData(ATSCSIBusEmulator::kReceiveMode_DataOut, mTransferBuffer, 512);
				mState = kState_Write_1;
			}
			break;

		case kState_Write_1:
			try {
				if (mpUIRenderer)
					mpUIRenderer->SetIDEActivity(true, mLBA);

				mpDisk->WriteSectors(mTransferBuffer, mLBA++, 1);

				if (--mBlockCount)
					mState = kState_Write_0;
				else
					mState = kState_Status;
			} catch(const MyError&) {
				mError = 0x11;		// Class 1 Uncorrectable data error
				mState = kState_Status;
			}
			break;

		case kState_UnknownCommand:
			mState = kState_Status;
			break;

		case kState_Status:
			mTransferBuffer[0] = 0x80 + (mError ? 0x02 : 0x00);
			mTransferBuffer[1] = 0x00;
			mpBus->CommandSendData(ATSCSIBusEmulator::kSendMode_Status, mTransferBuffer, 2);
			mState = kState_Complete;
			break;

		case kState_Complete:
			mpBus->CommandEnd();
			mState = kState_None;
			break;
	}
}

void ATSCSIDiskDevice::AbortCommand() {
	mState = kState_None;
}

void ATSCSIDiskDevice::DecodeCmdGroup0(const uint8 *command) {
	mLUN = command[1] >> 5;
	mLBA = VDReadUnalignedBEU32(command) & 0x1FFFFF;
	mBlockCount = command[4];
	if (!mBlockCount)
		mBlockCount = 256;
}

///////////////////////////////////////////////////////////////////////////

void ATCreateSCSIDiskDevice(IATIDEDisk *disk, IATSCSIDiskDevice **dev) {
	vdrefptr<ATSCSIDiskDevice> p(new ATSCSIDiskDevice);
	p->Init(disk);

	*dev = p.release();
}
