# Copyright (c) 2025 Thomas Goirand <zigo@debian.org>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from oslo_policy import policy as oslo_policy
from oslo_context import context as oslo_context
from flask import request, jsonify
from functools import wraps
import logging

# Import your policies explicitly
from vmms.policy import rules as vmms_rules

LOG = logging.getLogger(__name__)

_ENFORCER = None

def init_enforcer(CONF):
    global _ENFORCER
    if not _ENFORCER:
        LOG.debug("🐛 Initializing policy enforcer")
        _ENFORCER = oslo_policy.Enforcer(CONF)

        # Explicitly register your policies
        LOG.debug("🐛 Registering VMMS policy rules")
        for rule in vmms_rules.list_rules():
            _ENFORCER.register_default(rule)

        # Force load the rules
        _ENFORCER.load_rules()

        LOG.debug(f"🐛 Registered {len(_ENFORCER.registered_rules)} policy rules")
            
    return _ENFORCER

def get_enforcer(CONF):
    enforcer = _ENFORCER or init_enforcer(CONF)
    return enforcer

def enforce_policy(CONF, action, target=None):
    """Enforce policy for the current Flask request"""
    # Create context from request headers (set by keystonemiddleware)
    roles_header = request.headers.get('X-Roles', '')
    roles = roles_header.split(',') if roles_header else []
    roles = [role.strip() for role in roles]  # Clean up whitespace
    
    ctx = oslo_context.RequestContext(
        user_id=request.headers.get('X-User-Id'),
        project_id=request.headers.get('X-Project-Id'),
        roles=roles,
        is_admin='admin' in roles
    )
    
    # Get policy enforcer
    enforcer = get_enforcer(CONF)
    
    # Set default target if not provided
    if target is None:
        target = {
            'project_id': ctx.project_id or '',
            'user_id': ctx.user_id or ''
        }
    
    # Get credentials from context
    creds = ctx.to_policy_values()

    try:
        # Use the standard enforce method
        result = enforcer.enforce(action, target, creds)
        
        if not result:
            raise oslo_policy.PolicyNotAuthorized(action, target, creds)
            
        return result
    except oslo_policy.PolicyNotAuthorized:
        raise
    except Exception as e:
        LOG.error(f"⧱ Policy enforcement failed for {action}: {e}", exc_info=True)
        raise

def require_policy_factory(get_config_func):
    """Factory function to create policy decorator with config getter"""
    def require_policy(action):
        """Decorator to enforce policy on API endpoints"""
        def decorator(f):
            @wraps(f)
            def wrapper(*args, **kwargs):
                # Check authentication first
                identity_status = request.headers.get('X-Identity-Status')
                if not identity_status or identity_status.upper() != 'CONFIRMED':
                    return jsonify({'error': 'Authentication required'}), 401
                
                try:
                    CONF = get_config_func()
                    enforce_policy(CONF, action)
                except KeyError as e:
                    LOG.error(f"⧱ Policy key error: {e}")
                    return jsonify({'error': 'Policy configuration error'}), 500
                except Exception as e:
                    return jsonify({'error': 'Forbidden'}), 403
                    
                return f(*args, **kwargs)
            return wrapper
        return decorator
    return require_policy
